1 //===-- AArch64ISelLowering.cpp - AArch64 DAG Lowering Implementation ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the AArch64TargetLowering class.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "AArch64ISelLowering.h"
14 #include "AArch64CallingConvention.h"
15 #include "AArch64ExpandImm.h"
16 #include "AArch64MachineFunctionInfo.h"
17 #include "AArch64PerfectShuffle.h"
18 #include "AArch64RegisterInfo.h"
19 #include "AArch64Subtarget.h"
20 #include "MCTargetDesc/AArch64AddressingModes.h"
21 #include "Utils/AArch64BaseInfo.h"
22 #include "llvm/ADT/APFloat.h"
23 #include "llvm/ADT/APInt.h"
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallSet.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/Statistic.h"
29 #include "llvm/ADT/StringRef.h"
30 #include "llvm/ADT/Twine.h"
31 #include "llvm/Analysis/LoopInfo.h"
32 #include "llvm/Analysis/MemoryLocation.h"
33 #include "llvm/Analysis/ObjCARCUtil.h"
34 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
35 #include "llvm/Analysis/TargetTransformInfo.h"
36 #include "llvm/Analysis/ValueTracking.h"
37 #include "llvm/Analysis/VectorUtils.h"
38 #include "llvm/CodeGen/Analysis.h"
39 #include "llvm/CodeGen/CallingConvLower.h"
40 #include "llvm/CodeGen/ComplexDeinterleavingPass.h"
41 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
42 #include "llvm/CodeGen/GlobalISel/Utils.h"
43 #include "llvm/CodeGen/ISDOpcodes.h"
44 #include "llvm/CodeGen/MachineBasicBlock.h"
45 #include "llvm/CodeGen/MachineFrameInfo.h"
46 #include "llvm/CodeGen/MachineFunction.h"
47 #include "llvm/CodeGen/MachineInstr.h"
48 #include "llvm/CodeGen/MachineInstrBuilder.h"
49 #include "llvm/CodeGen/MachineMemOperand.h"
50 #include "llvm/CodeGen/MachineRegisterInfo.h"
51 #include "llvm/CodeGen/RuntimeLibcallUtil.h"
52 #include "llvm/CodeGen/SelectionDAG.h"
53 #include "llvm/CodeGen/SelectionDAGNodes.h"
54 #include "llvm/CodeGen/TargetCallingConv.h"
55 #include "llvm/CodeGen/TargetInstrInfo.h"
56 #include "llvm/CodeGen/TargetOpcodes.h"
57 #include "llvm/CodeGen/ValueTypes.h"
58 #include "llvm/CodeGenTypes/MachineValueType.h"
59 #include "llvm/IR/Attributes.h"
60 #include "llvm/IR/Constants.h"
61 #include "llvm/IR/DataLayout.h"
62 #include "llvm/IR/DebugLoc.h"
63 #include "llvm/IR/DerivedTypes.h"
64 #include "llvm/IR/Function.h"
65 #include "llvm/IR/GetElementPtrTypeIterator.h"
66 #include "llvm/IR/GlobalValue.h"
67 #include "llvm/IR/IRBuilder.h"
68 #include "llvm/IR/Instruction.h"
69 #include "llvm/IR/Instructions.h"
70 #include "llvm/IR/IntrinsicInst.h"
71 #include "llvm/IR/Intrinsics.h"
72 #include "llvm/IR/IntrinsicsAArch64.h"
73 #include "llvm/IR/Module.h"
74 #include "llvm/IR/PatternMatch.h"
75 #include "llvm/IR/Type.h"
76 #include "llvm/IR/Use.h"
77 #include "llvm/IR/Value.h"
78 #include "llvm/MC/MCRegisterInfo.h"
79 #include "llvm/Support/AtomicOrdering.h"
80 #include "llvm/Support/Casting.h"
81 #include "llvm/Support/CodeGen.h"
82 #include "llvm/Support/CommandLine.h"
83 #include "llvm/Support/Debug.h"
84 #include "llvm/Support/ErrorHandling.h"
85 #include "llvm/Support/InstructionCost.h"
86 #include "llvm/Support/KnownBits.h"
87 #include "llvm/Support/MathExtras.h"
88 #include "llvm/Support/SipHash.h"
89 #include "llvm/Support/raw_ostream.h"
90 #include "llvm/Target/TargetMachine.h"
91 #include "llvm/Target/TargetOptions.h"
92 #include "llvm/TargetParser/Triple.h"
93 #include <algorithm>
94 #include <bitset>
95 #include <cassert>
96 #include <cctype>
97 #include <cstdint>
98 #include <cstdlib>
99 #include <iterator>
100 #include <limits>
101 #include <optional>
102 #include <tuple>
103 #include <utility>
104 #include <vector>
105
106 using namespace llvm;
107 using namespace llvm::PatternMatch;
108
109 #define DEBUG_TYPE "aarch64-lower"
110
111 STATISTIC(NumTailCalls, "Number of tail calls");
112 STATISTIC(NumShiftInserts, "Number of vector shift inserts");
113 STATISTIC(NumOptimizedImms, "Number of times immediates were optimized");
114
115 // FIXME: The necessary dtprel relocations don't seem to be supported
116 // well in the GNU bfd and gold linkers at the moment. Therefore, by
117 // default, for now, fall back to GeneralDynamic code generation.
118 cl::opt<bool> EnableAArch64ELFLocalDynamicTLSGeneration(
119 "aarch64-elf-ldtls-generation", cl::Hidden,
120 cl::desc("Allow AArch64 Local Dynamic TLS code generation"),
121 cl::init(false));
122
123 static cl::opt<bool>
124 EnableOptimizeLogicalImm("aarch64-enable-logical-imm", cl::Hidden,
125 cl::desc("Enable AArch64 logical imm instruction "
126 "optimization"),
127 cl::init(true));
128
129 // Temporary option added for the purpose of testing functionality added
130 // to DAGCombiner.cpp in D92230. It is expected that this can be removed
131 // in future when both implementations will be based off MGATHER rather
132 // than the GLD1 nodes added for the SVE gather load intrinsics.
133 static cl::opt<bool>
134 EnableCombineMGatherIntrinsics("aarch64-enable-mgather-combine", cl::Hidden,
135 cl::desc("Combine extends of AArch64 masked "
136 "gather intrinsics"),
137 cl::init(true));
138
139 static cl::opt<bool> EnableExtToTBL("aarch64-enable-ext-to-tbl", cl::Hidden,
140 cl::desc("Combine ext and trunc to TBL"),
141 cl::init(true));
142
143 // All of the XOR, OR and CMP use ALU ports, and data dependency will become the
144 // bottleneck after this transform on high end CPU. So this max leaf node
145 // limitation is guard cmp+ccmp will be profitable.
146 static cl::opt<unsigned> MaxXors("aarch64-max-xors", cl::init(16), cl::Hidden,
147 cl::desc("Maximum of xors"));
148
149 // By turning this on, we will not fallback to DAG ISel when encountering
150 // scalable vector types for all instruction, even if SVE is not yet supported
151 // with some instructions.
152 // See [AArch64TargetLowering::fallbackToDAGISel] for implementation details.
153 cl::opt<bool> EnableSVEGISel(
154 "aarch64-enable-gisel-sve", cl::Hidden,
155 cl::desc("Enable / disable SVE scalable vectors in Global ISel"),
156 cl::init(false));
157
158 /// Value type used for condition codes.
159 static const MVT MVT_CC = MVT::i32;
160
161 static const MCPhysReg GPRArgRegs[] = {AArch64::X0, AArch64::X1, AArch64::X2,
162 AArch64::X3, AArch64::X4, AArch64::X5,
163 AArch64::X6, AArch64::X7};
164 static const MCPhysReg FPRArgRegs[] = {AArch64::Q0, AArch64::Q1, AArch64::Q2,
165 AArch64::Q3, AArch64::Q4, AArch64::Q5,
166 AArch64::Q6, AArch64::Q7};
167
getGPRArgRegs()168 ArrayRef<MCPhysReg> llvm::AArch64::getGPRArgRegs() { return GPRArgRegs; }
169
getFPRArgRegs()170 ArrayRef<MCPhysReg> llvm::AArch64::getFPRArgRegs() { return FPRArgRegs; }
171
getPackedSVEVectorVT(EVT VT)172 static inline EVT getPackedSVEVectorVT(EVT VT) {
173 switch (VT.getSimpleVT().SimpleTy) {
174 default:
175 llvm_unreachable("unexpected element type for vector");
176 case MVT::i8:
177 return MVT::nxv16i8;
178 case MVT::i16:
179 return MVT::nxv8i16;
180 case MVT::i32:
181 return MVT::nxv4i32;
182 case MVT::i64:
183 return MVT::nxv2i64;
184 case MVT::f16:
185 return MVT::nxv8f16;
186 case MVT::f32:
187 return MVT::nxv4f32;
188 case MVT::f64:
189 return MVT::nxv2f64;
190 case MVT::bf16:
191 return MVT::nxv8bf16;
192 }
193 }
194
195 // NOTE: Currently there's only a need to return integer vector types. If this
196 // changes then just add an extra "type" parameter.
getPackedSVEVectorVT(ElementCount EC)197 static inline EVT getPackedSVEVectorVT(ElementCount EC) {
198 switch (EC.getKnownMinValue()) {
199 default:
200 llvm_unreachable("unexpected element count for vector");
201 case 16:
202 return MVT::nxv16i8;
203 case 8:
204 return MVT::nxv8i16;
205 case 4:
206 return MVT::nxv4i32;
207 case 2:
208 return MVT::nxv2i64;
209 }
210 }
211
getPromotedVTForPredicate(EVT VT)212 static inline EVT getPromotedVTForPredicate(EVT VT) {
213 assert(VT.isScalableVector() && (VT.getVectorElementType() == MVT::i1) &&
214 "Expected scalable predicate vector type!");
215 switch (VT.getVectorMinNumElements()) {
216 default:
217 llvm_unreachable("unexpected element count for vector");
218 case 2:
219 return MVT::nxv2i64;
220 case 4:
221 return MVT::nxv4i32;
222 case 8:
223 return MVT::nxv8i16;
224 case 16:
225 return MVT::nxv16i8;
226 }
227 }
228
229 /// Returns true if VT's elements occupy the lowest bit positions of its
230 /// associated register class without any intervening space.
231 ///
232 /// For example, nxv2f16, nxv4f16 and nxv8f16 are legal types that belong to the
233 /// same register class, but only nxv8f16 can be treated as a packed vector.
isPackedVectorType(EVT VT,SelectionDAG & DAG)234 static inline bool isPackedVectorType(EVT VT, SelectionDAG &DAG) {
235 assert(VT.isVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
236 "Expected legal vector type!");
237 return VT.isFixedLengthVector() ||
238 VT.getSizeInBits().getKnownMinValue() == AArch64::SVEBitsPerBlock;
239 }
240
241 // Returns true for ####_MERGE_PASSTHRU opcodes, whose operands have a leading
242 // predicate and end with a passthru value matching the result type.
isMergePassthruOpcode(unsigned Opc)243 static bool isMergePassthruOpcode(unsigned Opc) {
244 switch (Opc) {
245 default:
246 return false;
247 case AArch64ISD::BITREVERSE_MERGE_PASSTHRU:
248 case AArch64ISD::BSWAP_MERGE_PASSTHRU:
249 case AArch64ISD::REVH_MERGE_PASSTHRU:
250 case AArch64ISD::REVW_MERGE_PASSTHRU:
251 case AArch64ISD::REVD_MERGE_PASSTHRU:
252 case AArch64ISD::CTLZ_MERGE_PASSTHRU:
253 case AArch64ISD::CTPOP_MERGE_PASSTHRU:
254 case AArch64ISD::DUP_MERGE_PASSTHRU:
255 case AArch64ISD::ABS_MERGE_PASSTHRU:
256 case AArch64ISD::NEG_MERGE_PASSTHRU:
257 case AArch64ISD::FNEG_MERGE_PASSTHRU:
258 case AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU:
259 case AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU:
260 case AArch64ISD::FCEIL_MERGE_PASSTHRU:
261 case AArch64ISD::FFLOOR_MERGE_PASSTHRU:
262 case AArch64ISD::FNEARBYINT_MERGE_PASSTHRU:
263 case AArch64ISD::FRINT_MERGE_PASSTHRU:
264 case AArch64ISD::FROUND_MERGE_PASSTHRU:
265 case AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU:
266 case AArch64ISD::FTRUNC_MERGE_PASSTHRU:
267 case AArch64ISD::FP_ROUND_MERGE_PASSTHRU:
268 case AArch64ISD::FP_EXTEND_MERGE_PASSTHRU:
269 case AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU:
270 case AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU:
271 case AArch64ISD::FCVTZU_MERGE_PASSTHRU:
272 case AArch64ISD::FCVTZS_MERGE_PASSTHRU:
273 case AArch64ISD::FSQRT_MERGE_PASSTHRU:
274 case AArch64ISD::FRECPX_MERGE_PASSTHRU:
275 case AArch64ISD::FABS_MERGE_PASSTHRU:
276 return true;
277 }
278 }
279
280 // Returns true if inactive lanes are known to be zeroed by construction.
isZeroingInactiveLanes(SDValue Op)281 static bool isZeroingInactiveLanes(SDValue Op) {
282 switch (Op.getOpcode()) {
283 default:
284 return false;
285 // We guarantee i1 splat_vectors to zero the other lanes
286 case ISD::SPLAT_VECTOR:
287 case AArch64ISD::PTRUE:
288 case AArch64ISD::SETCC_MERGE_ZERO:
289 return true;
290 case ISD::INTRINSIC_WO_CHAIN:
291 switch (Op.getConstantOperandVal(0)) {
292 default:
293 return false;
294 case Intrinsic::aarch64_sve_ptrue:
295 case Intrinsic::aarch64_sve_pnext:
296 case Intrinsic::aarch64_sve_cmpeq:
297 case Intrinsic::aarch64_sve_cmpne:
298 case Intrinsic::aarch64_sve_cmpge:
299 case Intrinsic::aarch64_sve_cmpgt:
300 case Intrinsic::aarch64_sve_cmphs:
301 case Intrinsic::aarch64_sve_cmphi:
302 case Intrinsic::aarch64_sve_cmpeq_wide:
303 case Intrinsic::aarch64_sve_cmpne_wide:
304 case Intrinsic::aarch64_sve_cmpge_wide:
305 case Intrinsic::aarch64_sve_cmpgt_wide:
306 case Intrinsic::aarch64_sve_cmplt_wide:
307 case Intrinsic::aarch64_sve_cmple_wide:
308 case Intrinsic::aarch64_sve_cmphs_wide:
309 case Intrinsic::aarch64_sve_cmphi_wide:
310 case Intrinsic::aarch64_sve_cmplo_wide:
311 case Intrinsic::aarch64_sve_cmpls_wide:
312 case Intrinsic::aarch64_sve_fcmpeq:
313 case Intrinsic::aarch64_sve_fcmpne:
314 case Intrinsic::aarch64_sve_fcmpge:
315 case Intrinsic::aarch64_sve_fcmpgt:
316 case Intrinsic::aarch64_sve_fcmpuo:
317 case Intrinsic::aarch64_sve_facgt:
318 case Intrinsic::aarch64_sve_facge:
319 case Intrinsic::aarch64_sve_whilege:
320 case Intrinsic::aarch64_sve_whilegt:
321 case Intrinsic::aarch64_sve_whilehi:
322 case Intrinsic::aarch64_sve_whilehs:
323 case Intrinsic::aarch64_sve_whilele:
324 case Intrinsic::aarch64_sve_whilelo:
325 case Intrinsic::aarch64_sve_whilels:
326 case Intrinsic::aarch64_sve_whilelt:
327 case Intrinsic::aarch64_sve_match:
328 case Intrinsic::aarch64_sve_nmatch:
329 case Intrinsic::aarch64_sve_whilege_x2:
330 case Intrinsic::aarch64_sve_whilegt_x2:
331 case Intrinsic::aarch64_sve_whilehi_x2:
332 case Intrinsic::aarch64_sve_whilehs_x2:
333 case Intrinsic::aarch64_sve_whilele_x2:
334 case Intrinsic::aarch64_sve_whilelo_x2:
335 case Intrinsic::aarch64_sve_whilels_x2:
336 case Intrinsic::aarch64_sve_whilelt_x2:
337 return true;
338 }
339 }
340 }
341
342 static std::tuple<SDValue, SDValue>
extractPtrauthBlendDiscriminators(SDValue Disc,SelectionDAG * DAG)343 extractPtrauthBlendDiscriminators(SDValue Disc, SelectionDAG *DAG) {
344 SDLoc DL(Disc);
345 SDValue AddrDisc;
346 SDValue ConstDisc;
347
348 // If this is a blend, remember the constant and address discriminators.
349 // Otherwise, it's either a constant discriminator, or a non-blended
350 // address discriminator.
351 if (Disc->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
352 Disc->getConstantOperandVal(0) == Intrinsic::ptrauth_blend) {
353 AddrDisc = Disc->getOperand(1);
354 ConstDisc = Disc->getOperand(2);
355 } else {
356 ConstDisc = Disc;
357 }
358
359 // If the constant discriminator (either the blend RHS, or the entire
360 // discriminator value) isn't a 16-bit constant, bail out, and let the
361 // discriminator be computed separately.
362 const auto *ConstDiscN = dyn_cast<ConstantSDNode>(ConstDisc);
363 if (!ConstDiscN || !isUInt<16>(ConstDiscN->getZExtValue()))
364 return std::make_tuple(DAG->getTargetConstant(0, DL, MVT::i64), Disc);
365
366 // If there's no address discriminator, use NoRegister, which we'll later
367 // replace with XZR, or directly use a Z variant of the inst. when available.
368 if (!AddrDisc)
369 AddrDisc = DAG->getRegister(AArch64::NoRegister, MVT::i64);
370
371 return std::make_tuple(
372 DAG->getTargetConstant(ConstDiscN->getZExtValue(), DL, MVT::i64),
373 AddrDisc);
374 }
375
AArch64TargetLowering(const TargetMachine & TM,const AArch64Subtarget & STI)376 AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
377 const AArch64Subtarget &STI)
378 : TargetLowering(TM), Subtarget(&STI) {
379 // AArch64 doesn't have comparisons which set GPRs or setcc instructions, so
380 // we have to make something up. Arbitrarily, choose ZeroOrOne.
381 setBooleanContents(ZeroOrOneBooleanContent);
382 // When comparing vectors the result sets the different elements in the
383 // vector to all-one or all-zero.
384 setBooleanVectorContents(ZeroOrNegativeOneBooleanContent);
385
386 // Set up the register classes.
387 addRegisterClass(MVT::i32, &AArch64::GPR32allRegClass);
388 addRegisterClass(MVT::i64, &AArch64::GPR64allRegClass);
389
390 if (Subtarget->hasLS64()) {
391 addRegisterClass(MVT::i64x8, &AArch64::GPR64x8ClassRegClass);
392 setOperationAction(ISD::LOAD, MVT::i64x8, Custom);
393 setOperationAction(ISD::STORE, MVT::i64x8, Custom);
394 }
395
396 if (Subtarget->hasFPARMv8()) {
397 addRegisterClass(MVT::f16, &AArch64::FPR16RegClass);
398 addRegisterClass(MVT::bf16, &AArch64::FPR16RegClass);
399 addRegisterClass(MVT::f32, &AArch64::FPR32RegClass);
400 addRegisterClass(MVT::f64, &AArch64::FPR64RegClass);
401 addRegisterClass(MVT::f128, &AArch64::FPR128RegClass);
402 }
403
404 if (Subtarget->hasNEON()) {
405 addRegisterClass(MVT::v16i8, &AArch64::FPR8RegClass);
406 addRegisterClass(MVT::v8i16, &AArch64::FPR16RegClass);
407
408 addDRType(MVT::v2f32);
409 addDRType(MVT::v8i8);
410 addDRType(MVT::v4i16);
411 addDRType(MVT::v2i32);
412 addDRType(MVT::v1i64);
413 addDRType(MVT::v1f64);
414 addDRType(MVT::v4f16);
415 addDRType(MVT::v4bf16);
416
417 addQRType(MVT::v4f32);
418 addQRType(MVT::v2f64);
419 addQRType(MVT::v16i8);
420 addQRType(MVT::v8i16);
421 addQRType(MVT::v4i32);
422 addQRType(MVT::v2i64);
423 addQRType(MVT::v8f16);
424 addQRType(MVT::v8bf16);
425 }
426
427 if (Subtarget->isSVEorStreamingSVEAvailable()) {
428 // Add legal sve predicate types
429 addRegisterClass(MVT::nxv1i1, &AArch64::PPRRegClass);
430 addRegisterClass(MVT::nxv2i1, &AArch64::PPRRegClass);
431 addRegisterClass(MVT::nxv4i1, &AArch64::PPRRegClass);
432 addRegisterClass(MVT::nxv8i1, &AArch64::PPRRegClass);
433 addRegisterClass(MVT::nxv16i1, &AArch64::PPRRegClass);
434
435 // Add legal sve data types
436 addRegisterClass(MVT::nxv16i8, &AArch64::ZPRRegClass);
437 addRegisterClass(MVT::nxv8i16, &AArch64::ZPRRegClass);
438 addRegisterClass(MVT::nxv4i32, &AArch64::ZPRRegClass);
439 addRegisterClass(MVT::nxv2i64, &AArch64::ZPRRegClass);
440
441 addRegisterClass(MVT::nxv2f16, &AArch64::ZPRRegClass);
442 addRegisterClass(MVT::nxv4f16, &AArch64::ZPRRegClass);
443 addRegisterClass(MVT::nxv8f16, &AArch64::ZPRRegClass);
444 addRegisterClass(MVT::nxv2f32, &AArch64::ZPRRegClass);
445 addRegisterClass(MVT::nxv4f32, &AArch64::ZPRRegClass);
446 addRegisterClass(MVT::nxv2f64, &AArch64::ZPRRegClass);
447
448 addRegisterClass(MVT::nxv2bf16, &AArch64::ZPRRegClass);
449 addRegisterClass(MVT::nxv4bf16, &AArch64::ZPRRegClass);
450 addRegisterClass(MVT::nxv8bf16, &AArch64::ZPRRegClass);
451
452 if (Subtarget->useSVEForFixedLengthVectors()) {
453 for (MVT VT : MVT::integer_fixedlen_vector_valuetypes())
454 if (useSVEForFixedLengthVectorVT(VT))
455 addRegisterClass(VT, &AArch64::ZPRRegClass);
456
457 for (MVT VT : MVT::fp_fixedlen_vector_valuetypes())
458 if (useSVEForFixedLengthVectorVT(VT))
459 addRegisterClass(VT, &AArch64::ZPRRegClass);
460 }
461 }
462
463 if (Subtarget->hasSVE2p1() || Subtarget->hasSME2()) {
464 addRegisterClass(MVT::aarch64svcount, &AArch64::PPRRegClass);
465 setOperationPromotedToType(ISD::LOAD, MVT::aarch64svcount, MVT::nxv16i1);
466 setOperationPromotedToType(ISD::STORE, MVT::aarch64svcount, MVT::nxv16i1);
467
468 setOperationAction(ISD::SELECT, MVT::aarch64svcount, Custom);
469 setOperationAction(ISD::SELECT_CC, MVT::aarch64svcount, Expand);
470 }
471
472 // Compute derived properties from the register classes
473 computeRegisterProperties(Subtarget->getRegisterInfo());
474
475 // Provide all sorts of operation actions
476 setOperationAction(ISD::GlobalAddress, MVT::i64, Custom);
477 setOperationAction(ISD::GlobalTLSAddress, MVT::i64, Custom);
478 setOperationAction(ISD::SETCC, MVT::i32, Custom);
479 setOperationAction(ISD::SETCC, MVT::i64, Custom);
480 setOperationAction(ISD::SETCC, MVT::bf16, Custom);
481 setOperationAction(ISD::SETCC, MVT::f16, Custom);
482 setOperationAction(ISD::SETCC, MVT::f32, Custom);
483 setOperationAction(ISD::SETCC, MVT::f64, Custom);
484 setOperationAction(ISD::STRICT_FSETCC, MVT::bf16, Custom);
485 setOperationAction(ISD::STRICT_FSETCC, MVT::f16, Custom);
486 setOperationAction(ISD::STRICT_FSETCC, MVT::f32, Custom);
487 setOperationAction(ISD::STRICT_FSETCC, MVT::f64, Custom);
488 setOperationAction(ISD::STRICT_FSETCCS, MVT::f16, Custom);
489 setOperationAction(ISD::STRICT_FSETCCS, MVT::f32, Custom);
490 setOperationAction(ISD::STRICT_FSETCCS, MVT::f64, Custom);
491 setOperationAction(ISD::BITREVERSE, MVT::i32, Legal);
492 setOperationAction(ISD::BITREVERSE, MVT::i64, Legal);
493 setOperationAction(ISD::BRCOND, MVT::Other, Custom);
494 setOperationAction(ISD::BR_CC, MVT::i32, Custom);
495 setOperationAction(ISD::BR_CC, MVT::i64, Custom);
496 setOperationAction(ISD::BR_CC, MVT::f16, Custom);
497 setOperationAction(ISD::BR_CC, MVT::f32, Custom);
498 setOperationAction(ISD::BR_CC, MVT::f64, Custom);
499 setOperationAction(ISD::SELECT, MVT::i32, Custom);
500 setOperationAction(ISD::SELECT, MVT::i64, Custom);
501 setOperationAction(ISD::SELECT, MVT::f16, Custom);
502 setOperationAction(ISD::SELECT, MVT::bf16, Custom);
503 setOperationAction(ISD::SELECT, MVT::f32, Custom);
504 setOperationAction(ISD::SELECT, MVT::f64, Custom);
505 setOperationAction(ISD::SELECT_CC, MVT::i32, Custom);
506 setOperationAction(ISD::SELECT_CC, MVT::i64, Custom);
507 setOperationAction(ISD::SELECT_CC, MVT::f16, Custom);
508 setOperationAction(ISD::SELECT_CC, MVT::bf16, Custom);
509 setOperationAction(ISD::SELECT_CC, MVT::f32, Custom);
510 setOperationAction(ISD::SELECT_CC, MVT::f64, Custom);
511 setOperationAction(ISD::BR_JT, MVT::Other, Custom);
512 setOperationAction(ISD::JumpTable, MVT::i64, Custom);
513 setOperationAction(ISD::BRIND, MVT::Other, Custom);
514 setOperationAction(ISD::SETCCCARRY, MVT::i64, Custom);
515
516 setOperationAction(ISD::PtrAuthGlobalAddress, MVT::i64, Custom);
517
518 setOperationAction(ISD::SHL_PARTS, MVT::i64, Custom);
519 setOperationAction(ISD::SRA_PARTS, MVT::i64, Custom);
520 setOperationAction(ISD::SRL_PARTS, MVT::i64, Custom);
521
522 setOperationAction(ISD::FREM, MVT::f32, Expand);
523 setOperationAction(ISD::FREM, MVT::f64, Expand);
524 setOperationAction(ISD::FREM, MVT::f80, Expand);
525
526 setOperationAction(ISD::BUILD_PAIR, MVT::i64, Expand);
527
528 // Custom lowering hooks are needed for XOR
529 // to fold it into CSINC/CSINV.
530 setOperationAction(ISD::XOR, MVT::i32, Custom);
531 setOperationAction(ISD::XOR, MVT::i64, Custom);
532
533 // Virtually no operation on f128 is legal, but LLVM can't expand them when
534 // there's a valid register class, so we need custom operations in most cases.
535 setOperationAction(ISD::FABS, MVT::f128, Expand);
536 setOperationAction(ISD::FADD, MVT::f128, LibCall);
537 setOperationAction(ISD::FCOPYSIGN, MVT::f128, Expand);
538 setOperationAction(ISD::FCOS, MVT::f128, Expand);
539 setOperationAction(ISD::FDIV, MVT::f128, LibCall);
540 setOperationAction(ISD::FMA, MVT::f128, Expand);
541 setOperationAction(ISD::FMUL, MVT::f128, LibCall);
542 setOperationAction(ISD::FNEG, MVT::f128, Expand);
543 setOperationAction(ISD::FPOW, MVT::f128, Expand);
544 setOperationAction(ISD::FREM, MVT::f128, Expand);
545 setOperationAction(ISD::FRINT, MVT::f128, Expand);
546 setOperationAction(ISD::FSIN, MVT::f128, Expand);
547 setOperationAction(ISD::FSINCOS, MVT::f128, Expand);
548 setOperationAction(ISD::FSQRT, MVT::f128, Expand);
549 setOperationAction(ISD::FSUB, MVT::f128, LibCall);
550 setOperationAction(ISD::FTAN, MVT::f128, Expand);
551 setOperationAction(ISD::FTRUNC, MVT::f128, Expand);
552 setOperationAction(ISD::SETCC, MVT::f128, Custom);
553 setOperationAction(ISD::STRICT_FSETCC, MVT::f128, Custom);
554 setOperationAction(ISD::STRICT_FSETCCS, MVT::f128, Custom);
555 setOperationAction(ISD::BR_CC, MVT::f128, Custom);
556 setOperationAction(ISD::SELECT, MVT::f128, Custom);
557 setOperationAction(ISD::SELECT_CC, MVT::f128, Custom);
558 setOperationAction(ISD::FP_EXTEND, MVT::f128, Custom);
559 // FIXME: f128 FMINIMUM and FMAXIMUM (including STRICT versions) currently
560 // aren't handled.
561
562 // Lowering for many of the conversions is actually specified by the non-f128
563 // type. The LowerXXX function will be trivial when f128 isn't involved.
564 setOperationAction(ISD::FP_TO_SINT, MVT::i32, Custom);
565 setOperationAction(ISD::FP_TO_SINT, MVT::i64, Custom);
566 setOperationAction(ISD::FP_TO_SINT, MVT::i128, Custom);
567 setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i32, Custom);
568 setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i64, Custom);
569 setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i128, Custom);
570 setOperationAction(ISD::FP_TO_UINT, MVT::i32, Custom);
571 setOperationAction(ISD::FP_TO_UINT, MVT::i64, Custom);
572 setOperationAction(ISD::FP_TO_UINT, MVT::i128, Custom);
573 setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::i32, Custom);
574 setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::i64, Custom);
575 setOperationAction(ISD::STRICT_FP_TO_UINT, MVT::i128, Custom);
576 setOperationAction(ISD::SINT_TO_FP, MVT::i32, Custom);
577 setOperationAction(ISD::SINT_TO_FP, MVT::i64, Custom);
578 setOperationAction(ISD::SINT_TO_FP, MVT::i128, Custom);
579 setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::i32, Custom);
580 setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::i64, Custom);
581 setOperationAction(ISD::STRICT_SINT_TO_FP, MVT::i128, Custom);
582 setOperationAction(ISD::UINT_TO_FP, MVT::i32, Custom);
583 setOperationAction(ISD::UINT_TO_FP, MVT::i64, Custom);
584 setOperationAction(ISD::UINT_TO_FP, MVT::i128, Custom);
585 setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i32, Custom);
586 setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i64, Custom);
587 setOperationAction(ISD::STRICT_UINT_TO_FP, MVT::i128, Custom);
588 if (Subtarget->hasFPARMv8()) {
589 setOperationAction(ISD::FP_ROUND, MVT::f16, Custom);
590 setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom);
591 }
592 setOperationAction(ISD::FP_ROUND, MVT::f32, Custom);
593 setOperationAction(ISD::FP_ROUND, MVT::f64, Custom);
594 if (Subtarget->hasFPARMv8()) {
595 setOperationAction(ISD::STRICT_FP_ROUND, MVT::f16, Custom);
596 setOperationAction(ISD::STRICT_FP_ROUND, MVT::bf16, Custom);
597 }
598 setOperationAction(ISD::STRICT_FP_ROUND, MVT::f32, Custom);
599 setOperationAction(ISD::STRICT_FP_ROUND, MVT::f64, Custom);
600
601 setOperationAction(ISD::FP_TO_UINT_SAT, MVT::i32, Custom);
602 setOperationAction(ISD::FP_TO_UINT_SAT, MVT::i64, Custom);
603 setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i32, Custom);
604 setOperationAction(ISD::FP_TO_SINT_SAT, MVT::i64, Custom);
605
606 // Variable arguments.
607 setOperationAction(ISD::VASTART, MVT::Other, Custom);
608 setOperationAction(ISD::VAARG, MVT::Other, Custom);
609 setOperationAction(ISD::VACOPY, MVT::Other, Custom);
610 setOperationAction(ISD::VAEND, MVT::Other, Expand);
611
612 // Variable-sized objects.
613 setOperationAction(ISD::STACKSAVE, MVT::Other, Expand);
614 setOperationAction(ISD::STACKRESTORE, MVT::Other, Expand);
615
616 // Lowering Funnel Shifts to EXTR
617 setOperationAction(ISD::FSHR, MVT::i32, Custom);
618 setOperationAction(ISD::FSHR, MVT::i64, Custom);
619 setOperationAction(ISD::FSHL, MVT::i32, Custom);
620 setOperationAction(ISD::FSHL, MVT::i64, Custom);
621
622 setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i64, Custom);
623
624 // Constant pool entries
625 setOperationAction(ISD::ConstantPool, MVT::i64, Custom);
626
627 // BlockAddress
628 setOperationAction(ISD::BlockAddress, MVT::i64, Custom);
629
630 // AArch64 lacks both left-rotate and popcount instructions.
631 setOperationAction(ISD::ROTL, MVT::i32, Expand);
632 setOperationAction(ISD::ROTL, MVT::i64, Expand);
633 for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
634 setOperationAction(ISD::ROTL, VT, Expand);
635 setOperationAction(ISD::ROTR, VT, Expand);
636 }
637
638 // AArch64 doesn't have i32 MULH{S|U}.
639 setOperationAction(ISD::MULHU, MVT::i32, Expand);
640 setOperationAction(ISD::MULHS, MVT::i32, Expand);
641
642 // AArch64 doesn't have {U|S}MUL_LOHI.
643 setOperationAction(ISD::UMUL_LOHI, MVT::i32, Expand);
644 setOperationAction(ISD::SMUL_LOHI, MVT::i32, Expand);
645 setOperationAction(ISD::UMUL_LOHI, MVT::i64, Expand);
646 setOperationAction(ISD::SMUL_LOHI, MVT::i64, Expand);
647
648 if (Subtarget->hasCSSC()) {
649 setOperationAction(ISD::CTPOP, MVT::i32, Legal);
650 setOperationAction(ISD::CTPOP, MVT::i64, Legal);
651 setOperationAction(ISD::CTPOP, MVT::i128, Expand);
652
653 setOperationAction(ISD::PARITY, MVT::i128, Expand);
654
655 setOperationAction(ISD::CTTZ, MVT::i32, Legal);
656 setOperationAction(ISD::CTTZ, MVT::i64, Legal);
657 setOperationAction(ISD::CTTZ, MVT::i128, Expand);
658
659 setOperationAction(ISD::ABS, MVT::i32, Legal);
660 setOperationAction(ISD::ABS, MVT::i64, Legal);
661
662 setOperationAction(ISD::SMAX, MVT::i32, Legal);
663 setOperationAction(ISD::SMAX, MVT::i64, Legal);
664 setOperationAction(ISD::UMAX, MVT::i32, Legal);
665 setOperationAction(ISD::UMAX, MVT::i64, Legal);
666
667 setOperationAction(ISD::SMIN, MVT::i32, Legal);
668 setOperationAction(ISD::SMIN, MVT::i64, Legal);
669 setOperationAction(ISD::UMIN, MVT::i32, Legal);
670 setOperationAction(ISD::UMIN, MVT::i64, Legal);
671 } else {
672 setOperationAction(ISD::CTPOP, MVT::i32, Custom);
673 setOperationAction(ISD::CTPOP, MVT::i64, Custom);
674 setOperationAction(ISD::CTPOP, MVT::i128, Custom);
675
676 setOperationAction(ISD::PARITY, MVT::i64, Custom);
677 setOperationAction(ISD::PARITY, MVT::i128, Custom);
678
679 setOperationAction(ISD::ABS, MVT::i32, Custom);
680 setOperationAction(ISD::ABS, MVT::i64, Custom);
681 }
682
683 setOperationAction(ISD::SDIVREM, MVT::i32, Expand);
684 setOperationAction(ISD::SDIVREM, MVT::i64, Expand);
685 for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
686 setOperationAction(ISD::SDIVREM, VT, Expand);
687 setOperationAction(ISD::UDIVREM, VT, Expand);
688 }
689 setOperationAction(ISD::SREM, MVT::i32, Expand);
690 setOperationAction(ISD::SREM, MVT::i64, Expand);
691 setOperationAction(ISD::UDIVREM, MVT::i32, Expand);
692 setOperationAction(ISD::UDIVREM, MVT::i64, Expand);
693 setOperationAction(ISD::UREM, MVT::i32, Expand);
694 setOperationAction(ISD::UREM, MVT::i64, Expand);
695
696 // Custom lower Add/Sub/Mul with overflow.
697 setOperationAction(ISD::SADDO, MVT::i32, Custom);
698 setOperationAction(ISD::SADDO, MVT::i64, Custom);
699 setOperationAction(ISD::UADDO, MVT::i32, Custom);
700 setOperationAction(ISD::UADDO, MVT::i64, Custom);
701 setOperationAction(ISD::SSUBO, MVT::i32, Custom);
702 setOperationAction(ISD::SSUBO, MVT::i64, Custom);
703 setOperationAction(ISD::USUBO, MVT::i32, Custom);
704 setOperationAction(ISD::USUBO, MVT::i64, Custom);
705 setOperationAction(ISD::SMULO, MVT::i32, Custom);
706 setOperationAction(ISD::SMULO, MVT::i64, Custom);
707 setOperationAction(ISD::UMULO, MVT::i32, Custom);
708 setOperationAction(ISD::UMULO, MVT::i64, Custom);
709
710 setOperationAction(ISD::UADDO_CARRY, MVT::i32, Custom);
711 setOperationAction(ISD::UADDO_CARRY, MVT::i64, Custom);
712 setOperationAction(ISD::USUBO_CARRY, MVT::i32, Custom);
713 setOperationAction(ISD::USUBO_CARRY, MVT::i64, Custom);
714 setOperationAction(ISD::SADDO_CARRY, MVT::i32, Custom);
715 setOperationAction(ISD::SADDO_CARRY, MVT::i64, Custom);
716 setOperationAction(ISD::SSUBO_CARRY, MVT::i32, Custom);
717 setOperationAction(ISD::SSUBO_CARRY, MVT::i64, Custom);
718
719 setOperationAction(ISD::FSIN, MVT::f32, Expand);
720 setOperationAction(ISD::FSIN, MVT::f64, Expand);
721 setOperationAction(ISD::FCOS, MVT::f32, Expand);
722 setOperationAction(ISD::FCOS, MVT::f64, Expand);
723 setOperationAction(ISD::FPOW, MVT::f32, Expand);
724 setOperationAction(ISD::FPOW, MVT::f64, Expand);
725 setOperationAction(ISD::FCOPYSIGN, MVT::f64, Custom);
726 setOperationAction(ISD::FCOPYSIGN, MVT::f32, Custom);
727 if (Subtarget->hasFullFP16()) {
728 setOperationAction(ISD::FCOPYSIGN, MVT::f16, Custom);
729 setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Custom);
730 } else {
731 setOperationAction(ISD::FCOPYSIGN, MVT::f16, Promote);
732 setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Promote);
733 }
734
735 for (auto Op : {ISD::FREM, ISD::FPOW, ISD::FPOWI,
736 ISD::FCOS, ISD::FSIN, ISD::FSINCOS,
737 ISD::FACOS, ISD::FASIN, ISD::FATAN,
738 ISD::FCOSH, ISD::FSINH, ISD::FTANH,
739 ISD::FTAN, ISD::FEXP, ISD::FEXP2,
740 ISD::FEXP10, ISD::FLOG, ISD::FLOG2,
741 ISD::FLOG10, ISD::STRICT_FREM, ISD::STRICT_FPOW,
742 ISD::STRICT_FPOWI, ISD::STRICT_FCOS, ISD::STRICT_FSIN,
743 ISD::STRICT_FACOS, ISD::STRICT_FASIN, ISD::STRICT_FATAN,
744 ISD::STRICT_FCOSH, ISD::STRICT_FSINH, ISD::STRICT_FTANH,
745 ISD::STRICT_FEXP, ISD::STRICT_FEXP2, ISD::STRICT_FLOG,
746 ISD::STRICT_FLOG2, ISD::STRICT_FLOG10, ISD::STRICT_FTAN}) {
747 setOperationAction(Op, MVT::f16, Promote);
748 setOperationAction(Op, MVT::v4f16, Expand);
749 setOperationAction(Op, MVT::v8f16, Expand);
750 setOperationAction(Op, MVT::bf16, Promote);
751 setOperationAction(Op, MVT::v4bf16, Expand);
752 setOperationAction(Op, MVT::v8bf16, Expand);
753 }
754
755 auto LegalizeNarrowFP = [this](MVT ScalarVT) {
756 for (auto Op : {
757 ISD::SETCC,
758 ISD::SELECT_CC,
759 ISD::BR_CC,
760 ISD::FADD,
761 ISD::FSUB,
762 ISD::FMUL,
763 ISD::FDIV,
764 ISD::FMA,
765 ISD::FCEIL,
766 ISD::FSQRT,
767 ISD::FFLOOR,
768 ISD::FNEARBYINT,
769 ISD::FRINT,
770 ISD::FROUND,
771 ISD::FROUNDEVEN,
772 ISD::FTRUNC,
773 ISD::FMINNUM,
774 ISD::FMAXNUM,
775 ISD::FMINIMUM,
776 ISD::FMAXIMUM,
777 ISD::STRICT_FADD,
778 ISD::STRICT_FSUB,
779 ISD::STRICT_FMUL,
780 ISD::STRICT_FDIV,
781 ISD::STRICT_FMA,
782 ISD::STRICT_FCEIL,
783 ISD::STRICT_FFLOOR,
784 ISD::STRICT_FSQRT,
785 ISD::STRICT_FRINT,
786 ISD::STRICT_FNEARBYINT,
787 ISD::STRICT_FROUND,
788 ISD::STRICT_FTRUNC,
789 ISD::STRICT_FROUNDEVEN,
790 ISD::STRICT_FMINNUM,
791 ISD::STRICT_FMAXNUM,
792 ISD::STRICT_FMINIMUM,
793 ISD::STRICT_FMAXIMUM,
794 })
795 setOperationAction(Op, ScalarVT, Promote);
796
797 for (auto Op : {ISD::FNEG, ISD::FABS})
798 setOperationAction(Op, ScalarVT, Legal);
799
800 // Round-to-integer need custom lowering for fp16, as Promote doesn't work
801 // because the result type is integer.
802 for (auto Op : {ISD::LROUND, ISD::LLROUND, ISD::LRINT, ISD::LLRINT,
803 ISD::STRICT_LROUND, ISD::STRICT_LLROUND, ISD::STRICT_LRINT,
804 ISD::STRICT_LLRINT})
805 setOperationAction(Op, ScalarVT, Custom);
806
807 // promote v4f16 to v4f32 when that is known to be safe.
808 auto V4Narrow = MVT::getVectorVT(ScalarVT, 4);
809 setOperationPromotedToType(ISD::FADD, V4Narrow, MVT::v4f32);
810 setOperationPromotedToType(ISD::FSUB, V4Narrow, MVT::v4f32);
811 setOperationPromotedToType(ISD::FMUL, V4Narrow, MVT::v4f32);
812 setOperationPromotedToType(ISD::FDIV, V4Narrow, MVT::v4f32);
813 setOperationPromotedToType(ISD::FCEIL, V4Narrow, MVT::v4f32);
814 setOperationPromotedToType(ISD::FFLOOR, V4Narrow, MVT::v4f32);
815 setOperationPromotedToType(ISD::FROUND, V4Narrow, MVT::v4f32);
816 setOperationPromotedToType(ISD::FTRUNC, V4Narrow, MVT::v4f32);
817 setOperationPromotedToType(ISD::FROUNDEVEN, V4Narrow, MVT::v4f32);
818 setOperationPromotedToType(ISD::FRINT, V4Narrow, MVT::v4f32);
819 setOperationPromotedToType(ISD::FNEARBYINT, V4Narrow, MVT::v4f32);
820
821 setOperationAction(ISD::FABS, V4Narrow, Legal);
822 setOperationAction(ISD::FNEG, V4Narrow, Legal);
823 setOperationAction(ISD::FMA, V4Narrow, Expand);
824 setOperationAction(ISD::SETCC, V4Narrow, Custom);
825 setOperationAction(ISD::BR_CC, V4Narrow, Expand);
826 setOperationAction(ISD::SELECT, V4Narrow, Expand);
827 setOperationAction(ISD::SELECT_CC, V4Narrow, Expand);
828 setOperationAction(ISD::FCOPYSIGN, V4Narrow, Custom);
829 setOperationAction(ISD::FSQRT, V4Narrow, Expand);
830
831 auto V8Narrow = MVT::getVectorVT(ScalarVT, 8);
832 setOperationAction(ISD::FABS, V8Narrow, Legal);
833 setOperationAction(ISD::FADD, V8Narrow, Legal);
834 setOperationAction(ISD::FCEIL, V8Narrow, Legal);
835 setOperationAction(ISD::FCOPYSIGN, V8Narrow, Custom);
836 setOperationAction(ISD::FDIV, V8Narrow, Legal);
837 setOperationAction(ISD::FFLOOR, V8Narrow, Legal);
838 setOperationAction(ISD::FMA, V8Narrow, Expand);
839 setOperationAction(ISD::FMUL, V8Narrow, Legal);
840 setOperationAction(ISD::FNEARBYINT, V8Narrow, Legal);
841 setOperationAction(ISD::FNEG, V8Narrow, Legal);
842 setOperationAction(ISD::FROUND, V8Narrow, Legal);
843 setOperationAction(ISD::FROUNDEVEN, V8Narrow, Legal);
844 setOperationAction(ISD::FRINT, V8Narrow, Legal);
845 setOperationAction(ISD::FSQRT, V8Narrow, Expand);
846 setOperationAction(ISD::FSUB, V8Narrow, Legal);
847 setOperationAction(ISD::FTRUNC, V8Narrow, Legal);
848 setOperationAction(ISD::SETCC, V8Narrow, Expand);
849 setOperationAction(ISD::BR_CC, V8Narrow, Expand);
850 setOperationAction(ISD::SELECT, V8Narrow, Expand);
851 setOperationAction(ISD::SELECT_CC, V8Narrow, Expand);
852 setOperationAction(ISD::FP_EXTEND, V8Narrow, Expand);
853 };
854
855 if (!Subtarget->hasFullFP16()) {
856 LegalizeNarrowFP(MVT::f16);
857 }
858 LegalizeNarrowFP(MVT::bf16);
859 setOperationAction(ISD::FP_ROUND, MVT::v4f32, Custom);
860 setOperationAction(ISD::FP_ROUND, MVT::v4bf16, Custom);
861
862 // AArch64 has implementations of a lot of rounding-like FP operations.
863 for (auto Op :
864 {ISD::FFLOOR, ISD::FNEARBYINT, ISD::FCEIL,
865 ISD::FRINT, ISD::FTRUNC, ISD::FROUND,
866 ISD::FROUNDEVEN, ISD::FMINNUM, ISD::FMAXNUM,
867 ISD::FMINIMUM, ISD::FMAXIMUM, ISD::LROUND,
868 ISD::LLROUND, ISD::LRINT, ISD::LLRINT,
869 ISD::STRICT_FFLOOR, ISD::STRICT_FCEIL, ISD::STRICT_FNEARBYINT,
870 ISD::STRICT_FRINT, ISD::STRICT_FTRUNC, ISD::STRICT_FROUNDEVEN,
871 ISD::STRICT_FROUND, ISD::STRICT_FMINNUM, ISD::STRICT_FMAXNUM,
872 ISD::STRICT_FMINIMUM, ISD::STRICT_FMAXIMUM, ISD::STRICT_LROUND,
873 ISD::STRICT_LLROUND, ISD::STRICT_LRINT, ISD::STRICT_LLRINT}) {
874 for (MVT Ty : {MVT::f32, MVT::f64})
875 setOperationAction(Op, Ty, Legal);
876 if (Subtarget->hasFullFP16())
877 setOperationAction(Op, MVT::f16, Legal);
878 }
879
880 // Basic strict FP operations are legal
881 for (auto Op : {ISD::STRICT_FADD, ISD::STRICT_FSUB, ISD::STRICT_FMUL,
882 ISD::STRICT_FDIV, ISD::STRICT_FMA, ISD::STRICT_FSQRT}) {
883 for (MVT Ty : {MVT::f32, MVT::f64})
884 setOperationAction(Op, Ty, Legal);
885 if (Subtarget->hasFullFP16())
886 setOperationAction(Op, MVT::f16, Legal);
887 }
888
889 // Strict conversion to a larger type is legal
890 for (auto VT : {MVT::f32, MVT::f64})
891 setOperationAction(ISD::STRICT_FP_EXTEND, VT, Legal);
892
893 setOperationAction(ISD::PREFETCH, MVT::Other, Custom);
894
895 setOperationAction(ISD::GET_ROUNDING, MVT::i32, Custom);
896 setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom);
897 setOperationAction(ISD::GET_FPMODE, MVT::i32, Custom);
898 setOperationAction(ISD::SET_FPMODE, MVT::i32, Custom);
899 setOperationAction(ISD::RESET_FPMODE, MVT::Other, Custom);
900
901 setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i128, Custom);
902 if (!Subtarget->hasLSE() && !Subtarget->outlineAtomics()) {
903 setOperationAction(ISD::ATOMIC_LOAD_SUB, MVT::i32, LibCall);
904 setOperationAction(ISD::ATOMIC_LOAD_SUB, MVT::i64, LibCall);
905 } else {
906 setOperationAction(ISD::ATOMIC_LOAD_SUB, MVT::i32, Expand);
907 setOperationAction(ISD::ATOMIC_LOAD_SUB, MVT::i64, Expand);
908 }
909 setOperationAction(ISD::ATOMIC_LOAD_AND, MVT::i32, Custom);
910 setOperationAction(ISD::ATOMIC_LOAD_AND, MVT::i64, Custom);
911
912 // Generate outline atomics library calls only if LSE was not specified for
913 // subtarget
914 if (Subtarget->outlineAtomics() && !Subtarget->hasLSE()) {
915 setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i8, LibCall);
916 setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i16, LibCall);
917 setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i32, LibCall);
918 setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i64, LibCall);
919 setOperationAction(ISD::ATOMIC_CMP_SWAP, MVT::i128, LibCall);
920 setOperationAction(ISD::ATOMIC_SWAP, MVT::i8, LibCall);
921 setOperationAction(ISD::ATOMIC_SWAP, MVT::i16, LibCall);
922 setOperationAction(ISD::ATOMIC_SWAP, MVT::i32, LibCall);
923 setOperationAction(ISD::ATOMIC_SWAP, MVT::i64, LibCall);
924 setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i8, LibCall);
925 setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i16, LibCall);
926 setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i32, LibCall);
927 setOperationAction(ISD::ATOMIC_LOAD_ADD, MVT::i64, LibCall);
928 setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i8, LibCall);
929 setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i16, LibCall);
930 setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i32, LibCall);
931 setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i64, LibCall);
932 setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i8, LibCall);
933 setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i16, LibCall);
934 setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i32, LibCall);
935 setOperationAction(ISD::ATOMIC_LOAD_CLR, MVT::i64, LibCall);
936 setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i8, LibCall);
937 setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i16, LibCall);
938 setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i32, LibCall);
939 setOperationAction(ISD::ATOMIC_LOAD_XOR, MVT::i64, LibCall);
940 #define LCALLNAMES(A, B, N) \
941 setLibcallName(A##N##_RELAX, #B #N "_relax"); \
942 setLibcallName(A##N##_ACQ, #B #N "_acq"); \
943 setLibcallName(A##N##_REL, #B #N "_rel"); \
944 setLibcallName(A##N##_ACQ_REL, #B #N "_acq_rel");
945 #define LCALLNAME4(A, B) \
946 LCALLNAMES(A, B, 1) \
947 LCALLNAMES(A, B, 2) LCALLNAMES(A, B, 4) LCALLNAMES(A, B, 8)
948 #define LCALLNAME5(A, B) \
949 LCALLNAMES(A, B, 1) \
950 LCALLNAMES(A, B, 2) \
951 LCALLNAMES(A, B, 4) LCALLNAMES(A, B, 8) LCALLNAMES(A, B, 16)
952 LCALLNAME5(RTLIB::OUTLINE_ATOMIC_CAS, __aarch64_cas)
953 LCALLNAME4(RTLIB::OUTLINE_ATOMIC_SWP, __aarch64_swp)
954 LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDADD, __aarch64_ldadd)
955 LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDSET, __aarch64_ldset)
956 LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDCLR, __aarch64_ldclr)
957 LCALLNAME4(RTLIB::OUTLINE_ATOMIC_LDEOR, __aarch64_ldeor)
958 #undef LCALLNAMES
959 #undef LCALLNAME4
960 #undef LCALLNAME5
961 }
962
963 if (Subtarget->hasLSE128()) {
964 // Custom lowering because i128 is not legal. Must be replaced by 2x64
965 // values. ATOMIC_LOAD_AND also needs op legalisation to emit LDCLRP.
966 setOperationAction(ISD::ATOMIC_LOAD_AND, MVT::i128, Custom);
967 setOperationAction(ISD::ATOMIC_LOAD_OR, MVT::i128, Custom);
968 setOperationAction(ISD::ATOMIC_SWAP, MVT::i128, Custom);
969 }
970
971 // 128-bit loads and stores can be done without expanding
972 setOperationAction(ISD::LOAD, MVT::i128, Custom);
973 setOperationAction(ISD::STORE, MVT::i128, Custom);
974
975 // Aligned 128-bit loads and stores are single-copy atomic according to the
976 // v8.4a spec. LRCPC3 introduces 128-bit STILP/LDIAPP but still requires LSE2.
977 if (Subtarget->hasLSE2()) {
978 setOperationAction(ISD::ATOMIC_LOAD, MVT::i128, Custom);
979 setOperationAction(ISD::ATOMIC_STORE, MVT::i128, Custom);
980 }
981
982 // 256 bit non-temporal stores can be lowered to STNP. Do this as part of the
983 // custom lowering, as there are no un-paired non-temporal stores and
984 // legalization will break up 256 bit inputs.
985 setOperationAction(ISD::STORE, MVT::v32i8, Custom);
986 setOperationAction(ISD::STORE, MVT::v16i16, Custom);
987 setOperationAction(ISD::STORE, MVT::v16f16, Custom);
988 setOperationAction(ISD::STORE, MVT::v16bf16, Custom);
989 setOperationAction(ISD::STORE, MVT::v8i32, Custom);
990 setOperationAction(ISD::STORE, MVT::v8f32, Custom);
991 setOperationAction(ISD::STORE, MVT::v4f64, Custom);
992 setOperationAction(ISD::STORE, MVT::v4i64, Custom);
993
994 // 256 bit non-temporal loads can be lowered to LDNP. This is done using
995 // custom lowering, as there are no un-paired non-temporal loads legalization
996 // will break up 256 bit inputs.
997 setOperationAction(ISD::LOAD, MVT::v32i8, Custom);
998 setOperationAction(ISD::LOAD, MVT::v16i16, Custom);
999 setOperationAction(ISD::LOAD, MVT::v16f16, Custom);
1000 setOperationAction(ISD::LOAD, MVT::v16bf16, Custom);
1001 setOperationAction(ISD::LOAD, MVT::v8i32, Custom);
1002 setOperationAction(ISD::LOAD, MVT::v8f32, Custom);
1003 setOperationAction(ISD::LOAD, MVT::v4f64, Custom);
1004 setOperationAction(ISD::LOAD, MVT::v4i64, Custom);
1005
1006 // Lower READCYCLECOUNTER using an mrs from CNTVCT_EL0.
1007 setOperationAction(ISD::READCYCLECOUNTER, MVT::i64, Legal);
1008
1009 if (getLibcallName(RTLIB::SINCOS_STRET_F32) != nullptr &&
1010 getLibcallName(RTLIB::SINCOS_STRET_F64) != nullptr) {
1011 // Issue __sincos_stret if available.
1012 setOperationAction(ISD::FSINCOS, MVT::f64, Custom);
1013 setOperationAction(ISD::FSINCOS, MVT::f32, Custom);
1014 } else {
1015 setOperationAction(ISD::FSINCOS, MVT::f64, Expand);
1016 setOperationAction(ISD::FSINCOS, MVT::f32, Expand);
1017 }
1018
1019 // Make floating-point constants legal for the large code model, so they don't
1020 // become loads from the constant pool.
1021 if (Subtarget->isTargetMachO() && TM.getCodeModel() == CodeModel::Large) {
1022 setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
1023 setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
1024 }
1025
1026 // AArch64 does not have floating-point extending loads, i1 sign-extending
1027 // load, floating-point truncating stores, or v2i32->v2i16 truncating store.
1028 for (MVT VT : MVT::fp_valuetypes()) {
1029 setLoadExtAction(ISD::EXTLOAD, VT, MVT::bf16, Expand);
1030 setLoadExtAction(ISD::EXTLOAD, VT, MVT::f16, Expand);
1031 setLoadExtAction(ISD::EXTLOAD, VT, MVT::f32, Expand);
1032 setLoadExtAction(ISD::EXTLOAD, VT, MVT::f64, Expand);
1033 setLoadExtAction(ISD::EXTLOAD, VT, MVT::f80, Expand);
1034 }
1035 for (MVT VT : MVT::integer_valuetypes())
1036 setLoadExtAction(ISD::SEXTLOAD, VT, MVT::i1, Expand);
1037
1038 for (MVT WideVT : MVT::fp_valuetypes()) {
1039 for (MVT NarrowVT : MVT::fp_valuetypes()) {
1040 if (WideVT.getScalarSizeInBits() > NarrowVT.getScalarSizeInBits()) {
1041 setTruncStoreAction(WideVT, NarrowVT, Expand);
1042 }
1043 }
1044 }
1045
1046 if (Subtarget->hasFPARMv8()) {
1047 setOperationAction(ISD::BITCAST, MVT::i16, Custom);
1048 setOperationAction(ISD::BITCAST, MVT::f16, Custom);
1049 setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
1050 }
1051
1052 // Indexed loads and stores are supported.
1053 for (unsigned im = (unsigned)ISD::PRE_INC;
1054 im != (unsigned)ISD::LAST_INDEXED_MODE; ++im) {
1055 setIndexedLoadAction(im, MVT::i8, Legal);
1056 setIndexedLoadAction(im, MVT::i16, Legal);
1057 setIndexedLoadAction(im, MVT::i32, Legal);
1058 setIndexedLoadAction(im, MVT::i64, Legal);
1059 setIndexedLoadAction(im, MVT::f64, Legal);
1060 setIndexedLoadAction(im, MVT::f32, Legal);
1061 setIndexedLoadAction(im, MVT::f16, Legal);
1062 setIndexedLoadAction(im, MVT::bf16, Legal);
1063 setIndexedStoreAction(im, MVT::i8, Legal);
1064 setIndexedStoreAction(im, MVT::i16, Legal);
1065 setIndexedStoreAction(im, MVT::i32, Legal);
1066 setIndexedStoreAction(im, MVT::i64, Legal);
1067 setIndexedStoreAction(im, MVT::f64, Legal);
1068 setIndexedStoreAction(im, MVT::f32, Legal);
1069 setIndexedStoreAction(im, MVT::f16, Legal);
1070 setIndexedStoreAction(im, MVT::bf16, Legal);
1071 }
1072
1073 // Trap.
1074 setOperationAction(ISD::TRAP, MVT::Other, Legal);
1075 setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal);
1076 setOperationAction(ISD::UBSANTRAP, MVT::Other, Legal);
1077
1078 // We combine OR nodes for bitfield operations.
1079 setTargetDAGCombine(ISD::OR);
1080 // Try to create BICs for vector ANDs.
1081 setTargetDAGCombine(ISD::AND);
1082
1083 // llvm.init.trampoline and llvm.adjust.trampoline
1084 setOperationAction(ISD::INIT_TRAMPOLINE, MVT::Other, Custom);
1085 setOperationAction(ISD::ADJUST_TRAMPOLINE, MVT::Other, Custom);
1086
1087 // Vector add and sub nodes may conceal a high-half opportunity.
1088 // Also, try to fold ADD into CSINC/CSINV..
1089 setTargetDAGCombine({ISD::ADD, ISD::ABS, ISD::SUB, ISD::XOR, ISD::SINT_TO_FP,
1090 ISD::UINT_TO_FP});
1091
1092 setTargetDAGCombine({ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
1093 ISD::FP_TO_UINT_SAT, ISD::FADD});
1094
1095 // Try and combine setcc with csel
1096 setTargetDAGCombine(ISD::SETCC);
1097
1098 setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN);
1099
1100 setTargetDAGCombine({ISD::ANY_EXTEND, ISD::ZERO_EXTEND, ISD::SIGN_EXTEND,
1101 ISD::SIGN_EXTEND_INREG, ISD::CONCAT_VECTORS,
1102 ISD::EXTRACT_SUBVECTOR, ISD::INSERT_SUBVECTOR,
1103 ISD::STORE, ISD::BUILD_VECTOR});
1104 setTargetDAGCombine(ISD::TRUNCATE);
1105 setTargetDAGCombine(ISD::LOAD);
1106
1107 setTargetDAGCombine(ISD::MSTORE);
1108
1109 setTargetDAGCombine(ISD::MUL);
1110
1111 setTargetDAGCombine({ISD::SELECT, ISD::VSELECT});
1112
1113 setTargetDAGCombine({ISD::INTRINSIC_VOID, ISD::INTRINSIC_W_CHAIN,
1114 ISD::INSERT_VECTOR_ELT, ISD::EXTRACT_VECTOR_ELT,
1115 ISD::VECREDUCE_ADD, ISD::STEP_VECTOR});
1116
1117 setTargetDAGCombine({ISD::MGATHER, ISD::MSCATTER});
1118
1119 setTargetDAGCombine(ISD::FP_EXTEND);
1120
1121 setTargetDAGCombine(ISD::GlobalAddress);
1122
1123 setTargetDAGCombine(ISD::CTLZ);
1124
1125 setTargetDAGCombine(ISD::VECREDUCE_AND);
1126 setTargetDAGCombine(ISD::VECREDUCE_OR);
1127 setTargetDAGCombine(ISD::VECREDUCE_XOR);
1128
1129 setTargetDAGCombine(ISD::SCALAR_TO_VECTOR);
1130
1131 // In case of strict alignment, avoid an excessive number of byte wide stores.
1132 MaxStoresPerMemsetOptSize = 8;
1133 MaxStoresPerMemset =
1134 Subtarget->requiresStrictAlign() ? MaxStoresPerMemsetOptSize : 32;
1135
1136 MaxGluedStoresPerMemcpy = 4;
1137 MaxStoresPerMemcpyOptSize = 4;
1138 MaxStoresPerMemcpy =
1139 Subtarget->requiresStrictAlign() ? MaxStoresPerMemcpyOptSize : 16;
1140
1141 MaxStoresPerMemmoveOptSize = 4;
1142 MaxStoresPerMemmove = 4;
1143
1144 MaxLoadsPerMemcmpOptSize = 4;
1145 MaxLoadsPerMemcmp =
1146 Subtarget->requiresStrictAlign() ? MaxLoadsPerMemcmpOptSize : 8;
1147
1148 setStackPointerRegisterToSaveRestore(AArch64::SP);
1149
1150 setSchedulingPreference(Sched::Hybrid);
1151
1152 EnableExtLdPromotion = true;
1153
1154 // Set required alignment.
1155 setMinFunctionAlignment(Align(4));
1156 // Set preferred alignments.
1157
1158 // Don't align loops on Windows. The SEH unwind info generation needs to
1159 // know the exact length of functions before the alignments have been
1160 // expanded.
1161 if (!Subtarget->isTargetWindows())
1162 setPrefLoopAlignment(STI.getPrefLoopAlignment());
1163 setMaxBytesForAlignment(STI.getMaxBytesForLoopAlignment());
1164 setPrefFunctionAlignment(STI.getPrefFunctionAlignment());
1165
1166 // Only change the limit for entries in a jump table if specified by
1167 // the sub target, but not at the command line.
1168 unsigned MaxJT = STI.getMaximumJumpTableSize();
1169 if (MaxJT && getMaximumJumpTableSize() == UINT_MAX)
1170 setMaximumJumpTableSize(MaxJT);
1171
1172 setHasExtractBitsInsn(true);
1173
1174 setMaxDivRemBitWidthSupported(128);
1175
1176 setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
1177
1178 if (Subtarget->isNeonAvailable()) {
1179 // FIXME: v1f64 shouldn't be legal if we can avoid it, because it leads to
1180 // silliness like this:
1181 // clang-format off
1182 for (auto Op :
1183 {ISD::SELECT, ISD::SELECT_CC,
1184 ISD::BR_CC, ISD::FADD, ISD::FSUB,
1185 ISD::FMUL, ISD::FDIV, ISD::FMA,
1186 ISD::FNEG, ISD::FABS, ISD::FCEIL,
1187 ISD::FSQRT, ISD::FFLOOR, ISD::FNEARBYINT,
1188 ISD::FSIN, ISD::FCOS, ISD::FTAN,
1189 ISD::FASIN, ISD::FACOS, ISD::FATAN,
1190 ISD::FSINH, ISD::FCOSH, ISD::FTANH,
1191 ISD::FPOW, ISD::FLOG, ISD::FLOG2,
1192 ISD::FLOG10, ISD::FEXP, ISD::FEXP2,
1193 ISD::FEXP10, ISD::FRINT, ISD::FROUND,
1194 ISD::FROUNDEVEN, ISD::FTRUNC, ISD::FMINNUM,
1195 ISD::FMAXNUM, ISD::FMINIMUM, ISD::FMAXIMUM,
1196 ISD::STRICT_FADD, ISD::STRICT_FSUB, ISD::STRICT_FMUL,
1197 ISD::STRICT_FDIV, ISD::STRICT_FMA, ISD::STRICT_FCEIL,
1198 ISD::STRICT_FFLOOR, ISD::STRICT_FSQRT, ISD::STRICT_FRINT,
1199 ISD::STRICT_FNEARBYINT, ISD::STRICT_FROUND, ISD::STRICT_FTRUNC,
1200 ISD::STRICT_FROUNDEVEN, ISD::STRICT_FMINNUM, ISD::STRICT_FMAXNUM,
1201 ISD::STRICT_FMINIMUM, ISD::STRICT_FMAXIMUM})
1202 setOperationAction(Op, MVT::v1f64, Expand);
1203 // clang-format on
1204 for (auto Op :
1205 {ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::SINT_TO_FP, ISD::UINT_TO_FP,
1206 ISD::FP_ROUND, ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT, ISD::MUL,
1207 ISD::STRICT_FP_TO_SINT, ISD::STRICT_FP_TO_UINT,
1208 ISD::STRICT_SINT_TO_FP, ISD::STRICT_UINT_TO_FP, ISD::STRICT_FP_ROUND})
1209 setOperationAction(Op, MVT::v1i64, Expand);
1210
1211 // AArch64 doesn't have a direct vector ->f32 conversion instructions for
1212 // elements smaller than i32, so promote the input to i32 first.
1213 setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v4i8, MVT::v4i32);
1214 setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v4i8, MVT::v4i32);
1215
1216 // Similarly, there is no direct i32 -> f64 vector conversion instruction.
1217 // Or, direct i32 -> f16 vector conversion. Set it so custom, so the
1218 // conversion happens in two steps: v4i32 -> v4f32 -> v4f16
1219 for (auto Op : {ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::STRICT_SINT_TO_FP,
1220 ISD::STRICT_UINT_TO_FP})
1221 for (auto VT : {MVT::v2i32, MVT::v2i64, MVT::v4i32})
1222 setOperationAction(Op, VT, Custom);
1223
1224 if (Subtarget->hasFullFP16()) {
1225 setOperationAction(ISD::ConstantFP, MVT::f16, Legal);
1226 setOperationAction(ISD::ConstantFP, MVT::bf16, Legal);
1227
1228 setOperationAction(ISD::SINT_TO_FP, MVT::v8i8, Custom);
1229 setOperationAction(ISD::UINT_TO_FP, MVT::v8i8, Custom);
1230 setOperationAction(ISD::SINT_TO_FP, MVT::v16i8, Custom);
1231 setOperationAction(ISD::UINT_TO_FP, MVT::v16i8, Custom);
1232 setOperationAction(ISD::SINT_TO_FP, MVT::v4i16, Custom);
1233 setOperationAction(ISD::UINT_TO_FP, MVT::v4i16, Custom);
1234 setOperationAction(ISD::SINT_TO_FP, MVT::v8i16, Custom);
1235 setOperationAction(ISD::UINT_TO_FP, MVT::v8i16, Custom);
1236 } else {
1237 // when AArch64 doesn't have fullfp16 support, promote the input
1238 // to i32 first.
1239 setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v8i8, MVT::v8i32);
1240 setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v8i8, MVT::v8i32);
1241 setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v16i8, MVT::v16i32);
1242 setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v16i8, MVT::v16i32);
1243 setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v4i16, MVT::v4i32);
1244 setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v4i16, MVT::v4i32);
1245 setOperationPromotedToType(ISD::SINT_TO_FP, MVT::v8i16, MVT::v8i32);
1246 setOperationPromotedToType(ISD::UINT_TO_FP, MVT::v8i16, MVT::v8i32);
1247 }
1248
1249 setOperationAction(ISD::CTLZ, MVT::v1i64, Expand);
1250 setOperationAction(ISD::CTLZ, MVT::v2i64, Expand);
1251 setOperationAction(ISD::BITREVERSE, MVT::v8i8, Legal);
1252 setOperationAction(ISD::BITREVERSE, MVT::v16i8, Legal);
1253 setOperationAction(ISD::BITREVERSE, MVT::v2i32, Custom);
1254 setOperationAction(ISD::BITREVERSE, MVT::v4i32, Custom);
1255 setOperationAction(ISD::BITREVERSE, MVT::v1i64, Custom);
1256 setOperationAction(ISD::BITREVERSE, MVT::v2i64, Custom);
1257 for (auto VT : {MVT::v1i64, MVT::v2i64}) {
1258 setOperationAction(ISD::UMAX, VT, Custom);
1259 setOperationAction(ISD::SMAX, VT, Custom);
1260 setOperationAction(ISD::UMIN, VT, Custom);
1261 setOperationAction(ISD::SMIN, VT, Custom);
1262 }
1263
1264 // Custom handling for some quad-vector types to detect MULL.
1265 setOperationAction(ISD::MUL, MVT::v8i16, Custom);
1266 setOperationAction(ISD::MUL, MVT::v4i32, Custom);
1267 setOperationAction(ISD::MUL, MVT::v2i64, Custom);
1268 setOperationAction(ISD::MUL, MVT::v4i16, Custom);
1269 setOperationAction(ISD::MUL, MVT::v2i32, Custom);
1270 setOperationAction(ISD::MUL, MVT::v1i64, Custom);
1271
1272 // Saturates
1273 for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32,
1274 MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) {
1275 setOperationAction(ISD::SADDSAT, VT, Legal);
1276 setOperationAction(ISD::UADDSAT, VT, Legal);
1277 setOperationAction(ISD::SSUBSAT, VT, Legal);
1278 setOperationAction(ISD::USUBSAT, VT, Legal);
1279 }
1280
1281 for (MVT VT : {MVT::v8i8, MVT::v4i16, MVT::v2i32, MVT::v16i8, MVT::v8i16,
1282 MVT::v4i32}) {
1283 setOperationAction(ISD::AVGFLOORS, VT, Legal);
1284 setOperationAction(ISD::AVGFLOORU, VT, Legal);
1285 setOperationAction(ISD::AVGCEILS, VT, Legal);
1286 setOperationAction(ISD::AVGCEILU, VT, Legal);
1287 setOperationAction(ISD::ABDS, VT, Legal);
1288 setOperationAction(ISD::ABDU, VT, Legal);
1289 }
1290
1291 // Vector reductions
1292 for (MVT VT : { MVT::v4f16, MVT::v2f32,
1293 MVT::v8f16, MVT::v4f32, MVT::v2f64 }) {
1294 if (VT.getVectorElementType() != MVT::f16 || Subtarget->hasFullFP16()) {
1295 setOperationAction(ISD::VECREDUCE_FMAX, VT, Legal);
1296 setOperationAction(ISD::VECREDUCE_FMIN, VT, Legal);
1297 setOperationAction(ISD::VECREDUCE_FMAXIMUM, VT, Legal);
1298 setOperationAction(ISD::VECREDUCE_FMINIMUM, VT, Legal);
1299
1300 setOperationAction(ISD::VECREDUCE_FADD, VT, Legal);
1301 }
1302 }
1303 for (MVT VT : { MVT::v8i8, MVT::v4i16, MVT::v2i32,
1304 MVT::v16i8, MVT::v8i16, MVT::v4i32 }) {
1305 setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
1306 setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
1307 setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
1308 setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
1309 setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
1310 setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
1311 setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
1312 setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1313 }
1314 setOperationAction(ISD::VECREDUCE_ADD, MVT::v2i64, Custom);
1315 setOperationAction(ISD::VECREDUCE_AND, MVT::v2i64, Custom);
1316 setOperationAction(ISD::VECREDUCE_OR, MVT::v2i64, Custom);
1317 setOperationAction(ISD::VECREDUCE_XOR, MVT::v2i64, Custom);
1318
1319 setOperationAction(ISD::ANY_EXTEND, MVT::v4i32, Legal);
1320 setTruncStoreAction(MVT::v2i32, MVT::v2i16, Expand);
1321 // Likewise, narrowing and extending vector loads/stores aren't handled
1322 // directly.
1323 for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
1324 setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Expand);
1325
1326 if (VT == MVT::v16i8 || VT == MVT::v8i16 || VT == MVT::v4i32) {
1327 setOperationAction(ISD::MULHS, VT, Legal);
1328 setOperationAction(ISD::MULHU, VT, Legal);
1329 } else {
1330 setOperationAction(ISD::MULHS, VT, Expand);
1331 setOperationAction(ISD::MULHU, VT, Expand);
1332 }
1333 setOperationAction(ISD::SMUL_LOHI, VT, Expand);
1334 setOperationAction(ISD::UMUL_LOHI, VT, Expand);
1335
1336 setOperationAction(ISD::BSWAP, VT, Expand);
1337 setOperationAction(ISD::CTTZ, VT, Expand);
1338
1339 for (MVT InnerVT : MVT::fixedlen_vector_valuetypes()) {
1340 setTruncStoreAction(VT, InnerVT, Expand);
1341 setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Expand);
1342 setLoadExtAction(ISD::ZEXTLOAD, VT, InnerVT, Expand);
1343 setLoadExtAction(ISD::EXTLOAD, VT, InnerVT, Expand);
1344 }
1345 }
1346
1347 // AArch64 has implementations of a lot of rounding-like FP operations.
1348 for (auto Op :
1349 {ISD::FFLOOR, ISD::FNEARBYINT, ISD::FCEIL, ISD::FRINT, ISD::FTRUNC,
1350 ISD::FROUND, ISD::FROUNDEVEN, ISD::STRICT_FFLOOR,
1351 ISD::STRICT_FNEARBYINT, ISD::STRICT_FCEIL, ISD::STRICT_FRINT,
1352 ISD::STRICT_FTRUNC, ISD::STRICT_FROUND, ISD::STRICT_FROUNDEVEN}) {
1353 for (MVT Ty : {MVT::v2f32, MVT::v4f32, MVT::v2f64})
1354 setOperationAction(Op, Ty, Legal);
1355 if (Subtarget->hasFullFP16())
1356 for (MVT Ty : {MVT::v4f16, MVT::v8f16})
1357 setOperationAction(Op, Ty, Legal);
1358 }
1359
1360 // LRINT and LLRINT.
1361 for (auto Op : {ISD::LRINT, ISD::LLRINT}) {
1362 for (MVT Ty : {MVT::v2f32, MVT::v4f32, MVT::v2f64})
1363 setOperationAction(Op, Ty, Custom);
1364 if (Subtarget->hasFullFP16())
1365 for (MVT Ty : {MVT::v4f16, MVT::v8f16})
1366 setOperationAction(Op, Ty, Custom);
1367 }
1368
1369 setTruncStoreAction(MVT::v4i16, MVT::v4i8, Custom);
1370
1371 setOperationAction(ISD::BITCAST, MVT::i2, Custom);
1372 setOperationAction(ISD::BITCAST, MVT::i4, Custom);
1373 setOperationAction(ISD::BITCAST, MVT::i8, Custom);
1374 setOperationAction(ISD::BITCAST, MVT::i16, Custom);
1375
1376 setOperationAction(ISD::BITCAST, MVT::v2i8, Custom);
1377 setOperationAction(ISD::BITCAST, MVT::v2i16, Custom);
1378 setOperationAction(ISD::BITCAST, MVT::v4i8, Custom);
1379
1380 setLoadExtAction(ISD::EXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
1381 setLoadExtAction(ISD::SEXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
1382 setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i16, MVT::v4i8, Custom);
1383 setLoadExtAction(ISD::EXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
1384 setLoadExtAction(ISD::SEXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
1385 setLoadExtAction(ISD::ZEXTLOAD, MVT::v4i32, MVT::v4i8, Custom);
1386
1387 // ADDP custom lowering
1388 for (MVT VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64 })
1389 setOperationAction(ISD::ADD, VT, Custom);
1390 // FADDP custom lowering
1391 for (MVT VT : { MVT::v16f16, MVT::v8f32, MVT::v4f64 })
1392 setOperationAction(ISD::FADD, VT, Custom);
1393 } else /* !isNeonAvailable */ {
1394 for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
1395 for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op)
1396 setOperationAction(Op, VT, Expand);
1397
1398 if (VT.is128BitVector() || VT.is64BitVector()) {
1399 setOperationAction(ISD::LOAD, VT, Legal);
1400 setOperationAction(ISD::STORE, VT, Legal);
1401 setOperationAction(ISD::BITCAST, VT,
1402 Subtarget->isLittleEndian() ? Legal : Expand);
1403 }
1404 for (MVT InnerVT : MVT::fixedlen_vector_valuetypes()) {
1405 setTruncStoreAction(VT, InnerVT, Expand);
1406 setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Expand);
1407 setLoadExtAction(ISD::ZEXTLOAD, VT, InnerVT, Expand);
1408 setLoadExtAction(ISD::EXTLOAD, VT, InnerVT, Expand);
1409 }
1410 }
1411 }
1412
1413 if (Subtarget->hasSME()) {
1414 setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
1415 }
1416
1417 // FIXME: Move lowering for more nodes here if those are common between
1418 // SVE and SME.
1419 if (Subtarget->isSVEorStreamingSVEAvailable()) {
1420 for (auto VT :
1421 {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1}) {
1422 setOperationAction(ISD::SPLAT_VECTOR, VT, Custom);
1423 setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
1424 setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom);
1425 setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom);
1426 }
1427 }
1428
1429 if (Subtarget->isSVEorStreamingSVEAvailable()) {
1430 for (auto VT : {MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64}) {
1431 setOperationAction(ISD::BITREVERSE, VT, Custom);
1432 setOperationAction(ISD::BSWAP, VT, Custom);
1433 setOperationAction(ISD::CTLZ, VT, Custom);
1434 setOperationAction(ISD::CTPOP, VT, Custom);
1435 setOperationAction(ISD::CTTZ, VT, Custom);
1436 setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
1437 setOperationAction(ISD::UINT_TO_FP, VT, Custom);
1438 setOperationAction(ISD::SINT_TO_FP, VT, Custom);
1439 setOperationAction(ISD::FP_TO_UINT, VT, Custom);
1440 setOperationAction(ISD::FP_TO_SINT, VT, Custom);
1441 setOperationAction(ISD::MLOAD, VT, Custom);
1442 setOperationAction(ISD::MUL, VT, Custom);
1443 setOperationAction(ISD::MULHS, VT, Custom);
1444 setOperationAction(ISD::MULHU, VT, Custom);
1445 setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
1446 setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
1447 setOperationAction(ISD::SELECT, VT, Custom);
1448 setOperationAction(ISD::SETCC, VT, Custom);
1449 setOperationAction(ISD::SDIV, VT, Custom);
1450 setOperationAction(ISD::UDIV, VT, Custom);
1451 setOperationAction(ISD::SMIN, VT, Custom);
1452 setOperationAction(ISD::UMIN, VT, Custom);
1453 setOperationAction(ISD::SMAX, VT, Custom);
1454 setOperationAction(ISD::UMAX, VT, Custom);
1455 setOperationAction(ISD::SHL, VT, Custom);
1456 setOperationAction(ISD::SRL, VT, Custom);
1457 setOperationAction(ISD::SRA, VT, Custom);
1458 setOperationAction(ISD::ABS, VT, Custom);
1459 setOperationAction(ISD::ABDS, VT, Custom);
1460 setOperationAction(ISD::ABDU, VT, Custom);
1461 setOperationAction(ISD::VECREDUCE_ADD, VT, Custom);
1462 setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
1463 setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
1464 setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1465 setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
1466 setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
1467 setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
1468 setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
1469 setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom);
1470 setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom);
1471
1472 setOperationAction(ISD::UMUL_LOHI, VT, Expand);
1473 setOperationAction(ISD::SMUL_LOHI, VT, Expand);
1474 setOperationAction(ISD::SELECT_CC, VT, Expand);
1475 setOperationAction(ISD::ROTL, VT, Expand);
1476 setOperationAction(ISD::ROTR, VT, Expand);
1477
1478 setOperationAction(ISD::SADDSAT, VT, Legal);
1479 setOperationAction(ISD::UADDSAT, VT, Legal);
1480 setOperationAction(ISD::SSUBSAT, VT, Legal);
1481 setOperationAction(ISD::USUBSAT, VT, Legal);
1482 setOperationAction(ISD::UREM, VT, Expand);
1483 setOperationAction(ISD::SREM, VT, Expand);
1484 setOperationAction(ISD::SDIVREM, VT, Expand);
1485 setOperationAction(ISD::UDIVREM, VT, Expand);
1486
1487 setOperationAction(ISD::AVGFLOORS, VT, Custom);
1488 setOperationAction(ISD::AVGFLOORU, VT, Custom);
1489 setOperationAction(ISD::AVGCEILS, VT, Custom);
1490 setOperationAction(ISD::AVGCEILU, VT, Custom);
1491
1492 if (!Subtarget->isLittleEndian())
1493 setOperationAction(ISD::BITCAST, VT, Expand);
1494
1495 if (Subtarget->hasSVE2() ||
1496 (Subtarget->hasSME() && Subtarget->isStreaming()))
1497 // For SLI/SRI.
1498 setOperationAction(ISD::OR, VT, Custom);
1499 }
1500
1501 // Illegal unpacked integer vector types.
1502 for (auto VT : {MVT::nxv8i8, MVT::nxv4i16, MVT::nxv2i32}) {
1503 setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
1504 setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
1505 }
1506
1507 // Legalize unpacked bitcasts to REINTERPRET_CAST.
1508 for (auto VT : {MVT::nxv2i16, MVT::nxv4i16, MVT::nxv2i32, MVT::nxv2bf16,
1509 MVT::nxv4bf16, MVT::nxv2f16, MVT::nxv4f16, MVT::nxv2f32})
1510 setOperationAction(ISD::BITCAST, VT, Custom);
1511
1512 for (auto VT :
1513 { MVT::nxv2i8, MVT::nxv2i16, MVT::nxv2i32, MVT::nxv2i64, MVT::nxv4i8,
1514 MVT::nxv4i16, MVT::nxv4i32, MVT::nxv8i8, MVT::nxv8i16 })
1515 setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Legal);
1516
1517 for (auto VT :
1518 {MVT::nxv16i1, MVT::nxv8i1, MVT::nxv4i1, MVT::nxv2i1, MVT::nxv1i1}) {
1519 setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
1520 setOperationAction(ISD::SELECT, VT, Custom);
1521 setOperationAction(ISD::SETCC, VT, Custom);
1522 setOperationAction(ISD::TRUNCATE, VT, Custom);
1523 setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
1524 setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
1525 setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1526
1527 setOperationAction(ISD::SELECT_CC, VT, Expand);
1528 setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
1529 setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
1530
1531 // There are no legal MVT::nxv16f## based types.
1532 if (VT != MVT::nxv16i1) {
1533 setOperationAction(ISD::SINT_TO_FP, VT, Custom);
1534 setOperationAction(ISD::UINT_TO_FP, VT, Custom);
1535 }
1536 }
1537
1538 // NEON doesn't support masked loads/stores, but SME and SVE do.
1539 for (auto VT :
1540 {MVT::v4f16, MVT::v8f16, MVT::v2f32, MVT::v4f32, MVT::v1f64,
1541 MVT::v2f64, MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16,
1542 MVT::v2i32, MVT::v4i32, MVT::v1i64, MVT::v2i64}) {
1543 setOperationAction(ISD::MLOAD, VT, Custom);
1544 setOperationAction(ISD::MSTORE, VT, Custom);
1545 }
1546
1547 // Firstly, exclude all scalable vector extending loads/truncating stores,
1548 // include both integer and floating scalable vector.
1549 for (MVT VT : MVT::scalable_vector_valuetypes()) {
1550 for (MVT InnerVT : MVT::scalable_vector_valuetypes()) {
1551 setTruncStoreAction(VT, InnerVT, Expand);
1552 setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Expand);
1553 setLoadExtAction(ISD::ZEXTLOAD, VT, InnerVT, Expand);
1554 setLoadExtAction(ISD::EXTLOAD, VT, InnerVT, Expand);
1555 }
1556 }
1557
1558 // Then, selectively enable those which we directly support.
1559 setTruncStoreAction(MVT::nxv2i64, MVT::nxv2i8, Legal);
1560 setTruncStoreAction(MVT::nxv2i64, MVT::nxv2i16, Legal);
1561 setTruncStoreAction(MVT::nxv2i64, MVT::nxv2i32, Legal);
1562 setTruncStoreAction(MVT::nxv4i32, MVT::nxv4i8, Legal);
1563 setTruncStoreAction(MVT::nxv4i32, MVT::nxv4i16, Legal);
1564 setTruncStoreAction(MVT::nxv8i16, MVT::nxv8i8, Legal);
1565 for (auto Op : {ISD::ZEXTLOAD, ISD::SEXTLOAD, ISD::EXTLOAD}) {
1566 setLoadExtAction(Op, MVT::nxv2i64, MVT::nxv2i8, Legal);
1567 setLoadExtAction(Op, MVT::nxv2i64, MVT::nxv2i16, Legal);
1568 setLoadExtAction(Op, MVT::nxv2i64, MVT::nxv2i32, Legal);
1569 setLoadExtAction(Op, MVT::nxv4i32, MVT::nxv4i8, Legal);
1570 setLoadExtAction(Op, MVT::nxv4i32, MVT::nxv4i16, Legal);
1571 setLoadExtAction(Op, MVT::nxv8i16, MVT::nxv8i8, Legal);
1572 }
1573
1574 // SVE supports truncating stores of 64 and 128-bit vectors
1575 setTruncStoreAction(MVT::v2i64, MVT::v2i8, Custom);
1576 setTruncStoreAction(MVT::v2i64, MVT::v2i16, Custom);
1577 setTruncStoreAction(MVT::v2i64, MVT::v2i32, Custom);
1578 setTruncStoreAction(MVT::v2i32, MVT::v2i8, Custom);
1579 setTruncStoreAction(MVT::v2i32, MVT::v2i16, Custom);
1580
1581 for (auto VT : {MVT::nxv2f16, MVT::nxv4f16, MVT::nxv8f16, MVT::nxv2f32,
1582 MVT::nxv4f32, MVT::nxv2f64}) {
1583 setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
1584 setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
1585 setOperationAction(ISD::MLOAD, VT, Custom);
1586 setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
1587 setOperationAction(ISD::SELECT, VT, Custom);
1588 setOperationAction(ISD::SETCC, VT, Custom);
1589 setOperationAction(ISD::FADD, VT, Custom);
1590 setOperationAction(ISD::FCOPYSIGN, VT, Custom);
1591 setOperationAction(ISD::FDIV, VT, Custom);
1592 setOperationAction(ISD::FMA, VT, Custom);
1593 setOperationAction(ISD::FMAXIMUM, VT, Custom);
1594 setOperationAction(ISD::FMAXNUM, VT, Custom);
1595 setOperationAction(ISD::FMINIMUM, VT, Custom);
1596 setOperationAction(ISD::FMINNUM, VT, Custom);
1597 setOperationAction(ISD::FMUL, VT, Custom);
1598 setOperationAction(ISD::FNEG, VT, Custom);
1599 setOperationAction(ISD::FSUB, VT, Custom);
1600 setOperationAction(ISD::FCEIL, VT, Custom);
1601 setOperationAction(ISD::FFLOOR, VT, Custom);
1602 setOperationAction(ISD::FNEARBYINT, VT, Custom);
1603 setOperationAction(ISD::FRINT, VT, Custom);
1604 setOperationAction(ISD::LRINT, VT, Custom);
1605 setOperationAction(ISD::LLRINT, VT, Custom);
1606 setOperationAction(ISD::FROUND, VT, Custom);
1607 setOperationAction(ISD::FROUNDEVEN, VT, Custom);
1608 setOperationAction(ISD::FTRUNC, VT, Custom);
1609 setOperationAction(ISD::FSQRT, VT, Custom);
1610 setOperationAction(ISD::FABS, VT, Custom);
1611 setOperationAction(ISD::FP_EXTEND, VT, Custom);
1612 setOperationAction(ISD::FP_ROUND, VT, Custom);
1613 setOperationAction(ISD::VECREDUCE_FADD, VT, Custom);
1614 setOperationAction(ISD::VECREDUCE_FMAX, VT, Custom);
1615 setOperationAction(ISD::VECREDUCE_FMIN, VT, Custom);
1616 setOperationAction(ISD::VECREDUCE_FMAXIMUM, VT, Custom);
1617 setOperationAction(ISD::VECREDUCE_FMINIMUM, VT, Custom);
1618 setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
1619 setOperationAction(ISD::VECTOR_DEINTERLEAVE, VT, Custom);
1620 setOperationAction(ISD::VECTOR_INTERLEAVE, VT, Custom);
1621
1622 setOperationAction(ISD::SELECT_CC, VT, Expand);
1623 setOperationAction(ISD::FREM, VT, Expand);
1624 setOperationAction(ISD::FPOW, VT, Expand);
1625 setOperationAction(ISD::FPOWI, VT, Expand);
1626 setOperationAction(ISD::FCOS, VT, Expand);
1627 setOperationAction(ISD::FSIN, VT, Expand);
1628 setOperationAction(ISD::FSINCOS, VT, Expand);
1629 setOperationAction(ISD::FTAN, VT, Expand);
1630 setOperationAction(ISD::FACOS, VT, Expand);
1631 setOperationAction(ISD::FASIN, VT, Expand);
1632 setOperationAction(ISD::FATAN, VT, Expand);
1633 setOperationAction(ISD::FCOSH, VT, Expand);
1634 setOperationAction(ISD::FSINH, VT, Expand);
1635 setOperationAction(ISD::FTANH, VT, Expand);
1636 setOperationAction(ISD::FEXP, VT, Expand);
1637 setOperationAction(ISD::FEXP2, VT, Expand);
1638 setOperationAction(ISD::FEXP10, VT, Expand);
1639 setOperationAction(ISD::FLOG, VT, Expand);
1640 setOperationAction(ISD::FLOG2, VT, Expand);
1641 setOperationAction(ISD::FLOG10, VT, Expand);
1642
1643 setCondCodeAction(ISD::SETO, VT, Expand);
1644 setCondCodeAction(ISD::SETOLT, VT, Expand);
1645 setCondCodeAction(ISD::SETLT, VT, Expand);
1646 setCondCodeAction(ISD::SETOLE, VT, Expand);
1647 setCondCodeAction(ISD::SETLE, VT, Expand);
1648 setCondCodeAction(ISD::SETULT, VT, Expand);
1649 setCondCodeAction(ISD::SETULE, VT, Expand);
1650 setCondCodeAction(ISD::SETUGE, VT, Expand);
1651 setCondCodeAction(ISD::SETUGT, VT, Expand);
1652 setCondCodeAction(ISD::SETUEQ, VT, Expand);
1653 setCondCodeAction(ISD::SETONE, VT, Expand);
1654
1655 if (!Subtarget->isLittleEndian())
1656 setOperationAction(ISD::BITCAST, VT, Expand);
1657 }
1658
1659 for (auto VT : {MVT::nxv2bf16, MVT::nxv4bf16, MVT::nxv8bf16}) {
1660 setOperationAction(ISD::CONCAT_VECTORS, VT, Custom);
1661 setOperationAction(ISD::MLOAD, VT, Custom);
1662 setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
1663 setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
1664 setOperationAction(ISD::VECTOR_SPLICE, VT, Custom);
1665
1666 if (!Subtarget->isLittleEndian())
1667 setOperationAction(ISD::BITCAST, VT, Expand);
1668 }
1669
1670 setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i8, Custom);
1671 setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::i16, Custom);
1672
1673 // NEON doesn't support integer divides, but SVE does
1674 for (auto VT : {MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16, MVT::v2i32,
1675 MVT::v4i32, MVT::v1i64, MVT::v2i64}) {
1676 setOperationAction(ISD::SDIV, VT, Custom);
1677 setOperationAction(ISD::UDIV, VT, Custom);
1678 }
1679
1680 // NEON doesn't support 64-bit vector integer muls, but SVE does.
1681 setOperationAction(ISD::MUL, MVT::v1i64, Custom);
1682 setOperationAction(ISD::MUL, MVT::v2i64, Custom);
1683
1684 // NOTE: Currently this has to happen after computeRegisterProperties rather
1685 // than the preferred option of combining it with the addRegisterClass call.
1686 if (Subtarget->useSVEForFixedLengthVectors()) {
1687 for (MVT VT : MVT::integer_fixedlen_vector_valuetypes()) {
1688 if (useSVEForFixedLengthVectorVT(
1689 VT, /*OverrideNEON=*/!Subtarget->isNeonAvailable()))
1690 addTypeForFixedLengthSVE(VT);
1691 }
1692 for (MVT VT : MVT::fp_fixedlen_vector_valuetypes()) {
1693 if (useSVEForFixedLengthVectorVT(
1694 VT, /*OverrideNEON=*/!Subtarget->isNeonAvailable()))
1695 addTypeForFixedLengthSVE(VT);
1696 }
1697
1698 // 64bit results can mean a bigger than NEON input.
1699 for (auto VT : {MVT::v8i8, MVT::v4i16})
1700 setOperationAction(ISD::TRUNCATE, VT, Custom);
1701 setOperationAction(ISD::FP_ROUND, MVT::v4f16, Custom);
1702
1703 // 128bit results imply a bigger than NEON input.
1704 for (auto VT : {MVT::v16i8, MVT::v8i16, MVT::v4i32})
1705 setOperationAction(ISD::TRUNCATE, VT, Custom);
1706 for (auto VT : {MVT::v8f16, MVT::v4f32})
1707 setOperationAction(ISD::FP_ROUND, VT, Custom);
1708
1709 // These operations are not supported on NEON but SVE can do them.
1710 setOperationAction(ISD::BITREVERSE, MVT::v1i64, Custom);
1711 setOperationAction(ISD::CTLZ, MVT::v1i64, Custom);
1712 setOperationAction(ISD::CTLZ, MVT::v2i64, Custom);
1713 setOperationAction(ISD::CTTZ, MVT::v1i64, Custom);
1714 setOperationAction(ISD::MULHS, MVT::v1i64, Custom);
1715 setOperationAction(ISD::MULHS, MVT::v2i64, Custom);
1716 setOperationAction(ISD::MULHU, MVT::v1i64, Custom);
1717 setOperationAction(ISD::MULHU, MVT::v2i64, Custom);
1718 setOperationAction(ISD::SMAX, MVT::v1i64, Custom);
1719 setOperationAction(ISD::SMAX, MVT::v2i64, Custom);
1720 setOperationAction(ISD::SMIN, MVT::v1i64, Custom);
1721 setOperationAction(ISD::SMIN, MVT::v2i64, Custom);
1722 setOperationAction(ISD::UMAX, MVT::v1i64, Custom);
1723 setOperationAction(ISD::UMAX, MVT::v2i64, Custom);
1724 setOperationAction(ISD::UMIN, MVT::v1i64, Custom);
1725 setOperationAction(ISD::UMIN, MVT::v2i64, Custom);
1726 setOperationAction(ISD::VECREDUCE_SMAX, MVT::v2i64, Custom);
1727 setOperationAction(ISD::VECREDUCE_SMIN, MVT::v2i64, Custom);
1728 setOperationAction(ISD::VECREDUCE_UMAX, MVT::v2i64, Custom);
1729 setOperationAction(ISD::VECREDUCE_UMIN, MVT::v2i64, Custom);
1730
1731 // Int operations with no NEON support.
1732 for (auto VT : {MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16,
1733 MVT::v2i32, MVT::v4i32, MVT::v2i64}) {
1734 setOperationAction(ISD::BITREVERSE, VT, Custom);
1735 setOperationAction(ISD::CTTZ, VT, Custom);
1736 setOperationAction(ISD::VECREDUCE_AND, VT, Custom);
1737 setOperationAction(ISD::VECREDUCE_OR, VT, Custom);
1738 setOperationAction(ISD::VECREDUCE_XOR, VT, Custom);
1739 setOperationAction(ISD::MULHS, VT, Custom);
1740 setOperationAction(ISD::MULHU, VT, Custom);
1741 }
1742
1743 // Use SVE for vectors with more than 2 elements.
1744 for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v4f32})
1745 setOperationAction(ISD::VECREDUCE_FADD, VT, Custom);
1746 }
1747
1748 setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv2i1, MVT::nxv2i64);
1749 setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv4i1, MVT::nxv4i32);
1750 setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv8i1, MVT::nxv8i16);
1751 setOperationPromotedToType(ISD::VECTOR_SPLICE, MVT::nxv16i1, MVT::nxv16i8);
1752
1753 setOperationAction(ISD::VSCALE, MVT::i32, Custom);
1754
1755 for (auto VT : {MVT::v16i1, MVT::v8i1, MVT::v4i1, MVT::v2i1})
1756 setOperationAction(ISD::INTRINSIC_WO_CHAIN, VT, Custom);
1757 }
1758
1759 // Handle operations that are only available in non-streaming SVE mode.
1760 if (Subtarget->isSVEAvailable()) {
1761 for (auto VT : {MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, MVT::nxv2i64,
1762 MVT::nxv2f16, MVT::nxv4f16, MVT::nxv8f16, MVT::nxv2f32,
1763 MVT::nxv4f32, MVT::nxv2f64, MVT::nxv2bf16, MVT::nxv4bf16,
1764 MVT::nxv8bf16, MVT::v4f16, MVT::v8f16, MVT::v2f32,
1765 MVT::v4f32, MVT::v1f64, MVT::v2f64, MVT::v8i8,
1766 MVT::v16i8, MVT::v4i16, MVT::v8i16, MVT::v2i32,
1767 MVT::v4i32, MVT::v1i64, MVT::v2i64}) {
1768 setOperationAction(ISD::MGATHER, VT, Custom);
1769 setOperationAction(ISD::MSCATTER, VT, Custom);
1770 }
1771
1772 for (auto VT : {MVT::nxv2f16, MVT::nxv4f16, MVT::nxv8f16, MVT::nxv2f32,
1773 MVT::nxv4f32, MVT::nxv2f64, MVT::v4f16, MVT::v8f16,
1774 MVT::v2f32, MVT::v4f32, MVT::v2f64})
1775 setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);
1776
1777 // Histcnt is SVE2 only
1778 if (Subtarget->hasSVE2())
1779 setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::Other,
1780 Custom);
1781 }
1782
1783
1784 if (Subtarget->hasMOPS() && Subtarget->hasMTE()) {
1785 // Only required for llvm.aarch64.mops.memset.tag
1786 setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
1787 }
1788
1789 setOperationAction(ISD::INTRINSIC_VOID, MVT::Other, Custom);
1790
1791 if (Subtarget->hasSVE()) {
1792 setOperationAction(ISD::FLDEXP, MVT::f64, Custom);
1793 setOperationAction(ISD::FLDEXP, MVT::f32, Custom);
1794 setOperationAction(ISD::FLDEXP, MVT::f16, Custom);
1795 setOperationAction(ISD::FLDEXP, MVT::bf16, Custom);
1796 }
1797
1798 PredictableSelectIsExpensive = Subtarget->predictableSelectIsExpensive();
1799
1800 IsStrictFPEnabled = true;
1801 setMaxAtomicSizeInBitsSupported(128);
1802
1803 // On MSVC, both 32-bit and 64-bit, ldexpf(f32) is not defined. MinGW has
1804 // it, but it's just a wrapper around ldexp.
1805 if (Subtarget->isTargetWindows()) {
1806 for (ISD::NodeType Op : {ISD::FLDEXP, ISD::STRICT_FLDEXP, ISD::FFREXP})
1807 if (isOperationExpand(Op, MVT::f32))
1808 setOperationAction(Op, MVT::f32, Promote);
1809 }
1810
1811 // LegalizeDAG currently can't expand fp16 LDEXP/FREXP on targets where i16
1812 // isn't legal.
1813 for (ISD::NodeType Op : {ISD::FLDEXP, ISD::STRICT_FLDEXP, ISD::FFREXP})
1814 if (isOperationExpand(Op, MVT::f16))
1815 setOperationAction(Op, MVT::f16, Promote);
1816
1817 if (Subtarget->isWindowsArm64EC()) {
1818 // FIXME: are there intrinsics we need to exclude from this?
1819 for (int i = 0; i < RTLIB::UNKNOWN_LIBCALL; ++i) {
1820 auto code = static_cast<RTLIB::Libcall>(i);
1821 auto libcallName = getLibcallName(code);
1822 if ((libcallName != nullptr) && (libcallName[0] != '#')) {
1823 setLibcallName(code, Saver.save(Twine("#") + libcallName).data());
1824 }
1825 }
1826 }
1827 }
1828
addTypeForNEON(MVT VT)1829 void AArch64TargetLowering::addTypeForNEON(MVT VT) {
1830 assert(VT.isVector() && "VT should be a vector type");
1831
1832 if (VT.isFloatingPoint()) {
1833 MVT PromoteTo = EVT(VT).changeVectorElementTypeToInteger().getSimpleVT();
1834 setOperationPromotedToType(ISD::LOAD, VT, PromoteTo);
1835 setOperationPromotedToType(ISD::STORE, VT, PromoteTo);
1836 }
1837
1838 // Mark vector float intrinsics as expand.
1839 if (VT == MVT::v2f32 || VT == MVT::v4f32 || VT == MVT::v2f64) {
1840 setOperationAction(ISD::FSIN, VT, Expand);
1841 setOperationAction(ISD::FCOS, VT, Expand);
1842 setOperationAction(ISD::FTAN, VT, Expand);
1843 setOperationAction(ISD::FASIN, VT, Expand);
1844 setOperationAction(ISD::FACOS, VT, Expand);
1845 setOperationAction(ISD::FATAN, VT, Expand);
1846 setOperationAction(ISD::FSINH, VT, Expand);
1847 setOperationAction(ISD::FCOSH, VT, Expand);
1848 setOperationAction(ISD::FTANH, VT, Expand);
1849 setOperationAction(ISD::FPOW, VT, Expand);
1850 setOperationAction(ISD::FLOG, VT, Expand);
1851 setOperationAction(ISD::FLOG2, VT, Expand);
1852 setOperationAction(ISD::FLOG10, VT, Expand);
1853 setOperationAction(ISD::FEXP, VT, Expand);
1854 setOperationAction(ISD::FEXP2, VT, Expand);
1855 setOperationAction(ISD::FEXP10, VT, Expand);
1856 }
1857
1858 // But we do support custom-lowering for FCOPYSIGN.
1859 if (VT == MVT::v2f32 || VT == MVT::v4f32 || VT == MVT::v2f64 ||
1860 ((VT == MVT::v4bf16 || VT == MVT::v8bf16 || VT == MVT::v4f16 ||
1861 VT == MVT::v8f16) &&
1862 Subtarget->hasFullFP16()))
1863 setOperationAction(ISD::FCOPYSIGN, VT, Custom);
1864
1865 setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
1866 setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
1867 setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
1868 setOperationAction(ISD::ZERO_EXTEND_VECTOR_INREG, VT, Custom);
1869 setOperationAction(ISD::VECTOR_SHUFFLE, VT, Custom);
1870 setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
1871 setOperationAction(ISD::SRA, VT, Custom);
1872 setOperationAction(ISD::SRL, VT, Custom);
1873 setOperationAction(ISD::SHL, VT, Custom);
1874 setOperationAction(ISD::OR, VT, Custom);
1875 setOperationAction(ISD::SETCC, VT, Custom);
1876 setOperationAction(ISD::CONCAT_VECTORS, VT, Legal);
1877
1878 setOperationAction(ISD::SELECT, VT, Expand);
1879 setOperationAction(ISD::SELECT_CC, VT, Expand);
1880 setOperationAction(ISD::VSELECT, VT, Expand);
1881 for (MVT InnerVT : MVT::all_valuetypes())
1882 setLoadExtAction(ISD::EXTLOAD, InnerVT, VT, Expand);
1883
1884 // CNT supports only B element sizes, then use UADDLP to widen.
1885 if (VT != MVT::v8i8 && VT != MVT::v16i8)
1886 setOperationAction(ISD::CTPOP, VT, Custom);
1887
1888 setOperationAction(ISD::UDIV, VT, Expand);
1889 setOperationAction(ISD::SDIV, VT, Expand);
1890 setOperationAction(ISD::UREM, VT, Expand);
1891 setOperationAction(ISD::SREM, VT, Expand);
1892 setOperationAction(ISD::FREM, VT, Expand);
1893
1894 for (unsigned Opcode :
1895 {ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
1896 ISD::FP_TO_UINT_SAT, ISD::STRICT_FP_TO_SINT, ISD::STRICT_FP_TO_UINT})
1897 setOperationAction(Opcode, VT, Custom);
1898
1899 if (!VT.isFloatingPoint())
1900 setOperationAction(ISD::ABS, VT, Legal);
1901
1902 // [SU][MIN|MAX] are available for all NEON types apart from i64.
1903 if (!VT.isFloatingPoint() && VT != MVT::v2i64 && VT != MVT::v1i64)
1904 for (unsigned Opcode : {ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX})
1905 setOperationAction(Opcode, VT, Legal);
1906
1907 // F[MIN|MAX][NUM|NAN] and simple strict operations are available for all FP
1908 // NEON types.
1909 if (VT.isFloatingPoint() &&
1910 VT.getVectorElementType() != MVT::bf16 &&
1911 (VT.getVectorElementType() != MVT::f16 || Subtarget->hasFullFP16()))
1912 for (unsigned Opcode :
1913 {ISD::FMINIMUM, ISD::FMAXIMUM, ISD::FMINNUM, ISD::FMAXNUM,
1914 ISD::STRICT_FMINIMUM, ISD::STRICT_FMAXIMUM, ISD::STRICT_FMINNUM,
1915 ISD::STRICT_FMAXNUM, ISD::STRICT_FADD, ISD::STRICT_FSUB,
1916 ISD::STRICT_FMUL, ISD::STRICT_FDIV, ISD::STRICT_FMA,
1917 ISD::STRICT_FSQRT})
1918 setOperationAction(Opcode, VT, Legal);
1919
1920 // Strict fp extend and trunc are legal
1921 if (VT.isFloatingPoint() && VT.getScalarSizeInBits() != 16)
1922 setOperationAction(ISD::STRICT_FP_EXTEND, VT, Legal);
1923 if (VT.isFloatingPoint() && VT.getScalarSizeInBits() != 64)
1924 setOperationAction(ISD::STRICT_FP_ROUND, VT, Legal);
1925
1926 // FIXME: We could potentially make use of the vector comparison instructions
1927 // for STRICT_FSETCC and STRICT_FSETCSS, but there's a number of
1928 // complications:
1929 // * FCMPEQ/NE are quiet comparisons, the rest are signalling comparisons,
1930 // so we would need to expand when the condition code doesn't match the
1931 // kind of comparison.
1932 // * Some kinds of comparison require more than one FCMXY instruction so
1933 // would need to be expanded instead.
1934 // * The lowering of the non-strict versions involves target-specific ISD
1935 // nodes so we would likely need to add strict versions of all of them and
1936 // handle them appropriately.
1937 setOperationAction(ISD::STRICT_FSETCC, VT, Expand);
1938 setOperationAction(ISD::STRICT_FSETCCS, VT, Expand);
1939
1940 if (Subtarget->isLittleEndian()) {
1941 for (unsigned im = (unsigned)ISD::PRE_INC;
1942 im != (unsigned)ISD::LAST_INDEXED_MODE; ++im) {
1943 setIndexedLoadAction(im, VT, Legal);
1944 setIndexedStoreAction(im, VT, Legal);
1945 }
1946 }
1947
1948 if (Subtarget->hasD128()) {
1949 setOperationAction(ISD::READ_REGISTER, MVT::i128, Custom);
1950 setOperationAction(ISD::WRITE_REGISTER, MVT::i128, Custom);
1951 }
1952 }
1953
shouldExpandGetActiveLaneMask(EVT ResVT,EVT OpVT) const1954 bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
1955 EVT OpVT) const {
1956 // Only SVE has a 1:1 mapping from intrinsic -> instruction (whilelo).
1957 if (!Subtarget->hasSVE())
1958 return true;
1959
1960 // We can only support legal predicate result types. We can use the SVE
1961 // whilelo instruction for generating fixed-width predicates too.
1962 if (ResVT != MVT::nxv2i1 && ResVT != MVT::nxv4i1 && ResVT != MVT::nxv8i1 &&
1963 ResVT != MVT::nxv16i1 && ResVT != MVT::v2i1 && ResVT != MVT::v4i1 &&
1964 ResVT != MVT::v8i1 && ResVT != MVT::v16i1)
1965 return true;
1966
1967 // The whilelo instruction only works with i32 or i64 scalar inputs.
1968 if (OpVT != MVT::i32 && OpVT != MVT::i64)
1969 return true;
1970
1971 return false;
1972 }
1973
shouldExpandCttzElements(EVT VT) const1974 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
1975 if (!Subtarget->isSVEorStreamingSVEAvailable())
1976 return true;
1977
1978 // We can only use the BRKB + CNTP sequence with legal predicate types. We can
1979 // also support fixed-width predicates.
1980 return VT != MVT::nxv16i1 && VT != MVT::nxv8i1 && VT != MVT::nxv4i1 &&
1981 VT != MVT::nxv2i1 && VT != MVT::v16i1 && VT != MVT::v8i1 &&
1982 VT != MVT::v4i1 && VT != MVT::v2i1;
1983 }
1984
addTypeForFixedLengthSVE(MVT VT)1985 void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
1986 assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
1987
1988 // By default everything must be expanded.
1989 for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op)
1990 setOperationAction(Op, VT, Expand);
1991
1992 if (VT.isFloatingPoint()) {
1993 setCondCodeAction(ISD::SETO, VT, Expand);
1994 setCondCodeAction(ISD::SETOLT, VT, Expand);
1995 setCondCodeAction(ISD::SETOLE, VT, Expand);
1996 setCondCodeAction(ISD::SETULT, VT, Expand);
1997 setCondCodeAction(ISD::SETULE, VT, Expand);
1998 setCondCodeAction(ISD::SETUGE, VT, Expand);
1999 setCondCodeAction(ISD::SETUGT, VT, Expand);
2000 setCondCodeAction(ISD::SETUEQ, VT, Expand);
2001 setCondCodeAction(ISD::SETONE, VT, Expand);
2002 }
2003
2004 TargetLoweringBase::LegalizeAction Default =
2005 VT == MVT::v1f64 ? Expand : Custom;
2006
2007 // Mark integer truncating stores/extending loads as having custom lowering
2008 if (VT.isInteger()) {
2009 MVT InnerVT = VT.changeVectorElementType(MVT::i8);
2010 while (InnerVT != VT) {
2011 setTruncStoreAction(VT, InnerVT, Default);
2012 setLoadExtAction(ISD::ZEXTLOAD, VT, InnerVT, Default);
2013 setLoadExtAction(ISD::SEXTLOAD, VT, InnerVT, Default);
2014 setLoadExtAction(ISD::EXTLOAD, VT, InnerVT, Default);
2015 InnerVT = InnerVT.changeVectorElementType(
2016 MVT::getIntegerVT(2 * InnerVT.getScalarSizeInBits()));
2017 }
2018 }
2019
2020 // Mark floating-point truncating stores/extending loads as having custom
2021 // lowering
2022 if (VT.isFloatingPoint()) {
2023 MVT InnerVT = VT.changeVectorElementType(MVT::f16);
2024 while (InnerVT != VT) {
2025 setTruncStoreAction(VT, InnerVT, Custom);
2026 setLoadExtAction(ISD::EXTLOAD, VT, InnerVT, Default);
2027 InnerVT = InnerVT.changeVectorElementType(
2028 MVT::getFloatingPointVT(2 * InnerVT.getScalarSizeInBits()));
2029 }
2030 }
2031
2032 bool PreferNEON = VT.is64BitVector() || VT.is128BitVector();
2033 bool PreferSVE = !PreferNEON && Subtarget->isSVEAvailable();
2034
2035 // Lower fixed length vector operations to scalable equivalents.
2036 setOperationAction(ISD::ABS, VT, Default);
2037 setOperationAction(ISD::ADD, VT, Default);
2038 setOperationAction(ISD::AND, VT, Default);
2039 setOperationAction(ISD::ANY_EXTEND, VT, Default);
2040 setOperationAction(ISD::BITCAST, VT, PreferNEON ? Legal : Default);
2041 setOperationAction(ISD::BITREVERSE, VT, Default);
2042 setOperationAction(ISD::BSWAP, VT, Default);
2043 setOperationAction(ISD::BUILD_VECTOR, VT, Default);
2044 setOperationAction(ISD::CONCAT_VECTORS, VT, Default);
2045 setOperationAction(ISD::CTLZ, VT, Default);
2046 setOperationAction(ISD::CTPOP, VT, Default);
2047 setOperationAction(ISD::CTTZ, VT, Default);
2048 setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Default);
2049 setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Default);
2050 setOperationAction(ISD::FABS, VT, Default);
2051 setOperationAction(ISD::FADD, VT, Default);
2052 setOperationAction(ISD::FCEIL, VT, Default);
2053 setOperationAction(ISD::FCOPYSIGN, VT, Default);
2054 setOperationAction(ISD::FDIV, VT, Default);
2055 setOperationAction(ISD::FFLOOR, VT, Default);
2056 setOperationAction(ISD::FMA, VT, Default);
2057 setOperationAction(ISD::FMAXIMUM, VT, Default);
2058 setOperationAction(ISD::FMAXNUM, VT, Default);
2059 setOperationAction(ISD::FMINIMUM, VT, Default);
2060 setOperationAction(ISD::FMINNUM, VT, Default);
2061 setOperationAction(ISD::FMUL, VT, Default);
2062 setOperationAction(ISD::FNEARBYINT, VT, Default);
2063 setOperationAction(ISD::FNEG, VT, Default);
2064 setOperationAction(ISD::FP_EXTEND, VT, Default);
2065 setOperationAction(ISD::FP_ROUND, VT, Default);
2066 setOperationAction(ISD::FP_TO_SINT, VT, Default);
2067 setOperationAction(ISD::FP_TO_UINT, VT, Default);
2068 setOperationAction(ISD::FRINT, VT, Default);
2069 setOperationAction(ISD::LRINT, VT, Default);
2070 setOperationAction(ISD::LLRINT, VT, Default);
2071 setOperationAction(ISD::FROUND, VT, Default);
2072 setOperationAction(ISD::FROUNDEVEN, VT, Default);
2073 setOperationAction(ISD::FSQRT, VT, Default);
2074 setOperationAction(ISD::FSUB, VT, Default);
2075 setOperationAction(ISD::FTRUNC, VT, Default);
2076 setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Default);
2077 setOperationAction(ISD::LOAD, VT, PreferNEON ? Legal : Default);
2078 setOperationAction(ISD::MGATHER, VT, PreferSVE ? Default : Expand);
2079 setOperationAction(ISD::MLOAD, VT, Default);
2080 setOperationAction(ISD::MSCATTER, VT, PreferSVE ? Default : Expand);
2081 setOperationAction(ISD::MSTORE, VT, Default);
2082 setOperationAction(ISD::MUL, VT, Default);
2083 setOperationAction(ISD::MULHS, VT, Default);
2084 setOperationAction(ISD::MULHU, VT, Default);
2085 setOperationAction(ISD::OR, VT, Default);
2086 setOperationAction(ISD::SCALAR_TO_VECTOR, VT, PreferNEON ? Legal : Expand);
2087 setOperationAction(ISD::SDIV, VT, Default);
2088 setOperationAction(ISD::SELECT, VT, Default);
2089 setOperationAction(ISD::SETCC, VT, Default);
2090 setOperationAction(ISD::SHL, VT, Default);
2091 setOperationAction(ISD::SIGN_EXTEND, VT, Default);
2092 setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Default);
2093 setOperationAction(ISD::SINT_TO_FP, VT, Default);
2094 setOperationAction(ISD::SMAX, VT, Default);
2095 setOperationAction(ISD::SMIN, VT, Default);
2096 setOperationAction(ISD::SPLAT_VECTOR, VT, Default);
2097 setOperationAction(ISD::SRA, VT, Default);
2098 setOperationAction(ISD::SRL, VT, Default);
2099 setOperationAction(ISD::STORE, VT, PreferNEON ? Legal : Default);
2100 setOperationAction(ISD::SUB, VT, Default);
2101 setOperationAction(ISD::TRUNCATE, VT, Default);
2102 setOperationAction(ISD::UDIV, VT, Default);
2103 setOperationAction(ISD::UINT_TO_FP, VT, Default);
2104 setOperationAction(ISD::UMAX, VT, Default);
2105 setOperationAction(ISD::UMIN, VT, Default);
2106 setOperationAction(ISD::VECREDUCE_ADD, VT, Default);
2107 setOperationAction(ISD::VECREDUCE_AND, VT, Default);
2108 setOperationAction(ISD::VECREDUCE_FADD, VT, Default);
2109 setOperationAction(ISD::VECREDUCE_FMAX, VT, Default);
2110 setOperationAction(ISD::VECREDUCE_FMIN, VT, Default);
2111 setOperationAction(ISD::VECREDUCE_FMAXIMUM, VT, Default);
2112 setOperationAction(ISD::VECREDUCE_FMINIMUM, VT, Default);
2113 setOperationAction(ISD::VECREDUCE_OR, VT, Default);
2114 setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, PreferSVE ? Default : Expand);
2115 setOperationAction(ISD::VECREDUCE_SMAX, VT, Default);
2116 setOperationAction(ISD::VECREDUCE_SMIN, VT, Default);
2117 setOperationAction(ISD::VECREDUCE_UMAX, VT, Default);
2118 setOperationAction(ISD::VECREDUCE_UMIN, VT, Default);
2119 setOperationAction(ISD::VECREDUCE_XOR, VT, Default);
2120 setOperationAction(ISD::VECTOR_SHUFFLE, VT, Default);
2121 setOperationAction(ISD::VECTOR_SPLICE, VT, Default);
2122 setOperationAction(ISD::VSELECT, VT, Default);
2123 setOperationAction(ISD::XOR, VT, Default);
2124 setOperationAction(ISD::ZERO_EXTEND, VT, Default);
2125 }
2126
addDRType(MVT VT)2127 void AArch64TargetLowering::addDRType(MVT VT) {
2128 addRegisterClass(VT, &AArch64::FPR64RegClass);
2129 if (Subtarget->isNeonAvailable())
2130 addTypeForNEON(VT);
2131 }
2132
addQRType(MVT VT)2133 void AArch64TargetLowering::addQRType(MVT VT) {
2134 addRegisterClass(VT, &AArch64::FPR128RegClass);
2135 if (Subtarget->isNeonAvailable())
2136 addTypeForNEON(VT);
2137 }
2138
getSetCCResultType(const DataLayout &,LLVMContext & C,EVT VT) const2139 EVT AArch64TargetLowering::getSetCCResultType(const DataLayout &,
2140 LLVMContext &C, EVT VT) const {
2141 if (!VT.isVector())
2142 return MVT::i32;
2143 if (VT.isScalableVector())
2144 return EVT::getVectorVT(C, MVT::i1, VT.getVectorElementCount());
2145 return VT.changeVectorElementTypeToInteger();
2146 }
2147
2148 // isIntImmediate - This method tests to see if the node is a constant
2149 // operand. If so Imm will receive the value.
isIntImmediate(const SDNode * N,uint64_t & Imm)2150 static bool isIntImmediate(const SDNode *N, uint64_t &Imm) {
2151 if (const ConstantSDNode *C = dyn_cast<const ConstantSDNode>(N)) {
2152 Imm = C->getZExtValue();
2153 return true;
2154 }
2155 return false;
2156 }
2157
2158 // isOpcWithIntImmediate - This method tests to see if the node is a specific
2159 // opcode and that it has a immediate integer right operand.
2160 // If so Imm will receive the value.
isOpcWithIntImmediate(const SDNode * N,unsigned Opc,uint64_t & Imm)2161 static bool isOpcWithIntImmediate(const SDNode *N, unsigned Opc,
2162 uint64_t &Imm) {
2163 return N->getOpcode() == Opc &&
2164 isIntImmediate(N->getOperand(1).getNode(), Imm);
2165 }
2166
optimizeLogicalImm(SDValue Op,unsigned Size,uint64_t Imm,const APInt & Demanded,TargetLowering::TargetLoweringOpt & TLO,unsigned NewOpc)2167 static bool optimizeLogicalImm(SDValue Op, unsigned Size, uint64_t Imm,
2168 const APInt &Demanded,
2169 TargetLowering::TargetLoweringOpt &TLO,
2170 unsigned NewOpc) {
2171 uint64_t OldImm = Imm, NewImm, Enc;
2172 uint64_t Mask = ((uint64_t)(-1LL) >> (64 - Size)), OrigMask = Mask;
2173
2174 // Return if the immediate is already all zeros, all ones, a bimm32 or a
2175 // bimm64.
2176 if (Imm == 0 || Imm == Mask ||
2177 AArch64_AM::isLogicalImmediate(Imm & Mask, Size))
2178 return false;
2179
2180 unsigned EltSize = Size;
2181 uint64_t DemandedBits = Demanded.getZExtValue();
2182
2183 // Clear bits that are not demanded.
2184 Imm &= DemandedBits;
2185
2186 while (true) {
2187 // The goal here is to set the non-demanded bits in a way that minimizes
2188 // the number of switching between 0 and 1. In order to achieve this goal,
2189 // we set the non-demanded bits to the value of the preceding demanded bits.
2190 // For example, if we have an immediate 0bx10xx0x1 ('x' indicates a
2191 // non-demanded bit), we copy bit0 (1) to the least significant 'x',
2192 // bit2 (0) to 'xx', and bit6 (1) to the most significant 'x'.
2193 // The final result is 0b11000011.
2194 uint64_t NonDemandedBits = ~DemandedBits;
2195 uint64_t InvertedImm = ~Imm & DemandedBits;
2196 uint64_t RotatedImm =
2197 ((InvertedImm << 1) | (InvertedImm >> (EltSize - 1) & 1)) &
2198 NonDemandedBits;
2199 uint64_t Sum = RotatedImm + NonDemandedBits;
2200 bool Carry = NonDemandedBits & ~Sum & (1ULL << (EltSize - 1));
2201 uint64_t Ones = (Sum + Carry) & NonDemandedBits;
2202 NewImm = (Imm | Ones) & Mask;
2203
2204 // If NewImm or its bitwise NOT is a shifted mask, it is a bitmask immediate
2205 // or all-ones or all-zeros, in which case we can stop searching. Otherwise,
2206 // we halve the element size and continue the search.
2207 if (isShiftedMask_64(NewImm) || isShiftedMask_64(~(NewImm | ~Mask)))
2208 break;
2209
2210 // We cannot shrink the element size any further if it is 2-bits.
2211 if (EltSize == 2)
2212 return false;
2213
2214 EltSize /= 2;
2215 Mask >>= EltSize;
2216 uint64_t Hi = Imm >> EltSize, DemandedBitsHi = DemandedBits >> EltSize;
2217
2218 // Return if there is mismatch in any of the demanded bits of Imm and Hi.
2219 if (((Imm ^ Hi) & (DemandedBits & DemandedBitsHi) & Mask) != 0)
2220 return false;
2221
2222 // Merge the upper and lower halves of Imm and DemandedBits.
2223 Imm |= Hi;
2224 DemandedBits |= DemandedBitsHi;
2225 }
2226
2227 ++NumOptimizedImms;
2228
2229 // Replicate the element across the register width.
2230 while (EltSize < Size) {
2231 NewImm |= NewImm << EltSize;
2232 EltSize *= 2;
2233 }
2234
2235 (void)OldImm;
2236 assert(((OldImm ^ NewImm) & Demanded.getZExtValue()) == 0 &&
2237 "demanded bits should never be altered");
2238 assert(OldImm != NewImm && "the new imm shouldn't be equal to the old imm");
2239
2240 // Create the new constant immediate node.
2241 EVT VT = Op.getValueType();
2242 SDLoc DL(Op);
2243 SDValue New;
2244
2245 // If the new constant immediate is all-zeros or all-ones, let the target
2246 // independent DAG combine optimize this node.
2247 if (NewImm == 0 || NewImm == OrigMask) {
2248 New = TLO.DAG.getNode(Op.getOpcode(), DL, VT, Op.getOperand(0),
2249 TLO.DAG.getConstant(NewImm, DL, VT));
2250 // Otherwise, create a machine node so that target independent DAG combine
2251 // doesn't undo this optimization.
2252 } else {
2253 Enc = AArch64_AM::encodeLogicalImmediate(NewImm, Size);
2254 SDValue EncConst = TLO.DAG.getTargetConstant(Enc, DL, VT);
2255 New = SDValue(
2256 TLO.DAG.getMachineNode(NewOpc, DL, VT, Op.getOperand(0), EncConst), 0);
2257 }
2258
2259 return TLO.CombineTo(Op, New);
2260 }
2261
targetShrinkDemandedConstant(SDValue Op,const APInt & DemandedBits,const APInt & DemandedElts,TargetLoweringOpt & TLO) const2262 bool AArch64TargetLowering::targetShrinkDemandedConstant(
2263 SDValue Op, const APInt &DemandedBits, const APInt &DemandedElts,
2264 TargetLoweringOpt &TLO) const {
2265 // Delay this optimization to as late as possible.
2266 if (!TLO.LegalOps)
2267 return false;
2268
2269 if (!EnableOptimizeLogicalImm)
2270 return false;
2271
2272 EVT VT = Op.getValueType();
2273 if (VT.isVector())
2274 return false;
2275
2276 unsigned Size = VT.getSizeInBits();
2277 assert((Size == 32 || Size == 64) &&
2278 "i32 or i64 is expected after legalization.");
2279
2280 // Exit early if we demand all bits.
2281 if (DemandedBits.popcount() == Size)
2282 return false;
2283
2284 unsigned NewOpc;
2285 switch (Op.getOpcode()) {
2286 default:
2287 return false;
2288 case ISD::AND:
2289 NewOpc = Size == 32 ? AArch64::ANDWri : AArch64::ANDXri;
2290 break;
2291 case ISD::OR:
2292 NewOpc = Size == 32 ? AArch64::ORRWri : AArch64::ORRXri;
2293 break;
2294 case ISD::XOR:
2295 NewOpc = Size == 32 ? AArch64::EORWri : AArch64::EORXri;
2296 break;
2297 }
2298 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op.getOperand(1));
2299 if (!C)
2300 return false;
2301 uint64_t Imm = C->getZExtValue();
2302 return optimizeLogicalImm(Op, Size, Imm, DemandedBits, TLO, NewOpc);
2303 }
2304
2305 /// computeKnownBitsForTargetNode - Determine which of the bits specified in
2306 /// Mask are known to be either zero or one and return them Known.
computeKnownBitsForTargetNode(const SDValue Op,KnownBits & Known,const APInt & DemandedElts,const SelectionDAG & DAG,unsigned Depth) const2307 void AArch64TargetLowering::computeKnownBitsForTargetNode(
2308 const SDValue Op, KnownBits &Known, const APInt &DemandedElts,
2309 const SelectionDAG &DAG, unsigned Depth) const {
2310 switch (Op.getOpcode()) {
2311 default:
2312 break;
2313 case AArch64ISD::DUP: {
2314 SDValue SrcOp = Op.getOperand(0);
2315 Known = DAG.computeKnownBits(SrcOp, Depth + 1);
2316 if (SrcOp.getValueSizeInBits() != Op.getScalarValueSizeInBits()) {
2317 assert(SrcOp.getValueSizeInBits() > Op.getScalarValueSizeInBits() &&
2318 "Expected DUP implicit truncation");
2319 Known = Known.trunc(Op.getScalarValueSizeInBits());
2320 }
2321 break;
2322 }
2323 case AArch64ISD::CSEL: {
2324 KnownBits Known2;
2325 Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
2326 Known2 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1);
2327 Known = Known.intersectWith(Known2);
2328 break;
2329 }
2330 case AArch64ISD::BICi: {
2331 // Compute the bit cleared value.
2332 uint64_t Mask =
2333 ~(Op->getConstantOperandVal(1) << Op->getConstantOperandVal(2));
2334 Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
2335 Known &= KnownBits::makeConstant(APInt(Known.getBitWidth(), Mask));
2336 break;
2337 }
2338 case AArch64ISD::VLSHR: {
2339 KnownBits Known2;
2340 Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
2341 Known2 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1);
2342 Known = KnownBits::lshr(Known, Known2);
2343 break;
2344 }
2345 case AArch64ISD::VASHR: {
2346 KnownBits Known2;
2347 Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
2348 Known2 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1);
2349 Known = KnownBits::ashr(Known, Known2);
2350 break;
2351 }
2352 case AArch64ISD::VSHL: {
2353 KnownBits Known2;
2354 Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
2355 Known2 = DAG.computeKnownBits(Op->getOperand(1), Depth + 1);
2356 Known = KnownBits::shl(Known, Known2);
2357 break;
2358 }
2359 case AArch64ISD::MOVI: {
2360 Known = KnownBits::makeConstant(
2361 APInt(Known.getBitWidth(), Op->getConstantOperandVal(0)));
2362 break;
2363 }
2364 case AArch64ISD::LOADgot:
2365 case AArch64ISD::ADDlow: {
2366 if (!Subtarget->isTargetILP32())
2367 break;
2368 // In ILP32 mode all valid pointers are in the low 4GB of the address-space.
2369 Known.Zero = APInt::getHighBitsSet(64, 32);
2370 break;
2371 }
2372 case AArch64ISD::ASSERT_ZEXT_BOOL: {
2373 Known = DAG.computeKnownBits(Op->getOperand(0), Depth + 1);
2374 Known.Zero |= APInt(Known.getBitWidth(), 0xFE);
2375 break;
2376 }
2377 case ISD::INTRINSIC_W_CHAIN: {
2378 Intrinsic::ID IntID =
2379 static_cast<Intrinsic::ID>(Op->getConstantOperandVal(1));
2380 switch (IntID) {
2381 default: return;
2382 case Intrinsic::aarch64_ldaxr:
2383 case Intrinsic::aarch64_ldxr: {
2384 unsigned BitWidth = Known.getBitWidth();
2385 EVT VT = cast<MemIntrinsicSDNode>(Op)->getMemoryVT();
2386 unsigned MemBits = VT.getScalarSizeInBits();
2387 Known.Zero |= APInt::getHighBitsSet(BitWidth, BitWidth - MemBits);
2388 return;
2389 }
2390 }
2391 break;
2392 }
2393 case ISD::INTRINSIC_WO_CHAIN:
2394 case ISD::INTRINSIC_VOID: {
2395 unsigned IntNo = Op.getConstantOperandVal(0);
2396 switch (IntNo) {
2397 default:
2398 break;
2399 case Intrinsic::aarch64_neon_uaddlv: {
2400 MVT VT = Op.getOperand(1).getValueType().getSimpleVT();
2401 unsigned BitWidth = Known.getBitWidth();
2402 if (VT == MVT::v8i8 || VT == MVT::v16i8) {
2403 unsigned Bound = (VT == MVT::v8i8) ? 11 : 12;
2404 assert(BitWidth >= Bound && "Unexpected width!");
2405 APInt Mask = APInt::getHighBitsSet(BitWidth, BitWidth - Bound);
2406 Known.Zero |= Mask;
2407 }
2408 break;
2409 }
2410 case Intrinsic::aarch64_neon_umaxv:
2411 case Intrinsic::aarch64_neon_uminv: {
2412 // Figure out the datatype of the vector operand. The UMINV instruction
2413 // will zero extend the result, so we can mark as known zero all the
2414 // bits larger than the element datatype. 32-bit or larget doesn't need
2415 // this as those are legal types and will be handled by isel directly.
2416 MVT VT = Op.getOperand(1).getValueType().getSimpleVT();
2417 unsigned BitWidth = Known.getBitWidth();
2418 if (VT == MVT::v8i8 || VT == MVT::v16i8) {
2419 assert(BitWidth >= 8 && "Unexpected width!");
2420 APInt Mask = APInt::getHighBitsSet(BitWidth, BitWidth - 8);
2421 Known.Zero |= Mask;
2422 } else if (VT == MVT::v4i16 || VT == MVT::v8i16) {
2423 assert(BitWidth >= 16 && "Unexpected width!");
2424 APInt Mask = APInt::getHighBitsSet(BitWidth, BitWidth - 16);
2425 Known.Zero |= Mask;
2426 }
2427 break;
2428 } break;
2429 }
2430 }
2431 }
2432 }
2433
ComputeNumSignBitsForTargetNode(SDValue Op,const APInt & DemandedElts,const SelectionDAG & DAG,unsigned Depth) const2434 unsigned AArch64TargetLowering::ComputeNumSignBitsForTargetNode(
2435 SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
2436 unsigned Depth) const {
2437 EVT VT = Op.getValueType();
2438 unsigned VTBits = VT.getScalarSizeInBits();
2439 unsigned Opcode = Op.getOpcode();
2440 switch (Opcode) {
2441 case AArch64ISD::CMEQ:
2442 case AArch64ISD::CMGE:
2443 case AArch64ISD::CMGT:
2444 case AArch64ISD::CMHI:
2445 case AArch64ISD::CMHS:
2446 case AArch64ISD::FCMEQ:
2447 case AArch64ISD::FCMGE:
2448 case AArch64ISD::FCMGT:
2449 case AArch64ISD::CMEQz:
2450 case AArch64ISD::CMGEz:
2451 case AArch64ISD::CMGTz:
2452 case AArch64ISD::CMLEz:
2453 case AArch64ISD::CMLTz:
2454 case AArch64ISD::FCMEQz:
2455 case AArch64ISD::FCMGEz:
2456 case AArch64ISD::FCMGTz:
2457 case AArch64ISD::FCMLEz:
2458 case AArch64ISD::FCMLTz:
2459 // Compares return either 0 or all-ones
2460 return VTBits;
2461 }
2462
2463 return 1;
2464 }
2465
getScalarShiftAmountTy(const DataLayout & DL,EVT) const2466 MVT AArch64TargetLowering::getScalarShiftAmountTy(const DataLayout &DL,
2467 EVT) const {
2468 return MVT::i64;
2469 }
2470
allowsMisalignedMemoryAccesses(EVT VT,unsigned AddrSpace,Align Alignment,MachineMemOperand::Flags Flags,unsigned * Fast) const2471 bool AArch64TargetLowering::allowsMisalignedMemoryAccesses(
2472 EVT VT, unsigned AddrSpace, Align Alignment, MachineMemOperand::Flags Flags,
2473 unsigned *Fast) const {
2474 if (Subtarget->requiresStrictAlign())
2475 return false;
2476
2477 if (Fast) {
2478 // Some CPUs are fine with unaligned stores except for 128-bit ones.
2479 *Fast = !Subtarget->isMisaligned128StoreSlow() || VT.getStoreSize() != 16 ||
2480 // See comments in performSTORECombine() for more details about
2481 // these conditions.
2482
2483 // Code that uses clang vector extensions can mark that it
2484 // wants unaligned accesses to be treated as fast by
2485 // underspecifying alignment to be 1 or 2.
2486 Alignment <= 2 ||
2487
2488 // Disregard v2i64. Memcpy lowering produces those and splitting
2489 // them regresses performance on micro-benchmarks and olden/bh.
2490 VT == MVT::v2i64;
2491 }
2492 return true;
2493 }
2494
2495 // Same as above but handling LLTs instead.
allowsMisalignedMemoryAccesses(LLT Ty,unsigned AddrSpace,Align Alignment,MachineMemOperand::Flags Flags,unsigned * Fast) const2496 bool AArch64TargetLowering::allowsMisalignedMemoryAccesses(
2497 LLT Ty, unsigned AddrSpace, Align Alignment, MachineMemOperand::Flags Flags,
2498 unsigned *Fast) const {
2499 if (Subtarget->requiresStrictAlign())
2500 return false;
2501
2502 if (Fast) {
2503 // Some CPUs are fine with unaligned stores except for 128-bit ones.
2504 *Fast = !Subtarget->isMisaligned128StoreSlow() ||
2505 Ty.getSizeInBytes() != 16 ||
2506 // See comments in performSTORECombine() for more details about
2507 // these conditions.
2508
2509 // Code that uses clang vector extensions can mark that it
2510 // wants unaligned accesses to be treated as fast by
2511 // underspecifying alignment to be 1 or 2.
2512 Alignment <= 2 ||
2513
2514 // Disregard v2i64. Memcpy lowering produces those and splitting
2515 // them regresses performance on micro-benchmarks and olden/bh.
2516 Ty == LLT::fixed_vector(2, 64);
2517 }
2518 return true;
2519 }
2520
2521 FastISel *
createFastISel(FunctionLoweringInfo & funcInfo,const TargetLibraryInfo * libInfo) const2522 AArch64TargetLowering::createFastISel(FunctionLoweringInfo &funcInfo,
2523 const TargetLibraryInfo *libInfo) const {
2524 return AArch64::createFastISel(funcInfo, libInfo);
2525 }
2526
getTargetNodeName(unsigned Opcode) const2527 const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
2528 #define MAKE_CASE(V) \
2529 case V: \
2530 return #V;
2531 switch ((AArch64ISD::NodeType)Opcode) {
2532 case AArch64ISD::FIRST_NUMBER:
2533 break;
2534 MAKE_CASE(AArch64ISD::ALLOCATE_ZA_BUFFER)
2535 MAKE_CASE(AArch64ISD::INIT_TPIDR2OBJ)
2536 MAKE_CASE(AArch64ISD::COALESCER_BARRIER)
2537 MAKE_CASE(AArch64ISD::VG_SAVE)
2538 MAKE_CASE(AArch64ISD::VG_RESTORE)
2539 MAKE_CASE(AArch64ISD::SMSTART)
2540 MAKE_CASE(AArch64ISD::SMSTOP)
2541 MAKE_CASE(AArch64ISD::RESTORE_ZA)
2542 MAKE_CASE(AArch64ISD::RESTORE_ZT)
2543 MAKE_CASE(AArch64ISD::SAVE_ZT)
2544 MAKE_CASE(AArch64ISD::CALL)
2545 MAKE_CASE(AArch64ISD::ADRP)
2546 MAKE_CASE(AArch64ISD::ADR)
2547 MAKE_CASE(AArch64ISD::ADDlow)
2548 MAKE_CASE(AArch64ISD::AUTH_CALL)
2549 MAKE_CASE(AArch64ISD::AUTH_TC_RETURN)
2550 MAKE_CASE(AArch64ISD::AUTH_CALL_RVMARKER)
2551 MAKE_CASE(AArch64ISD::LOADgot)
2552 MAKE_CASE(AArch64ISD::RET_GLUE)
2553 MAKE_CASE(AArch64ISD::BRCOND)
2554 MAKE_CASE(AArch64ISD::CSEL)
2555 MAKE_CASE(AArch64ISD::CSINV)
2556 MAKE_CASE(AArch64ISD::CSNEG)
2557 MAKE_CASE(AArch64ISD::CSINC)
2558 MAKE_CASE(AArch64ISD::THREAD_POINTER)
2559 MAKE_CASE(AArch64ISD::TLSDESC_CALLSEQ)
2560 MAKE_CASE(AArch64ISD::PROBED_ALLOCA)
2561 MAKE_CASE(AArch64ISD::ABDS_PRED)
2562 MAKE_CASE(AArch64ISD::ABDU_PRED)
2563 MAKE_CASE(AArch64ISD::HADDS_PRED)
2564 MAKE_CASE(AArch64ISD::HADDU_PRED)
2565 MAKE_CASE(AArch64ISD::MUL_PRED)
2566 MAKE_CASE(AArch64ISD::MULHS_PRED)
2567 MAKE_CASE(AArch64ISD::MULHU_PRED)
2568 MAKE_CASE(AArch64ISD::RHADDS_PRED)
2569 MAKE_CASE(AArch64ISD::RHADDU_PRED)
2570 MAKE_CASE(AArch64ISD::SDIV_PRED)
2571 MAKE_CASE(AArch64ISD::SHL_PRED)
2572 MAKE_CASE(AArch64ISD::SMAX_PRED)
2573 MAKE_CASE(AArch64ISD::SMIN_PRED)
2574 MAKE_CASE(AArch64ISD::SRA_PRED)
2575 MAKE_CASE(AArch64ISD::SRL_PRED)
2576 MAKE_CASE(AArch64ISD::UDIV_PRED)
2577 MAKE_CASE(AArch64ISD::UMAX_PRED)
2578 MAKE_CASE(AArch64ISD::UMIN_PRED)
2579 MAKE_CASE(AArch64ISD::SRAD_MERGE_OP1)
2580 MAKE_CASE(AArch64ISD::FNEG_MERGE_PASSTHRU)
2581 MAKE_CASE(AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU)
2582 MAKE_CASE(AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU)
2583 MAKE_CASE(AArch64ISD::FCEIL_MERGE_PASSTHRU)
2584 MAKE_CASE(AArch64ISD::FFLOOR_MERGE_PASSTHRU)
2585 MAKE_CASE(AArch64ISD::FNEARBYINT_MERGE_PASSTHRU)
2586 MAKE_CASE(AArch64ISD::FRINT_MERGE_PASSTHRU)
2587 MAKE_CASE(AArch64ISD::FROUND_MERGE_PASSTHRU)
2588 MAKE_CASE(AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU)
2589 MAKE_CASE(AArch64ISD::FTRUNC_MERGE_PASSTHRU)
2590 MAKE_CASE(AArch64ISD::FP_ROUND_MERGE_PASSTHRU)
2591 MAKE_CASE(AArch64ISD::FP_EXTEND_MERGE_PASSTHRU)
2592 MAKE_CASE(AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU)
2593 MAKE_CASE(AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU)
2594 MAKE_CASE(AArch64ISD::FCVTZU_MERGE_PASSTHRU)
2595 MAKE_CASE(AArch64ISD::FCVTZS_MERGE_PASSTHRU)
2596 MAKE_CASE(AArch64ISD::FSQRT_MERGE_PASSTHRU)
2597 MAKE_CASE(AArch64ISD::FRECPX_MERGE_PASSTHRU)
2598 MAKE_CASE(AArch64ISD::FABS_MERGE_PASSTHRU)
2599 MAKE_CASE(AArch64ISD::ABS_MERGE_PASSTHRU)
2600 MAKE_CASE(AArch64ISD::NEG_MERGE_PASSTHRU)
2601 MAKE_CASE(AArch64ISD::SETCC_MERGE_ZERO)
2602 MAKE_CASE(AArch64ISD::ADC)
2603 MAKE_CASE(AArch64ISD::SBC)
2604 MAKE_CASE(AArch64ISD::ADDS)
2605 MAKE_CASE(AArch64ISD::SUBS)
2606 MAKE_CASE(AArch64ISD::ADCS)
2607 MAKE_CASE(AArch64ISD::SBCS)
2608 MAKE_CASE(AArch64ISD::ANDS)
2609 MAKE_CASE(AArch64ISD::CCMP)
2610 MAKE_CASE(AArch64ISD::CCMN)
2611 MAKE_CASE(AArch64ISD::FCCMP)
2612 MAKE_CASE(AArch64ISD::FCMP)
2613 MAKE_CASE(AArch64ISD::STRICT_FCMP)
2614 MAKE_CASE(AArch64ISD::STRICT_FCMPE)
2615 MAKE_CASE(AArch64ISD::FCVTXN)
2616 MAKE_CASE(AArch64ISD::SME_ZA_LDR)
2617 MAKE_CASE(AArch64ISD::SME_ZA_STR)
2618 MAKE_CASE(AArch64ISD::DUP)
2619 MAKE_CASE(AArch64ISD::DUPLANE8)
2620 MAKE_CASE(AArch64ISD::DUPLANE16)
2621 MAKE_CASE(AArch64ISD::DUPLANE32)
2622 MAKE_CASE(AArch64ISD::DUPLANE64)
2623 MAKE_CASE(AArch64ISD::DUPLANE128)
2624 MAKE_CASE(AArch64ISD::MOVI)
2625 MAKE_CASE(AArch64ISD::MOVIshift)
2626 MAKE_CASE(AArch64ISD::MOVIedit)
2627 MAKE_CASE(AArch64ISD::MOVImsl)
2628 MAKE_CASE(AArch64ISD::FMOV)
2629 MAKE_CASE(AArch64ISD::MVNIshift)
2630 MAKE_CASE(AArch64ISD::MVNImsl)
2631 MAKE_CASE(AArch64ISD::BICi)
2632 MAKE_CASE(AArch64ISD::ORRi)
2633 MAKE_CASE(AArch64ISD::BSP)
2634 MAKE_CASE(AArch64ISD::ZIP1)
2635 MAKE_CASE(AArch64ISD::ZIP2)
2636 MAKE_CASE(AArch64ISD::UZP1)
2637 MAKE_CASE(AArch64ISD::UZP2)
2638 MAKE_CASE(AArch64ISD::TRN1)
2639 MAKE_CASE(AArch64ISD::TRN2)
2640 MAKE_CASE(AArch64ISD::REV16)
2641 MAKE_CASE(AArch64ISD::REV32)
2642 MAKE_CASE(AArch64ISD::REV64)
2643 MAKE_CASE(AArch64ISD::EXT)
2644 MAKE_CASE(AArch64ISD::SPLICE)
2645 MAKE_CASE(AArch64ISD::VSHL)
2646 MAKE_CASE(AArch64ISD::VLSHR)
2647 MAKE_CASE(AArch64ISD::VASHR)
2648 MAKE_CASE(AArch64ISD::VSLI)
2649 MAKE_CASE(AArch64ISD::VSRI)
2650 MAKE_CASE(AArch64ISD::CMEQ)
2651 MAKE_CASE(AArch64ISD::CMGE)
2652 MAKE_CASE(AArch64ISD::CMGT)
2653 MAKE_CASE(AArch64ISD::CMHI)
2654 MAKE_CASE(AArch64ISD::CMHS)
2655 MAKE_CASE(AArch64ISD::FCMEQ)
2656 MAKE_CASE(AArch64ISD::FCMGE)
2657 MAKE_CASE(AArch64ISD::FCMGT)
2658 MAKE_CASE(AArch64ISD::CMEQz)
2659 MAKE_CASE(AArch64ISD::CMGEz)
2660 MAKE_CASE(AArch64ISD::CMGTz)
2661 MAKE_CASE(AArch64ISD::CMLEz)
2662 MAKE_CASE(AArch64ISD::CMLTz)
2663 MAKE_CASE(AArch64ISD::FCMEQz)
2664 MAKE_CASE(AArch64ISD::FCMGEz)
2665 MAKE_CASE(AArch64ISD::FCMGTz)
2666 MAKE_CASE(AArch64ISD::FCMLEz)
2667 MAKE_CASE(AArch64ISD::FCMLTz)
2668 MAKE_CASE(AArch64ISD::SADDV)
2669 MAKE_CASE(AArch64ISD::UADDV)
2670 MAKE_CASE(AArch64ISD::UADDLV)
2671 MAKE_CASE(AArch64ISD::SADDLV)
2672 MAKE_CASE(AArch64ISD::SDOT)
2673 MAKE_CASE(AArch64ISD::UDOT)
2674 MAKE_CASE(AArch64ISD::SMINV)
2675 MAKE_CASE(AArch64ISD::UMINV)
2676 MAKE_CASE(AArch64ISD::SMAXV)
2677 MAKE_CASE(AArch64ISD::UMAXV)
2678 MAKE_CASE(AArch64ISD::SADDV_PRED)
2679 MAKE_CASE(AArch64ISD::UADDV_PRED)
2680 MAKE_CASE(AArch64ISD::SMAXV_PRED)
2681 MAKE_CASE(AArch64ISD::UMAXV_PRED)
2682 MAKE_CASE(AArch64ISD::SMINV_PRED)
2683 MAKE_CASE(AArch64ISD::UMINV_PRED)
2684 MAKE_CASE(AArch64ISD::ORV_PRED)
2685 MAKE_CASE(AArch64ISD::EORV_PRED)
2686 MAKE_CASE(AArch64ISD::ANDV_PRED)
2687 MAKE_CASE(AArch64ISD::CLASTA_N)
2688 MAKE_CASE(AArch64ISD::CLASTB_N)
2689 MAKE_CASE(AArch64ISD::LASTA)
2690 MAKE_CASE(AArch64ISD::LASTB)
2691 MAKE_CASE(AArch64ISD::REINTERPRET_CAST)
2692 MAKE_CASE(AArch64ISD::LS64_BUILD)
2693 MAKE_CASE(AArch64ISD::LS64_EXTRACT)
2694 MAKE_CASE(AArch64ISD::TBL)
2695 MAKE_CASE(AArch64ISD::FADD_PRED)
2696 MAKE_CASE(AArch64ISD::FADDA_PRED)
2697 MAKE_CASE(AArch64ISD::FADDV_PRED)
2698 MAKE_CASE(AArch64ISD::FDIV_PRED)
2699 MAKE_CASE(AArch64ISD::FMA_PRED)
2700 MAKE_CASE(AArch64ISD::FMAX_PRED)
2701 MAKE_CASE(AArch64ISD::FMAXV_PRED)
2702 MAKE_CASE(AArch64ISD::FMAXNM_PRED)
2703 MAKE_CASE(AArch64ISD::FMAXNMV_PRED)
2704 MAKE_CASE(AArch64ISD::FMIN_PRED)
2705 MAKE_CASE(AArch64ISD::FMINV_PRED)
2706 MAKE_CASE(AArch64ISD::FMINNM_PRED)
2707 MAKE_CASE(AArch64ISD::FMINNMV_PRED)
2708 MAKE_CASE(AArch64ISD::FMUL_PRED)
2709 MAKE_CASE(AArch64ISD::FSUB_PRED)
2710 MAKE_CASE(AArch64ISD::RDSVL)
2711 MAKE_CASE(AArch64ISD::BIC)
2712 MAKE_CASE(AArch64ISD::CBZ)
2713 MAKE_CASE(AArch64ISD::CBNZ)
2714 MAKE_CASE(AArch64ISD::TBZ)
2715 MAKE_CASE(AArch64ISD::TBNZ)
2716 MAKE_CASE(AArch64ISD::TC_RETURN)
2717 MAKE_CASE(AArch64ISD::PREFETCH)
2718 MAKE_CASE(AArch64ISD::SITOF)
2719 MAKE_CASE(AArch64ISD::UITOF)
2720 MAKE_CASE(AArch64ISD::NVCAST)
2721 MAKE_CASE(AArch64ISD::MRS)
2722 MAKE_CASE(AArch64ISD::SQSHL_I)
2723 MAKE_CASE(AArch64ISD::UQSHL_I)
2724 MAKE_CASE(AArch64ISD::SRSHR_I)
2725 MAKE_CASE(AArch64ISD::URSHR_I)
2726 MAKE_CASE(AArch64ISD::SQSHLU_I)
2727 MAKE_CASE(AArch64ISD::WrapperLarge)
2728 MAKE_CASE(AArch64ISD::LD2post)
2729 MAKE_CASE(AArch64ISD::LD3post)
2730 MAKE_CASE(AArch64ISD::LD4post)
2731 MAKE_CASE(AArch64ISD::ST2post)
2732 MAKE_CASE(AArch64ISD::ST3post)
2733 MAKE_CASE(AArch64ISD::ST4post)
2734 MAKE_CASE(AArch64ISD::LD1x2post)
2735 MAKE_CASE(AArch64ISD::LD1x3post)
2736 MAKE_CASE(AArch64ISD::LD1x4post)
2737 MAKE_CASE(AArch64ISD::ST1x2post)
2738 MAKE_CASE(AArch64ISD::ST1x3post)
2739 MAKE_CASE(AArch64ISD::ST1x4post)
2740 MAKE_CASE(AArch64ISD::LD1DUPpost)
2741 MAKE_CASE(AArch64ISD::LD2DUPpost)
2742 MAKE_CASE(AArch64ISD::LD3DUPpost)
2743 MAKE_CASE(AArch64ISD::LD4DUPpost)
2744 MAKE_CASE(AArch64ISD::LD1LANEpost)
2745 MAKE_CASE(AArch64ISD::LD2LANEpost)
2746 MAKE_CASE(AArch64ISD::LD3LANEpost)
2747 MAKE_CASE(AArch64ISD::LD4LANEpost)
2748 MAKE_CASE(AArch64ISD::ST2LANEpost)
2749 MAKE_CASE(AArch64ISD::ST3LANEpost)
2750 MAKE_CASE(AArch64ISD::ST4LANEpost)
2751 MAKE_CASE(AArch64ISD::SMULL)
2752 MAKE_CASE(AArch64ISD::UMULL)
2753 MAKE_CASE(AArch64ISD::PMULL)
2754 MAKE_CASE(AArch64ISD::FRECPE)
2755 MAKE_CASE(AArch64ISD::FRECPS)
2756 MAKE_CASE(AArch64ISD::FRSQRTE)
2757 MAKE_CASE(AArch64ISD::FRSQRTS)
2758 MAKE_CASE(AArch64ISD::STG)
2759 MAKE_CASE(AArch64ISD::STZG)
2760 MAKE_CASE(AArch64ISD::ST2G)
2761 MAKE_CASE(AArch64ISD::STZ2G)
2762 MAKE_CASE(AArch64ISD::SUNPKHI)
2763 MAKE_CASE(AArch64ISD::SUNPKLO)
2764 MAKE_CASE(AArch64ISD::UUNPKHI)
2765 MAKE_CASE(AArch64ISD::UUNPKLO)
2766 MAKE_CASE(AArch64ISD::INSR)
2767 MAKE_CASE(AArch64ISD::PTEST)
2768 MAKE_CASE(AArch64ISD::PTEST_ANY)
2769 MAKE_CASE(AArch64ISD::PTRUE)
2770 MAKE_CASE(AArch64ISD::LD1_MERGE_ZERO)
2771 MAKE_CASE(AArch64ISD::LD1S_MERGE_ZERO)
2772 MAKE_CASE(AArch64ISD::LDNF1_MERGE_ZERO)
2773 MAKE_CASE(AArch64ISD::LDNF1S_MERGE_ZERO)
2774 MAKE_CASE(AArch64ISD::LDFF1_MERGE_ZERO)
2775 MAKE_CASE(AArch64ISD::LDFF1S_MERGE_ZERO)
2776 MAKE_CASE(AArch64ISD::LD1RQ_MERGE_ZERO)
2777 MAKE_CASE(AArch64ISD::LD1RO_MERGE_ZERO)
2778 MAKE_CASE(AArch64ISD::SVE_LD2_MERGE_ZERO)
2779 MAKE_CASE(AArch64ISD::SVE_LD3_MERGE_ZERO)
2780 MAKE_CASE(AArch64ISD::SVE_LD4_MERGE_ZERO)
2781 MAKE_CASE(AArch64ISD::GLD1_MERGE_ZERO)
2782 MAKE_CASE(AArch64ISD::GLD1_SCALED_MERGE_ZERO)
2783 MAKE_CASE(AArch64ISD::GLD1_SXTW_MERGE_ZERO)
2784 MAKE_CASE(AArch64ISD::GLD1_UXTW_MERGE_ZERO)
2785 MAKE_CASE(AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO)
2786 MAKE_CASE(AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO)
2787 MAKE_CASE(AArch64ISD::GLD1_IMM_MERGE_ZERO)
2788 MAKE_CASE(AArch64ISD::GLD1Q_MERGE_ZERO)
2789 MAKE_CASE(AArch64ISD::GLD1Q_INDEX_MERGE_ZERO)
2790 MAKE_CASE(AArch64ISD::GLD1S_MERGE_ZERO)
2791 MAKE_CASE(AArch64ISD::GLD1S_SCALED_MERGE_ZERO)
2792 MAKE_CASE(AArch64ISD::GLD1S_SXTW_MERGE_ZERO)
2793 MAKE_CASE(AArch64ISD::GLD1S_UXTW_MERGE_ZERO)
2794 MAKE_CASE(AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO)
2795 MAKE_CASE(AArch64ISD::GLD1S_UXTW_SCALED_MERGE_ZERO)
2796 MAKE_CASE(AArch64ISD::GLD1S_IMM_MERGE_ZERO)
2797 MAKE_CASE(AArch64ISD::GLDFF1_MERGE_ZERO)
2798 MAKE_CASE(AArch64ISD::GLDFF1_SCALED_MERGE_ZERO)
2799 MAKE_CASE(AArch64ISD::GLDFF1_SXTW_MERGE_ZERO)
2800 MAKE_CASE(AArch64ISD::GLDFF1_UXTW_MERGE_ZERO)
2801 MAKE_CASE(AArch64ISD::GLDFF1_SXTW_SCALED_MERGE_ZERO)
2802 MAKE_CASE(AArch64ISD::GLDFF1_UXTW_SCALED_MERGE_ZERO)
2803 MAKE_CASE(AArch64ISD::GLDFF1_IMM_MERGE_ZERO)
2804 MAKE_CASE(AArch64ISD::GLDFF1S_MERGE_ZERO)
2805 MAKE_CASE(AArch64ISD::GLDFF1S_SCALED_MERGE_ZERO)
2806 MAKE_CASE(AArch64ISD::GLDFF1S_SXTW_MERGE_ZERO)
2807 MAKE_CASE(AArch64ISD::GLDFF1S_UXTW_MERGE_ZERO)
2808 MAKE_CASE(AArch64ISD::GLDFF1S_SXTW_SCALED_MERGE_ZERO)
2809 MAKE_CASE(AArch64ISD::GLDFF1S_UXTW_SCALED_MERGE_ZERO)
2810 MAKE_CASE(AArch64ISD::GLDFF1S_IMM_MERGE_ZERO)
2811 MAKE_CASE(AArch64ISD::GLDNT1_MERGE_ZERO)
2812 MAKE_CASE(AArch64ISD::GLDNT1_INDEX_MERGE_ZERO)
2813 MAKE_CASE(AArch64ISD::GLDNT1S_MERGE_ZERO)
2814 MAKE_CASE(AArch64ISD::SST1Q_PRED)
2815 MAKE_CASE(AArch64ISD::SST1Q_INDEX_PRED)
2816 MAKE_CASE(AArch64ISD::ST1_PRED)
2817 MAKE_CASE(AArch64ISD::SST1_PRED)
2818 MAKE_CASE(AArch64ISD::SST1_SCALED_PRED)
2819 MAKE_CASE(AArch64ISD::SST1_SXTW_PRED)
2820 MAKE_CASE(AArch64ISD::SST1_UXTW_PRED)
2821 MAKE_CASE(AArch64ISD::SST1_SXTW_SCALED_PRED)
2822 MAKE_CASE(AArch64ISD::SST1_UXTW_SCALED_PRED)
2823 MAKE_CASE(AArch64ISD::SST1_IMM_PRED)
2824 MAKE_CASE(AArch64ISD::SSTNT1_PRED)
2825 MAKE_CASE(AArch64ISD::SSTNT1_INDEX_PRED)
2826 MAKE_CASE(AArch64ISD::LDP)
2827 MAKE_CASE(AArch64ISD::LDIAPP)
2828 MAKE_CASE(AArch64ISD::LDNP)
2829 MAKE_CASE(AArch64ISD::STP)
2830 MAKE_CASE(AArch64ISD::STILP)
2831 MAKE_CASE(AArch64ISD::STNP)
2832 MAKE_CASE(AArch64ISD::BITREVERSE_MERGE_PASSTHRU)
2833 MAKE_CASE(AArch64ISD::BSWAP_MERGE_PASSTHRU)
2834 MAKE_CASE(AArch64ISD::REVH_MERGE_PASSTHRU)
2835 MAKE_CASE(AArch64ISD::REVW_MERGE_PASSTHRU)
2836 MAKE_CASE(AArch64ISD::REVD_MERGE_PASSTHRU)
2837 MAKE_CASE(AArch64ISD::CTLZ_MERGE_PASSTHRU)
2838 MAKE_CASE(AArch64ISD::CTPOP_MERGE_PASSTHRU)
2839 MAKE_CASE(AArch64ISD::DUP_MERGE_PASSTHRU)
2840 MAKE_CASE(AArch64ISD::INDEX_VECTOR)
2841 MAKE_CASE(AArch64ISD::ADDP)
2842 MAKE_CASE(AArch64ISD::SADDLP)
2843 MAKE_CASE(AArch64ISD::UADDLP)
2844 MAKE_CASE(AArch64ISD::CALL_RVMARKER)
2845 MAKE_CASE(AArch64ISD::ASSERT_ZEXT_BOOL)
2846 MAKE_CASE(AArch64ISD::MOPS_MEMSET)
2847 MAKE_CASE(AArch64ISD::MOPS_MEMSET_TAGGING)
2848 MAKE_CASE(AArch64ISD::MOPS_MEMCOPY)
2849 MAKE_CASE(AArch64ISD::MOPS_MEMMOVE)
2850 MAKE_CASE(AArch64ISD::CALL_BTI)
2851 MAKE_CASE(AArch64ISD::MRRS)
2852 MAKE_CASE(AArch64ISD::MSRR)
2853 MAKE_CASE(AArch64ISD::RSHRNB_I)
2854 MAKE_CASE(AArch64ISD::CTTZ_ELTS)
2855 MAKE_CASE(AArch64ISD::CALL_ARM64EC_TO_X64)
2856 MAKE_CASE(AArch64ISD::URSHR_I_PRED)
2857 }
2858 #undef MAKE_CASE
2859 return nullptr;
2860 }
2861
2862 MachineBasicBlock *
EmitF128CSEL(MachineInstr & MI,MachineBasicBlock * MBB) const2863 AArch64TargetLowering::EmitF128CSEL(MachineInstr &MI,
2864 MachineBasicBlock *MBB) const {
2865 // We materialise the F128CSEL pseudo-instruction as some control flow and a
2866 // phi node:
2867
2868 // OrigBB:
2869 // [... previous instrs leading to comparison ...]
2870 // b.ne TrueBB
2871 // b EndBB
2872 // TrueBB:
2873 // ; Fallthrough
2874 // EndBB:
2875 // Dest = PHI [IfTrue, TrueBB], [IfFalse, OrigBB]
2876
2877 MachineFunction *MF = MBB->getParent();
2878 const TargetInstrInfo *TII = Subtarget->getInstrInfo();
2879 const BasicBlock *LLVM_BB = MBB->getBasicBlock();
2880 DebugLoc DL = MI.getDebugLoc();
2881 MachineFunction::iterator It = ++MBB->getIterator();
2882
2883 Register DestReg = MI.getOperand(0).getReg();
2884 Register IfTrueReg = MI.getOperand(1).getReg();
2885 Register IfFalseReg = MI.getOperand(2).getReg();
2886 unsigned CondCode = MI.getOperand(3).getImm();
2887 bool NZCVKilled = MI.getOperand(4).isKill();
2888
2889 MachineBasicBlock *TrueBB = MF->CreateMachineBasicBlock(LLVM_BB);
2890 MachineBasicBlock *EndBB = MF->CreateMachineBasicBlock(LLVM_BB);
2891 MF->insert(It, TrueBB);
2892 MF->insert(It, EndBB);
2893
2894 // Transfer rest of current basic-block to EndBB
2895 EndBB->splice(EndBB->begin(), MBB, std::next(MachineBasicBlock::iterator(MI)),
2896 MBB->end());
2897 EndBB->transferSuccessorsAndUpdatePHIs(MBB);
2898
2899 BuildMI(MBB, DL, TII->get(AArch64::Bcc)).addImm(CondCode).addMBB(TrueBB);
2900 BuildMI(MBB, DL, TII->get(AArch64::B)).addMBB(EndBB);
2901 MBB->addSuccessor(TrueBB);
2902 MBB->addSuccessor(EndBB);
2903
2904 // TrueBB falls through to the end.
2905 TrueBB->addSuccessor(EndBB);
2906
2907 if (!NZCVKilled) {
2908 TrueBB->addLiveIn(AArch64::NZCV);
2909 EndBB->addLiveIn(AArch64::NZCV);
2910 }
2911
2912 BuildMI(*EndBB, EndBB->begin(), DL, TII->get(AArch64::PHI), DestReg)
2913 .addReg(IfTrueReg)
2914 .addMBB(TrueBB)
2915 .addReg(IfFalseReg)
2916 .addMBB(MBB);
2917
2918 MI.eraseFromParent();
2919 return EndBB;
2920 }
2921
EmitLoweredCatchRet(MachineInstr & MI,MachineBasicBlock * BB) const2922 MachineBasicBlock *AArch64TargetLowering::EmitLoweredCatchRet(
2923 MachineInstr &MI, MachineBasicBlock *BB) const {
2924 assert(!isAsynchronousEHPersonality(classifyEHPersonality(
2925 BB->getParent()->getFunction().getPersonalityFn())) &&
2926 "SEH does not use catchret!");
2927 return BB;
2928 }
2929
2930 MachineBasicBlock *
EmitDynamicProbedAlloc(MachineInstr & MI,MachineBasicBlock * MBB) const2931 AArch64TargetLowering::EmitDynamicProbedAlloc(MachineInstr &MI,
2932 MachineBasicBlock *MBB) const {
2933 MachineFunction &MF = *MBB->getParent();
2934 MachineBasicBlock::iterator MBBI = MI.getIterator();
2935 DebugLoc DL = MBB->findDebugLoc(MBBI);
2936 const AArch64InstrInfo &TII =
2937 *MF.getSubtarget<AArch64Subtarget>().getInstrInfo();
2938 Register TargetReg = MI.getOperand(0).getReg();
2939 MachineBasicBlock::iterator NextInst =
2940 TII.probedStackAlloc(MBBI, TargetReg, false);
2941
2942 MI.eraseFromParent();
2943 return NextInst->getParent();
2944 }
2945
2946 MachineBasicBlock *
EmitTileLoad(unsigned Opc,unsigned BaseReg,MachineInstr & MI,MachineBasicBlock * BB) const2947 AArch64TargetLowering::EmitTileLoad(unsigned Opc, unsigned BaseReg,
2948 MachineInstr &MI,
2949 MachineBasicBlock *BB) const {
2950 const TargetInstrInfo *TII = Subtarget->getInstrInfo();
2951 MachineInstrBuilder MIB = BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(Opc));
2952
2953 MIB.addReg(BaseReg + MI.getOperand(0).getImm(), RegState::Define);
2954 MIB.add(MI.getOperand(1)); // slice index register
2955 MIB.add(MI.getOperand(2)); // slice index offset
2956 MIB.add(MI.getOperand(3)); // pg
2957 MIB.add(MI.getOperand(4)); // base
2958 MIB.add(MI.getOperand(5)); // offset
2959
2960 MI.eraseFromParent(); // The pseudo is gone now.
2961 return BB;
2962 }
2963
2964 MachineBasicBlock *
EmitFill(MachineInstr & MI,MachineBasicBlock * BB) const2965 AArch64TargetLowering::EmitFill(MachineInstr &MI, MachineBasicBlock *BB) const {
2966 const TargetInstrInfo *TII = Subtarget->getInstrInfo();
2967 MachineInstrBuilder MIB =
2968 BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::LDR_ZA));
2969
2970 MIB.addReg(AArch64::ZA, RegState::Define);
2971 MIB.add(MI.getOperand(0)); // Vector select register
2972 MIB.add(MI.getOperand(1)); // Vector select offset
2973 MIB.add(MI.getOperand(2)); // Base
2974 MIB.add(MI.getOperand(1)); // Offset, same as vector select offset
2975
2976 MI.eraseFromParent(); // The pseudo is gone now.
2977 return BB;
2978 }
2979
EmitZTInstr(MachineInstr & MI,MachineBasicBlock * BB,unsigned Opcode,bool Op0IsDef) const2980 MachineBasicBlock *AArch64TargetLowering::EmitZTInstr(MachineInstr &MI,
2981 MachineBasicBlock *BB,
2982 unsigned Opcode,
2983 bool Op0IsDef) const {
2984 const TargetInstrInfo *TII = Subtarget->getInstrInfo();
2985 MachineInstrBuilder MIB;
2986
2987 MIB = BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(Opcode))
2988 .addReg(MI.getOperand(0).getReg(), Op0IsDef ? RegState::Define : 0);
2989 for (unsigned I = 1; I < MI.getNumOperands(); ++I)
2990 MIB.add(MI.getOperand(I));
2991
2992 MI.eraseFromParent(); // The pseudo is gone now.
2993 return BB;
2994 }
2995
2996 MachineBasicBlock *
EmitZAInstr(unsigned Opc,unsigned BaseReg,MachineInstr & MI,MachineBasicBlock * BB) const2997 AArch64TargetLowering::EmitZAInstr(unsigned Opc, unsigned BaseReg,
2998 MachineInstr &MI,
2999 MachineBasicBlock *BB) const {
3000 const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3001 MachineInstrBuilder MIB = BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(Opc));
3002 unsigned StartIdx = 0;
3003
3004 bool HasTile = BaseReg != AArch64::ZA;
3005 bool HasZPROut = HasTile && MI.getOperand(0).isReg();
3006 if (HasZPROut) {
3007 MIB.add(MI.getOperand(StartIdx)); // Output ZPR
3008 ++StartIdx;
3009 }
3010 if (HasTile) {
3011 MIB.addReg(BaseReg + MI.getOperand(StartIdx).getImm(),
3012 RegState::Define); // Output ZA Tile
3013 MIB.addReg(BaseReg + MI.getOperand(StartIdx).getImm()); // Input Za Tile
3014 StartIdx++;
3015 } else {
3016 // Avoids all instructions with mnemonic za.<sz>[Reg, Imm,
3017 if (MI.getOperand(0).isReg() && !MI.getOperand(1).isImm()) {
3018 MIB.add(MI.getOperand(StartIdx)); // Output ZPR
3019 ++StartIdx;
3020 }
3021 MIB.addReg(BaseReg, RegState::Define).addReg(BaseReg);
3022 }
3023 for (unsigned I = StartIdx; I < MI.getNumOperands(); ++I)
3024 MIB.add(MI.getOperand(I));
3025
3026 MI.eraseFromParent(); // The pseudo is gone now.
3027 return BB;
3028 }
3029
3030 MachineBasicBlock *
EmitZero(MachineInstr & MI,MachineBasicBlock * BB) const3031 AArch64TargetLowering::EmitZero(MachineInstr &MI, MachineBasicBlock *BB) const {
3032 const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3033 MachineInstrBuilder MIB =
3034 BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::ZERO_M));
3035 MIB.add(MI.getOperand(0)); // Mask
3036
3037 unsigned Mask = MI.getOperand(0).getImm();
3038 for (unsigned I = 0; I < 8; I++) {
3039 if (Mask & (1 << I))
3040 MIB.addDef(AArch64::ZAD0 + I, RegState::ImplicitDefine);
3041 }
3042
3043 MI.eraseFromParent(); // The pseudo is gone now.
3044 return BB;
3045 }
3046
3047 MachineBasicBlock *
EmitInitTPIDR2Object(MachineInstr & MI,MachineBasicBlock * BB) const3048 AArch64TargetLowering::EmitInitTPIDR2Object(MachineInstr &MI,
3049 MachineBasicBlock *BB) const {
3050 MachineFunction *MF = BB->getParent();
3051 MachineFrameInfo &MFI = MF->getFrameInfo();
3052 AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3053 TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
3054 if (TPIDR2.Uses > 0) {
3055 const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3056 // Store the buffer pointer to the TPIDR2 stack object.
3057 BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STRXui))
3058 .addReg(MI.getOperand(0).getReg())
3059 .addFrameIndex(TPIDR2.FrameIndex)
3060 .addImm(0);
3061 // Set the reserved bytes (10-15) to zero
3062 BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STRHHui))
3063 .addReg(AArch64::WZR)
3064 .addFrameIndex(TPIDR2.FrameIndex)
3065 .addImm(5);
3066 BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::STRWui))
3067 .addReg(AArch64::WZR)
3068 .addFrameIndex(TPIDR2.FrameIndex)
3069 .addImm(3);
3070 } else
3071 MFI.RemoveStackObject(TPIDR2.FrameIndex);
3072
3073 BB->remove_instr(&MI);
3074 return BB;
3075 }
3076
3077 MachineBasicBlock *
EmitAllocateZABuffer(MachineInstr & MI,MachineBasicBlock * BB) const3078 AArch64TargetLowering::EmitAllocateZABuffer(MachineInstr &MI,
3079 MachineBasicBlock *BB) const {
3080 MachineFunction *MF = BB->getParent();
3081 MachineFrameInfo &MFI = MF->getFrameInfo();
3082 AArch64FunctionInfo *FuncInfo = MF->getInfo<AArch64FunctionInfo>();
3083 // TODO This function grows the stack with a subtraction, which doesn't work
3084 // on Windows. Some refactoring to share the functionality in
3085 // LowerWindowsDYNAMIC_STACKALLOC will be required once the Windows ABI
3086 // supports SME
3087 assert(!MF->getSubtarget<AArch64Subtarget>().isTargetWindows() &&
3088 "Lazy ZA save is not yet supported on Windows");
3089
3090 TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
3091
3092 if (TPIDR2.Uses > 0) {
3093 const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3094 MachineRegisterInfo &MRI = MF->getRegInfo();
3095
3096 // The SUBXrs below won't always be emitted in a form that accepts SP
3097 // directly
3098 Register SP = MRI.createVirtualRegister(&AArch64::GPR64RegClass);
3099 BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY), SP)
3100 .addReg(AArch64::SP);
3101
3102 // Allocate a lazy-save buffer object of the size given, normally SVL * SVL
3103 auto Size = MI.getOperand(1).getReg();
3104 auto Dest = MI.getOperand(0).getReg();
3105 BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(AArch64::MSUBXrrr), Dest)
3106 .addReg(Size)
3107 .addReg(Size)
3108 .addReg(SP);
3109 BuildMI(*BB, MI, MI.getDebugLoc(), TII->get(TargetOpcode::COPY),
3110 AArch64::SP)
3111 .addReg(Dest);
3112
3113 // We have just allocated a variable sized object, tell this to PEI.
3114 MFI.CreateVariableSizedObject(Align(16), nullptr);
3115 }
3116
3117 BB->remove_instr(&MI);
3118 return BB;
3119 }
3120
EmitInstrWithCustomInserter(MachineInstr & MI,MachineBasicBlock * BB) const3121 MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
3122 MachineInstr &MI, MachineBasicBlock *BB) const {
3123
3124 int SMEOrigInstr = AArch64::getSMEPseudoMap(MI.getOpcode());
3125 if (SMEOrigInstr != -1) {
3126 const TargetInstrInfo *TII = Subtarget->getInstrInfo();
3127 uint64_t SMEMatrixType =
3128 TII->get(MI.getOpcode()).TSFlags & AArch64::SMEMatrixTypeMask;
3129 switch (SMEMatrixType) {
3130 case (AArch64::SMEMatrixArray):
3131 return EmitZAInstr(SMEOrigInstr, AArch64::ZA, MI, BB);
3132 case (AArch64::SMEMatrixTileB):
3133 return EmitZAInstr(SMEOrigInstr, AArch64::ZAB0, MI, BB);
3134 case (AArch64::SMEMatrixTileH):
3135 return EmitZAInstr(SMEOrigInstr, AArch64::ZAH0, MI, BB);
3136 case (AArch64::SMEMatrixTileS):
3137 return EmitZAInstr(SMEOrigInstr, AArch64::ZAS0, MI, BB);
3138 case (AArch64::SMEMatrixTileD):
3139 return EmitZAInstr(SMEOrigInstr, AArch64::ZAD0, MI, BB);
3140 case (AArch64::SMEMatrixTileQ):
3141 return EmitZAInstr(SMEOrigInstr, AArch64::ZAQ0, MI, BB);
3142 }
3143 }
3144
3145 switch (MI.getOpcode()) {
3146 default:
3147 #ifndef NDEBUG
3148 MI.dump();
3149 #endif
3150 llvm_unreachable("Unexpected instruction for custom inserter!");
3151 case AArch64::InitTPIDR2Obj:
3152 return EmitInitTPIDR2Object(MI, BB);
3153 case AArch64::AllocateZABuffer:
3154 return EmitAllocateZABuffer(MI, BB);
3155 case AArch64::F128CSEL:
3156 return EmitF128CSEL(MI, BB);
3157 case TargetOpcode::STATEPOINT:
3158 // STATEPOINT is a pseudo instruction which has no implicit defs/uses
3159 // while bl call instruction (where statepoint will be lowered at the end)
3160 // has implicit def. This def is early-clobber as it will be set at
3161 // the moment of the call and earlier than any use is read.
3162 // Add this implicit dead def here as a workaround.
3163 MI.addOperand(*MI.getMF(),
3164 MachineOperand::CreateReg(
3165 AArch64::LR, /*isDef*/ true,
3166 /*isImp*/ true, /*isKill*/ false, /*isDead*/ true,
3167 /*isUndef*/ false, /*isEarlyClobber*/ true));
3168 [[fallthrough]];
3169 case TargetOpcode::STACKMAP:
3170 case TargetOpcode::PATCHPOINT:
3171 return emitPatchPoint(MI, BB);
3172
3173 case TargetOpcode::PATCHABLE_EVENT_CALL:
3174 case TargetOpcode::PATCHABLE_TYPED_EVENT_CALL:
3175 return BB;
3176
3177 case AArch64::CATCHRET:
3178 return EmitLoweredCatchRet(MI, BB);
3179
3180 case AArch64::PROBED_STACKALLOC_DYN:
3181 return EmitDynamicProbedAlloc(MI, BB);
3182
3183 case AArch64::LD1_MXIPXX_H_PSEUDO_B:
3184 return EmitTileLoad(AArch64::LD1_MXIPXX_H_B, AArch64::ZAB0, MI, BB);
3185 case AArch64::LD1_MXIPXX_H_PSEUDO_H:
3186 return EmitTileLoad(AArch64::LD1_MXIPXX_H_H, AArch64::ZAH0, MI, BB);
3187 case AArch64::LD1_MXIPXX_H_PSEUDO_S:
3188 return EmitTileLoad(AArch64::LD1_MXIPXX_H_S, AArch64::ZAS0, MI, BB);
3189 case AArch64::LD1_MXIPXX_H_PSEUDO_D:
3190 return EmitTileLoad(AArch64::LD1_MXIPXX_H_D, AArch64::ZAD0, MI, BB);
3191 case AArch64::LD1_MXIPXX_H_PSEUDO_Q:
3192 return EmitTileLoad(AArch64::LD1_MXIPXX_H_Q, AArch64::ZAQ0, MI, BB);
3193 case AArch64::LD1_MXIPXX_V_PSEUDO_B:
3194 return EmitTileLoad(AArch64::LD1_MXIPXX_V_B, AArch64::ZAB0, MI, BB);
3195 case AArch64::LD1_MXIPXX_V_PSEUDO_H:
3196 return EmitTileLoad(AArch64::LD1_MXIPXX_V_H, AArch64::ZAH0, MI, BB);
3197 case AArch64::LD1_MXIPXX_V_PSEUDO_S:
3198 return EmitTileLoad(AArch64::LD1_MXIPXX_V_S, AArch64::ZAS0, MI, BB);
3199 case AArch64::LD1_MXIPXX_V_PSEUDO_D:
3200 return EmitTileLoad(AArch64::LD1_MXIPXX_V_D, AArch64::ZAD0, MI, BB);
3201 case AArch64::LD1_MXIPXX_V_PSEUDO_Q:
3202 return EmitTileLoad(AArch64::LD1_MXIPXX_V_Q, AArch64::ZAQ0, MI, BB);
3203 case AArch64::LDR_ZA_PSEUDO:
3204 return EmitFill(MI, BB);
3205 case AArch64::LDR_TX_PSEUDO:
3206 return EmitZTInstr(MI, BB, AArch64::LDR_TX, /*Op0IsDef=*/true);
3207 case AArch64::STR_TX_PSEUDO:
3208 return EmitZTInstr(MI, BB, AArch64::STR_TX, /*Op0IsDef=*/false);
3209 case AArch64::ZERO_M_PSEUDO:
3210 return EmitZero(MI, BB);
3211 case AArch64::ZERO_T_PSEUDO:
3212 return EmitZTInstr(MI, BB, AArch64::ZERO_T, /*Op0IsDef=*/true);
3213 }
3214 }
3215
3216 //===----------------------------------------------------------------------===//
3217 // AArch64 Lowering private implementation.
3218 //===----------------------------------------------------------------------===//
3219
3220 //===----------------------------------------------------------------------===//
3221 // Lowering Code
3222 //===----------------------------------------------------------------------===//
3223
3224 // Forward declarations of SVE fixed length lowering helpers
3225 static EVT getContainerForFixedLengthVector(SelectionDAG &DAG, EVT VT);
3226 static SDValue convertToScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
3227 static SDValue convertFromScalableVector(SelectionDAG &DAG, EVT VT, SDValue V);
3228 static SDValue convertFixedMaskToScalableVector(SDValue Mask,
3229 SelectionDAG &DAG);
3230 static SDValue getPredicateForVector(SelectionDAG &DAG, SDLoc &DL, EVT VT);
3231 static SDValue getPredicateForScalableVector(SelectionDAG &DAG, SDLoc &DL,
3232 EVT VT);
3233
3234 /// isZerosVector - Check whether SDNode N is a zero-filled vector.
isZerosVector(const SDNode * N)3235 static bool isZerosVector(const SDNode *N) {
3236 // Look through a bit convert.
3237 while (N->getOpcode() == ISD::BITCAST)
3238 N = N->getOperand(0).getNode();
3239
3240 if (ISD::isConstantSplatVectorAllZeros(N))
3241 return true;
3242
3243 if (N->getOpcode() != AArch64ISD::DUP)
3244 return false;
3245
3246 auto Opnd0 = N->getOperand(0);
3247 return isNullConstant(Opnd0) || isNullFPConstant(Opnd0);
3248 }
3249
3250 /// changeIntCCToAArch64CC - Convert a DAG integer condition code to an AArch64
3251 /// CC
changeIntCCToAArch64CC(ISD::CondCode CC)3252 static AArch64CC::CondCode changeIntCCToAArch64CC(ISD::CondCode CC) {
3253 switch (CC) {
3254 default:
3255 llvm_unreachable("Unknown condition code!");
3256 case ISD::SETNE:
3257 return AArch64CC::NE;
3258 case ISD::SETEQ:
3259 return AArch64CC::EQ;
3260 case ISD::SETGT:
3261 return AArch64CC::GT;
3262 case ISD::SETGE:
3263 return AArch64CC::GE;
3264 case ISD::SETLT:
3265 return AArch64CC::LT;
3266 case ISD::SETLE:
3267 return AArch64CC::LE;
3268 case ISD::SETUGT:
3269 return AArch64CC::HI;
3270 case ISD::SETUGE:
3271 return AArch64CC::HS;
3272 case ISD::SETULT:
3273 return AArch64CC::LO;
3274 case ISD::SETULE:
3275 return AArch64CC::LS;
3276 }
3277 }
3278
3279 /// changeFPCCToAArch64CC - Convert a DAG fp condition code to an AArch64 CC.
changeFPCCToAArch64CC(ISD::CondCode CC,AArch64CC::CondCode & CondCode,AArch64CC::CondCode & CondCode2)3280 static void changeFPCCToAArch64CC(ISD::CondCode CC,
3281 AArch64CC::CondCode &CondCode,
3282 AArch64CC::CondCode &CondCode2) {
3283 CondCode2 = AArch64CC::AL;
3284 switch (CC) {
3285 default:
3286 llvm_unreachable("Unknown FP condition!");
3287 case ISD::SETEQ:
3288 case ISD::SETOEQ:
3289 CondCode = AArch64CC::EQ;
3290 break;
3291 case ISD::SETGT:
3292 case ISD::SETOGT:
3293 CondCode = AArch64CC::GT;
3294 break;
3295 case ISD::SETGE:
3296 case ISD::SETOGE:
3297 CondCode = AArch64CC::GE;
3298 break;
3299 case ISD::SETOLT:
3300 CondCode = AArch64CC::MI;
3301 break;
3302 case ISD::SETOLE:
3303 CondCode = AArch64CC::LS;
3304 break;
3305 case ISD::SETONE:
3306 CondCode = AArch64CC::MI;
3307 CondCode2 = AArch64CC::GT;
3308 break;
3309 case ISD::SETO:
3310 CondCode = AArch64CC::VC;
3311 break;
3312 case ISD::SETUO:
3313 CondCode = AArch64CC::VS;
3314 break;
3315 case ISD::SETUEQ:
3316 CondCode = AArch64CC::EQ;
3317 CondCode2 = AArch64CC::VS;
3318 break;
3319 case ISD::SETUGT:
3320 CondCode = AArch64CC::HI;
3321 break;
3322 case ISD::SETUGE:
3323 CondCode = AArch64CC::PL;
3324 break;
3325 case ISD::SETLT:
3326 case ISD::SETULT:
3327 CondCode = AArch64CC::LT;
3328 break;
3329 case ISD::SETLE:
3330 case ISD::SETULE:
3331 CondCode = AArch64CC::LE;
3332 break;
3333 case ISD::SETNE:
3334 case ISD::SETUNE:
3335 CondCode = AArch64CC::NE;
3336 break;
3337 }
3338 }
3339
3340 /// Convert a DAG fp condition code to an AArch64 CC.
3341 /// This differs from changeFPCCToAArch64CC in that it returns cond codes that
3342 /// should be AND'ed instead of OR'ed.
changeFPCCToANDAArch64CC(ISD::CondCode CC,AArch64CC::CondCode & CondCode,AArch64CC::CondCode & CondCode2)3343 static void changeFPCCToANDAArch64CC(ISD::CondCode CC,
3344 AArch64CC::CondCode &CondCode,
3345 AArch64CC::CondCode &CondCode2) {
3346 CondCode2 = AArch64CC::AL;
3347 switch (CC) {
3348 default:
3349 changeFPCCToAArch64CC(CC, CondCode, CondCode2);
3350 assert(CondCode2 == AArch64CC::AL);
3351 break;
3352 case ISD::SETONE:
3353 // (a one b)
3354 // == ((a olt b) || (a ogt b))
3355 // == ((a ord b) && (a une b))
3356 CondCode = AArch64CC::VC;
3357 CondCode2 = AArch64CC::NE;
3358 break;
3359 case ISD::SETUEQ:
3360 // (a ueq b)
3361 // == ((a uno b) || (a oeq b))
3362 // == ((a ule b) && (a uge b))
3363 CondCode = AArch64CC::PL;
3364 CondCode2 = AArch64CC::LE;
3365 break;
3366 }
3367 }
3368
3369 /// changeVectorFPCCToAArch64CC - Convert a DAG fp condition code to an AArch64
3370 /// CC usable with the vector instructions. Fewer operations are available
3371 /// without a real NZCV register, so we have to use less efficient combinations
3372 /// to get the same effect.
changeVectorFPCCToAArch64CC(ISD::CondCode CC,AArch64CC::CondCode & CondCode,AArch64CC::CondCode & CondCode2,bool & Invert)3373 static void changeVectorFPCCToAArch64CC(ISD::CondCode CC,
3374 AArch64CC::CondCode &CondCode,
3375 AArch64CC::CondCode &CondCode2,
3376 bool &Invert) {
3377 Invert = false;
3378 switch (CC) {
3379 default:
3380 // Mostly the scalar mappings work fine.
3381 changeFPCCToAArch64CC(CC, CondCode, CondCode2);
3382 break;
3383 case ISD::SETUO:
3384 Invert = true;
3385 [[fallthrough]];
3386 case ISD::SETO:
3387 CondCode = AArch64CC::MI;
3388 CondCode2 = AArch64CC::GE;
3389 break;
3390 case ISD::SETUEQ:
3391 case ISD::SETULT:
3392 case ISD::SETULE:
3393 case ISD::SETUGT:
3394 case ISD::SETUGE:
3395 // All of the compare-mask comparisons are ordered, but we can switch
3396 // between the two by a double inversion. E.g. ULE == !OGT.
3397 Invert = true;
3398 changeFPCCToAArch64CC(getSetCCInverse(CC, /* FP inverse */ MVT::f32),
3399 CondCode, CondCode2);
3400 break;
3401 }
3402 }
3403
isLegalArithImmed(uint64_t C)3404 static bool isLegalArithImmed(uint64_t C) {
3405 // Matches AArch64DAGToDAGISel::SelectArithImmed().
3406 bool IsLegal = (C >> 12 == 0) || ((C & 0xFFFULL) == 0 && C >> 24 == 0);
3407 LLVM_DEBUG(dbgs() << "Is imm " << C
3408 << " legal: " << (IsLegal ? "yes\n" : "no\n"));
3409 return IsLegal;
3410 }
3411
cannotBeIntMin(SDValue CheckedVal,SelectionDAG & DAG)3412 static bool cannotBeIntMin(SDValue CheckedVal, SelectionDAG &DAG) {
3413 KnownBits KnownSrc = DAG.computeKnownBits(CheckedVal);
3414 return !KnownSrc.getSignedMinValue().isMinSignedValue();
3415 }
3416
3417 // Can a (CMP op1, (sub 0, op2) be turned into a CMN instruction on
3418 // the grounds that "op1 - (-op2) == op1 + op2" ? Not always, the C and V flags
3419 // can be set differently by this operation. It comes down to whether
3420 // "SInt(~op2)+1 == SInt(~op2+1)" (and the same for UInt). If they are then
3421 // everything is fine. If not then the optimization is wrong. Thus general
3422 // comparisons are only valid if op2 != 0.
3423 //
3424 // So, finally, the only LLVM-native comparisons that don't mention C or V
3425 // are the ones that aren't unsigned comparisons. They're the only ones we can
3426 // safely use CMN for in the absence of information about op2.
isCMN(SDValue Op,ISD::CondCode CC,SelectionDAG & DAG)3427 static bool isCMN(SDValue Op, ISD::CondCode CC, SelectionDAG &DAG) {
3428 return Op.getOpcode() == ISD::SUB && isNullConstant(Op.getOperand(0)) &&
3429 (isIntEqualitySetCC(CC) ||
3430 (isUnsignedIntSetCC(CC) && DAG.isKnownNeverZero(Op.getOperand(1))) ||
3431 (isSignedIntSetCC(CC) && cannotBeIntMin(Op.getOperand(1), DAG)));
3432 }
3433
emitStrictFPComparison(SDValue LHS,SDValue RHS,const SDLoc & dl,SelectionDAG & DAG,SDValue Chain,bool IsSignaling)3434 static SDValue emitStrictFPComparison(SDValue LHS, SDValue RHS, const SDLoc &dl,
3435 SelectionDAG &DAG, SDValue Chain,
3436 bool IsSignaling) {
3437 EVT VT = LHS.getValueType();
3438 assert(VT != MVT::f128);
3439
3440 const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
3441
3442 if ((VT == MVT::f16 && !FullFP16) || VT == MVT::bf16) {
3443 LHS = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {MVT::f32, MVT::Other},
3444 {Chain, LHS});
3445 RHS = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {MVT::f32, MVT::Other},
3446 {LHS.getValue(1), RHS});
3447 Chain = RHS.getValue(1);
3448 VT = MVT::f32;
3449 }
3450 unsigned Opcode =
3451 IsSignaling ? AArch64ISD::STRICT_FCMPE : AArch64ISD::STRICT_FCMP;
3452 return DAG.getNode(Opcode, dl, {VT, MVT::Other}, {Chain, LHS, RHS});
3453 }
3454
emitComparison(SDValue LHS,SDValue RHS,ISD::CondCode CC,const SDLoc & dl,SelectionDAG & DAG)3455 static SDValue emitComparison(SDValue LHS, SDValue RHS, ISD::CondCode CC,
3456 const SDLoc &dl, SelectionDAG &DAG) {
3457 EVT VT = LHS.getValueType();
3458 const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
3459
3460 if (VT.isFloatingPoint()) {
3461 assert(VT != MVT::f128);
3462 if ((VT == MVT::f16 && !FullFP16) || VT == MVT::bf16) {
3463 LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, LHS);
3464 RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, RHS);
3465 VT = MVT::f32;
3466 }
3467 return DAG.getNode(AArch64ISD::FCMP, dl, VT, LHS, RHS);
3468 }
3469
3470 // The CMP instruction is just an alias for SUBS, and representing it as
3471 // SUBS means that it's possible to get CSE with subtract operations.
3472 // A later phase can perform the optimization of setting the destination
3473 // register to WZR/XZR if it ends up being unused.
3474 unsigned Opcode = AArch64ISD::SUBS;
3475
3476 if (isCMN(RHS, CC, DAG)) {
3477 // Can we combine a (CMP op1, (sub 0, op2) into a CMN instruction ?
3478 Opcode = AArch64ISD::ADDS;
3479 RHS = RHS.getOperand(1);
3480 } else if (LHS.getOpcode() == ISD::SUB && isNullConstant(LHS.getOperand(0)) &&
3481 isIntEqualitySetCC(CC)) {
3482 // As we are looking for EQ/NE compares, the operands can be commuted ; can
3483 // we combine a (CMP (sub 0, op1), op2) into a CMN instruction ?
3484 Opcode = AArch64ISD::ADDS;
3485 LHS = LHS.getOperand(1);
3486 } else if (isNullConstant(RHS) && !isUnsignedIntSetCC(CC)) {
3487 if (LHS.getOpcode() == ISD::AND) {
3488 // Similarly, (CMP (and X, Y), 0) can be implemented with a TST
3489 // (a.k.a. ANDS) except that the flags are only guaranteed to work for one
3490 // of the signed comparisons.
3491 const SDValue ANDSNode = DAG.getNode(AArch64ISD::ANDS, dl,
3492 DAG.getVTList(VT, MVT_CC),
3493 LHS.getOperand(0),
3494 LHS.getOperand(1));
3495 // Replace all users of (and X, Y) with newly generated (ands X, Y)
3496 DAG.ReplaceAllUsesWith(LHS, ANDSNode);
3497 return ANDSNode.getValue(1);
3498 } else if (LHS.getOpcode() == AArch64ISD::ANDS) {
3499 // Use result of ANDS
3500 return LHS.getValue(1);
3501 }
3502 }
3503
3504 return DAG.getNode(Opcode, dl, DAG.getVTList(VT, MVT_CC), LHS, RHS)
3505 .getValue(1);
3506 }
3507
3508 /// \defgroup AArch64CCMP CMP;CCMP matching
3509 ///
3510 /// These functions deal with the formation of CMP;CCMP;... sequences.
3511 /// The CCMP/CCMN/FCCMP/FCCMPE instructions allow the conditional execution of
3512 /// a comparison. They set the NZCV flags to a predefined value if their
3513 /// predicate is false. This allows to express arbitrary conjunctions, for
3514 /// example "cmp 0 (and (setCA (cmp A)) (setCB (cmp B)))"
3515 /// expressed as:
3516 /// cmp A
3517 /// ccmp B, inv(CB), CA
3518 /// check for CB flags
3519 ///
3520 /// This naturally lets us implement chains of AND operations with SETCC
3521 /// operands. And we can even implement some other situations by transforming
3522 /// them:
3523 /// - We can implement (NEG SETCC) i.e. negating a single comparison by
3524 /// negating the flags used in a CCMP/FCCMP operations.
3525 /// - We can negate the result of a whole chain of CMP/CCMP/FCCMP operations
3526 /// by negating the flags we test for afterwards. i.e.
3527 /// NEG (CMP CCMP CCCMP ...) can be implemented.
3528 /// - Note that we can only ever negate all previously processed results.
3529 /// What we can not implement by flipping the flags to test is a negation
3530 /// of two sub-trees (because the negation affects all sub-trees emitted so
3531 /// far, so the 2nd sub-tree we emit would also affect the first).
3532 /// With those tools we can implement some OR operations:
3533 /// - (OR (SETCC A) (SETCC B)) can be implemented via:
3534 /// NEG (AND (NEG (SETCC A)) (NEG (SETCC B)))
3535 /// - After transforming OR to NEG/AND combinations we may be able to use NEG
3536 /// elimination rules from earlier to implement the whole thing as a
3537 /// CCMP/FCCMP chain.
3538 ///
3539 /// As complete example:
3540 /// or (or (setCA (cmp A)) (setCB (cmp B)))
3541 /// (and (setCC (cmp C)) (setCD (cmp D)))"
3542 /// can be reassociated to:
3543 /// or (and (setCC (cmp C)) setCD (cmp D))
3544 // (or (setCA (cmp A)) (setCB (cmp B)))
3545 /// can be transformed to:
3546 /// not (and (not (and (setCC (cmp C)) (setCD (cmp D))))
3547 /// (and (not (setCA (cmp A)) (not (setCB (cmp B))))))"
3548 /// which can be implemented as:
3549 /// cmp C
3550 /// ccmp D, inv(CD), CC
3551 /// ccmp A, CA, inv(CD)
3552 /// ccmp B, CB, inv(CA)
3553 /// check for CB flags
3554 ///
3555 /// A counterexample is "or (and A B) (and C D)" which translates to
3556 /// not (and (not (and (not A) (not B))) (not (and (not C) (not D)))), we
3557 /// can only implement 1 of the inner (not) operations, but not both!
3558 /// @{
3559
3560 /// Create a conditional comparison; Use CCMP, CCMN or FCCMP as appropriate.
emitConditionalComparison(SDValue LHS,SDValue RHS,ISD::CondCode CC,SDValue CCOp,AArch64CC::CondCode Predicate,AArch64CC::CondCode OutCC,const SDLoc & DL,SelectionDAG & DAG)3561 static SDValue emitConditionalComparison(SDValue LHS, SDValue RHS,
3562 ISD::CondCode CC, SDValue CCOp,
3563 AArch64CC::CondCode Predicate,
3564 AArch64CC::CondCode OutCC,
3565 const SDLoc &DL, SelectionDAG &DAG) {
3566 unsigned Opcode = 0;
3567 const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
3568
3569 if (LHS.getValueType().isFloatingPoint()) {
3570 assert(LHS.getValueType() != MVT::f128);
3571 if ((LHS.getValueType() == MVT::f16 && !FullFP16) ||
3572 LHS.getValueType() == MVT::bf16) {
3573 LHS = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, LHS);
3574 RHS = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, RHS);
3575 }
3576 Opcode = AArch64ISD::FCCMP;
3577 } else if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(RHS)) {
3578 APInt Imm = Const->getAPIntValue();
3579 if (Imm.isNegative() && Imm.sgt(-32)) {
3580 Opcode = AArch64ISD::CCMN;
3581 RHS = DAG.getConstant(Imm.abs(), DL, Const->getValueType(0));
3582 }
3583 } else if (isCMN(RHS, CC, DAG)) {
3584 Opcode = AArch64ISD::CCMN;
3585 RHS = RHS.getOperand(1);
3586 } else if (LHS.getOpcode() == ISD::SUB && isNullConstant(LHS.getOperand(0)) &&
3587 isIntEqualitySetCC(CC)) {
3588 // As we are looking for EQ/NE compares, the operands can be commuted ; can
3589 // we combine a (CCMP (sub 0, op1), op2) into a CCMN instruction ?
3590 Opcode = AArch64ISD::CCMN;
3591 LHS = LHS.getOperand(1);
3592 }
3593 if (Opcode == 0)
3594 Opcode = AArch64ISD::CCMP;
3595
3596 SDValue Condition = DAG.getConstant(Predicate, DL, MVT_CC);
3597 AArch64CC::CondCode InvOutCC = AArch64CC::getInvertedCondCode(OutCC);
3598 unsigned NZCV = AArch64CC::getNZCVToSatisfyCondCode(InvOutCC);
3599 SDValue NZCVOp = DAG.getConstant(NZCV, DL, MVT::i32);
3600 return DAG.getNode(Opcode, DL, MVT_CC, LHS, RHS, NZCVOp, Condition, CCOp);
3601 }
3602
3603 /// Returns true if @p Val is a tree of AND/OR/SETCC operations that can be
3604 /// expressed as a conjunction. See \ref AArch64CCMP.
3605 /// \param CanNegate Set to true if we can negate the whole sub-tree just by
3606 /// changing the conditions on the SETCC tests.
3607 /// (this means we can call emitConjunctionRec() with
3608 /// Negate==true on this sub-tree)
3609 /// \param MustBeFirst Set to true if this subtree needs to be negated and we
3610 /// cannot do the negation naturally. We are required to
3611 /// emit the subtree first in this case.
3612 /// \param WillNegate Is true if are called when the result of this
3613 /// subexpression must be negated. This happens when the
3614 /// outer expression is an OR. We can use this fact to know
3615 /// that we have a double negation (or (or ...) ...) that
3616 /// can be implemented for free.
canEmitConjunction(const SDValue Val,bool & CanNegate,bool & MustBeFirst,bool WillNegate,unsigned Depth=0)3617 static bool canEmitConjunction(const SDValue Val, bool &CanNegate,
3618 bool &MustBeFirst, bool WillNegate,
3619 unsigned Depth = 0) {
3620 if (!Val.hasOneUse())
3621 return false;
3622 unsigned Opcode = Val->getOpcode();
3623 if (Opcode == ISD::SETCC) {
3624 if (Val->getOperand(0).getValueType() == MVT::f128)
3625 return false;
3626 CanNegate = true;
3627 MustBeFirst = false;
3628 return true;
3629 }
3630 // Protect against exponential runtime and stack overflow.
3631 if (Depth > 6)
3632 return false;
3633 if (Opcode == ISD::AND || Opcode == ISD::OR) {
3634 bool IsOR = Opcode == ISD::OR;
3635 SDValue O0 = Val->getOperand(0);
3636 SDValue O1 = Val->getOperand(1);
3637 bool CanNegateL;
3638 bool MustBeFirstL;
3639 if (!canEmitConjunction(O0, CanNegateL, MustBeFirstL, IsOR, Depth+1))
3640 return false;
3641 bool CanNegateR;
3642 bool MustBeFirstR;
3643 if (!canEmitConjunction(O1, CanNegateR, MustBeFirstR, IsOR, Depth+1))
3644 return false;
3645
3646 if (MustBeFirstL && MustBeFirstR)
3647 return false;
3648
3649 if (IsOR) {
3650 // For an OR expression we need to be able to naturally negate at least
3651 // one side or we cannot do the transformation at all.
3652 if (!CanNegateL && !CanNegateR)
3653 return false;
3654 // If we the result of the OR will be negated and we can naturally negate
3655 // the leafs, then this sub-tree as a whole negates naturally.
3656 CanNegate = WillNegate && CanNegateL && CanNegateR;
3657 // If we cannot naturally negate the whole sub-tree, then this must be
3658 // emitted first.
3659 MustBeFirst = !CanNegate;
3660 } else {
3661 assert(Opcode == ISD::AND && "Must be OR or AND");
3662 // We cannot naturally negate an AND operation.
3663 CanNegate = false;
3664 MustBeFirst = MustBeFirstL || MustBeFirstR;
3665 }
3666 return true;
3667 }
3668 return false;
3669 }
3670
3671 /// Emit conjunction or disjunction tree with the CMP/FCMP followed by a chain
3672 /// of CCMP/CFCMP ops. See @ref AArch64CCMP.
3673 /// Tries to transform the given i1 producing node @p Val to a series compare
3674 /// and conditional compare operations. @returns an NZCV flags producing node
3675 /// and sets @p OutCC to the flags that should be tested or returns SDValue() if
3676 /// transformation was not possible.
3677 /// \p Negate is true if we want this sub-tree being negated just by changing
3678 /// SETCC conditions.
emitConjunctionRec(SelectionDAG & DAG,SDValue Val,AArch64CC::CondCode & OutCC,bool Negate,SDValue CCOp,AArch64CC::CondCode Predicate)3679 static SDValue emitConjunctionRec(SelectionDAG &DAG, SDValue Val,
3680 AArch64CC::CondCode &OutCC, bool Negate, SDValue CCOp,
3681 AArch64CC::CondCode Predicate) {
3682 // We're at a tree leaf, produce a conditional comparison operation.
3683 unsigned Opcode = Val->getOpcode();
3684 if (Opcode == ISD::SETCC) {
3685 SDValue LHS = Val->getOperand(0);
3686 SDValue RHS = Val->getOperand(1);
3687 ISD::CondCode CC = cast<CondCodeSDNode>(Val->getOperand(2))->get();
3688 bool isInteger = LHS.getValueType().isInteger();
3689 if (Negate)
3690 CC = getSetCCInverse(CC, LHS.getValueType());
3691 SDLoc DL(Val);
3692 // Determine OutCC and handle FP special case.
3693 if (isInteger) {
3694 OutCC = changeIntCCToAArch64CC(CC);
3695 } else {
3696 assert(LHS.getValueType().isFloatingPoint());
3697 AArch64CC::CondCode ExtraCC;
3698 changeFPCCToANDAArch64CC(CC, OutCC, ExtraCC);
3699 // Some floating point conditions can't be tested with a single condition
3700 // code. Construct an additional comparison in this case.
3701 if (ExtraCC != AArch64CC::AL) {
3702 SDValue ExtraCmp;
3703 if (!CCOp.getNode())
3704 ExtraCmp = emitComparison(LHS, RHS, CC, DL, DAG);
3705 else
3706 ExtraCmp = emitConditionalComparison(LHS, RHS, CC, CCOp, Predicate,
3707 ExtraCC, DL, DAG);
3708 CCOp = ExtraCmp;
3709 Predicate = ExtraCC;
3710 }
3711 }
3712
3713 // Produce a normal comparison if we are first in the chain
3714 if (!CCOp)
3715 return emitComparison(LHS, RHS, CC, DL, DAG);
3716 // Otherwise produce a ccmp.
3717 return emitConditionalComparison(LHS, RHS, CC, CCOp, Predicate, OutCC, DL,
3718 DAG);
3719 }
3720 assert(Val->hasOneUse() && "Valid conjunction/disjunction tree");
3721
3722 bool IsOR = Opcode == ISD::OR;
3723
3724 SDValue LHS = Val->getOperand(0);
3725 bool CanNegateL;
3726 bool MustBeFirstL;
3727 bool ValidL = canEmitConjunction(LHS, CanNegateL, MustBeFirstL, IsOR);
3728 assert(ValidL && "Valid conjunction/disjunction tree");
3729 (void)ValidL;
3730
3731 SDValue RHS = Val->getOperand(1);
3732 bool CanNegateR;
3733 bool MustBeFirstR;
3734 bool ValidR = canEmitConjunction(RHS, CanNegateR, MustBeFirstR, IsOR);
3735 assert(ValidR && "Valid conjunction/disjunction tree");
3736 (void)ValidR;
3737
3738 // Swap sub-tree that must come first to the right side.
3739 if (MustBeFirstL) {
3740 assert(!MustBeFirstR && "Valid conjunction/disjunction tree");
3741 std::swap(LHS, RHS);
3742 std::swap(CanNegateL, CanNegateR);
3743 std::swap(MustBeFirstL, MustBeFirstR);
3744 }
3745
3746 bool NegateR;
3747 bool NegateAfterR;
3748 bool NegateL;
3749 bool NegateAfterAll;
3750 if (Opcode == ISD::OR) {
3751 // Swap the sub-tree that we can negate naturally to the left.
3752 if (!CanNegateL) {
3753 assert(CanNegateR && "at least one side must be negatable");
3754 assert(!MustBeFirstR && "invalid conjunction/disjunction tree");
3755 assert(!Negate);
3756 std::swap(LHS, RHS);
3757 NegateR = false;
3758 NegateAfterR = true;
3759 } else {
3760 // Negate the left sub-tree if possible, otherwise negate the result.
3761 NegateR = CanNegateR;
3762 NegateAfterR = !CanNegateR;
3763 }
3764 NegateL = true;
3765 NegateAfterAll = !Negate;
3766 } else {
3767 assert(Opcode == ISD::AND && "Valid conjunction/disjunction tree");
3768 assert(!Negate && "Valid conjunction/disjunction tree");
3769
3770 NegateL = false;
3771 NegateR = false;
3772 NegateAfterR = false;
3773 NegateAfterAll = false;
3774 }
3775
3776 // Emit sub-trees.
3777 AArch64CC::CondCode RHSCC;
3778 SDValue CmpR = emitConjunctionRec(DAG, RHS, RHSCC, NegateR, CCOp, Predicate);
3779 if (NegateAfterR)
3780 RHSCC = AArch64CC::getInvertedCondCode(RHSCC);
3781 SDValue CmpL = emitConjunctionRec(DAG, LHS, OutCC, NegateL, CmpR, RHSCC);
3782 if (NegateAfterAll)
3783 OutCC = AArch64CC::getInvertedCondCode(OutCC);
3784 return CmpL;
3785 }
3786
3787 /// Emit expression as a conjunction (a series of CCMP/CFCMP ops).
3788 /// In some cases this is even possible with OR operations in the expression.
3789 /// See \ref AArch64CCMP.
3790 /// \see emitConjunctionRec().
emitConjunction(SelectionDAG & DAG,SDValue Val,AArch64CC::CondCode & OutCC)3791 static SDValue emitConjunction(SelectionDAG &DAG, SDValue Val,
3792 AArch64CC::CondCode &OutCC) {
3793 bool DummyCanNegate;
3794 bool DummyMustBeFirst;
3795 if (!canEmitConjunction(Val, DummyCanNegate, DummyMustBeFirst, false))
3796 return SDValue();
3797
3798 return emitConjunctionRec(DAG, Val, OutCC, false, SDValue(), AArch64CC::AL);
3799 }
3800
3801 /// @}
3802
3803 /// Returns how profitable it is to fold a comparison's operand's shift and/or
3804 /// extension operations.
getCmpOperandFoldingProfit(SDValue Op)3805 static unsigned getCmpOperandFoldingProfit(SDValue Op) {
3806 auto isSupportedExtend = [&](SDValue V) {
3807 if (V.getOpcode() == ISD::SIGN_EXTEND_INREG)
3808 return true;
3809
3810 if (V.getOpcode() == ISD::AND)
3811 if (ConstantSDNode *MaskCst = dyn_cast<ConstantSDNode>(V.getOperand(1))) {
3812 uint64_t Mask = MaskCst->getZExtValue();
3813 return (Mask == 0xFF || Mask == 0xFFFF || Mask == 0xFFFFFFFF);
3814 }
3815
3816 return false;
3817 };
3818
3819 if (!Op.hasOneUse())
3820 return 0;
3821
3822 if (isSupportedExtend(Op))
3823 return 1;
3824
3825 unsigned Opc = Op.getOpcode();
3826 if (Opc == ISD::SHL || Opc == ISD::SRL || Opc == ISD::SRA)
3827 if (ConstantSDNode *ShiftCst = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
3828 uint64_t Shift = ShiftCst->getZExtValue();
3829 if (isSupportedExtend(Op.getOperand(0)))
3830 return (Shift <= 4) ? 2 : 1;
3831 EVT VT = Op.getValueType();
3832 if ((VT == MVT::i32 && Shift <= 31) || (VT == MVT::i64 && Shift <= 63))
3833 return 1;
3834 }
3835
3836 return 0;
3837 }
3838
getAArch64Cmp(SDValue LHS,SDValue RHS,ISD::CondCode CC,SDValue & AArch64cc,SelectionDAG & DAG,const SDLoc & dl)3839 static SDValue getAArch64Cmp(SDValue LHS, SDValue RHS, ISD::CondCode CC,
3840 SDValue &AArch64cc, SelectionDAG &DAG,
3841 const SDLoc &dl) {
3842 if (ConstantSDNode *RHSC = dyn_cast<ConstantSDNode>(RHS.getNode())) {
3843 EVT VT = RHS.getValueType();
3844 uint64_t C = RHSC->getZExtValue();
3845 if (!isLegalArithImmed(C)) {
3846 // Constant does not fit, try adjusting it by one?
3847 switch (CC) {
3848 default:
3849 break;
3850 case ISD::SETLT:
3851 case ISD::SETGE:
3852 if ((VT == MVT::i32 && C != 0x80000000 &&
3853 isLegalArithImmed((uint32_t)(C - 1))) ||
3854 (VT == MVT::i64 && C != 0x80000000ULL &&
3855 isLegalArithImmed(C - 1ULL))) {
3856 CC = (CC == ISD::SETLT) ? ISD::SETLE : ISD::SETGT;
3857 C = (VT == MVT::i32) ? (uint32_t)(C - 1) : C - 1;
3858 RHS = DAG.getConstant(C, dl, VT);
3859 }
3860 break;
3861 case ISD::SETULT:
3862 case ISD::SETUGE:
3863 if ((VT == MVT::i32 && C != 0 &&
3864 isLegalArithImmed((uint32_t)(C - 1))) ||
3865 (VT == MVT::i64 && C != 0ULL && isLegalArithImmed(C - 1ULL))) {
3866 CC = (CC == ISD::SETULT) ? ISD::SETULE : ISD::SETUGT;
3867 C = (VT == MVT::i32) ? (uint32_t)(C - 1) : C - 1;
3868 RHS = DAG.getConstant(C, dl, VT);
3869 }
3870 break;
3871 case ISD::SETLE:
3872 case ISD::SETGT:
3873 if ((VT == MVT::i32 && C != INT32_MAX &&
3874 isLegalArithImmed((uint32_t)(C + 1))) ||
3875 (VT == MVT::i64 && C != INT64_MAX &&
3876 isLegalArithImmed(C + 1ULL))) {
3877 CC = (CC == ISD::SETLE) ? ISD::SETLT : ISD::SETGE;
3878 C = (VT == MVT::i32) ? (uint32_t)(C + 1) : C + 1;
3879 RHS = DAG.getConstant(C, dl, VT);
3880 }
3881 break;
3882 case ISD::SETULE:
3883 case ISD::SETUGT:
3884 if ((VT == MVT::i32 && C != UINT32_MAX &&
3885 isLegalArithImmed((uint32_t)(C + 1))) ||
3886 (VT == MVT::i64 && C != UINT64_MAX &&
3887 isLegalArithImmed(C + 1ULL))) {
3888 CC = (CC == ISD::SETULE) ? ISD::SETULT : ISD::SETUGE;
3889 C = (VT == MVT::i32) ? (uint32_t)(C + 1) : C + 1;
3890 RHS = DAG.getConstant(C, dl, VT);
3891 }
3892 break;
3893 }
3894 }
3895 }
3896
3897 // Comparisons are canonicalized so that the RHS operand is simpler than the
3898 // LHS one, the extreme case being when RHS is an immediate. However, AArch64
3899 // can fold some shift+extend operations on the RHS operand, so swap the
3900 // operands if that can be done.
3901 //
3902 // For example:
3903 // lsl w13, w11, #1
3904 // cmp w13, w12
3905 // can be turned into:
3906 // cmp w12, w11, lsl #1
3907 if (!isa<ConstantSDNode>(RHS) ||
3908 !isLegalArithImmed(RHS->getAsAPIntVal().abs().getZExtValue())) {
3909 bool LHSIsCMN = isCMN(LHS, CC, DAG);
3910 bool RHSIsCMN = isCMN(RHS, CC, DAG);
3911 SDValue TheLHS = LHSIsCMN ? LHS.getOperand(1) : LHS;
3912 SDValue TheRHS = RHSIsCMN ? RHS.getOperand(1) : RHS;
3913
3914 if (getCmpOperandFoldingProfit(TheLHS) + (LHSIsCMN ? 1 : 0) >
3915 getCmpOperandFoldingProfit(TheRHS) + (RHSIsCMN ? 1 : 0)) {
3916 std::swap(LHS, RHS);
3917 CC = ISD::getSetCCSwappedOperands(CC);
3918 }
3919 }
3920
3921 SDValue Cmp;
3922 AArch64CC::CondCode AArch64CC;
3923 if (isIntEqualitySetCC(CC) && isa<ConstantSDNode>(RHS)) {
3924 const ConstantSDNode *RHSC = cast<ConstantSDNode>(RHS);
3925
3926 // The imm operand of ADDS is an unsigned immediate, in the range 0 to 4095.
3927 // For the i8 operand, the largest immediate is 255, so this can be easily
3928 // encoded in the compare instruction. For the i16 operand, however, the
3929 // largest immediate cannot be encoded in the compare.
3930 // Therefore, use a sign extending load and cmn to avoid materializing the
3931 // -1 constant. For example,
3932 // movz w1, #65535
3933 // ldrh w0, [x0, #0]
3934 // cmp w0, w1
3935 // >
3936 // ldrsh w0, [x0, #0]
3937 // cmn w0, #1
3938 // Fundamental, we're relying on the property that (zext LHS) == (zext RHS)
3939 // if and only if (sext LHS) == (sext RHS). The checks are in place to
3940 // ensure both the LHS and RHS are truly zero extended and to make sure the
3941 // transformation is profitable.
3942 if ((RHSC->getZExtValue() >> 16 == 0) && isa<LoadSDNode>(LHS) &&
3943 cast<LoadSDNode>(LHS)->getExtensionType() == ISD::ZEXTLOAD &&
3944 cast<LoadSDNode>(LHS)->getMemoryVT() == MVT::i16 &&
3945 LHS.getNode()->hasNUsesOfValue(1, 0)) {
3946 int16_t ValueofRHS = RHS->getAsZExtVal();
3947 if (ValueofRHS < 0 && isLegalArithImmed(-ValueofRHS)) {
3948 SDValue SExt =
3949 DAG.getNode(ISD::SIGN_EXTEND_INREG, dl, LHS.getValueType(), LHS,
3950 DAG.getValueType(MVT::i16));
3951 Cmp = emitComparison(SExt, DAG.getConstant(ValueofRHS, dl,
3952 RHS.getValueType()),
3953 CC, dl, DAG);
3954 AArch64CC = changeIntCCToAArch64CC(CC);
3955 }
3956 }
3957
3958 if (!Cmp && (RHSC->isZero() || RHSC->isOne())) {
3959 if ((Cmp = emitConjunction(DAG, LHS, AArch64CC))) {
3960 if ((CC == ISD::SETNE) ^ RHSC->isZero())
3961 AArch64CC = AArch64CC::getInvertedCondCode(AArch64CC);
3962 }
3963 }
3964 }
3965
3966 if (!Cmp) {
3967 Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
3968 AArch64CC = changeIntCCToAArch64CC(CC);
3969 }
3970 AArch64cc = DAG.getConstant(AArch64CC, dl, MVT_CC);
3971 return Cmp;
3972 }
3973
3974 static std::pair<SDValue, SDValue>
getAArch64XALUOOp(AArch64CC::CondCode & CC,SDValue Op,SelectionDAG & DAG)3975 getAArch64XALUOOp(AArch64CC::CondCode &CC, SDValue Op, SelectionDAG &DAG) {
3976 assert((Op.getValueType() == MVT::i32 || Op.getValueType() == MVT::i64) &&
3977 "Unsupported value type");
3978 SDValue Value, Overflow;
3979 SDLoc DL(Op);
3980 SDValue LHS = Op.getOperand(0);
3981 SDValue RHS = Op.getOperand(1);
3982 unsigned Opc = 0;
3983 switch (Op.getOpcode()) {
3984 default:
3985 llvm_unreachable("Unknown overflow instruction!");
3986 case ISD::SADDO:
3987 Opc = AArch64ISD::ADDS;
3988 CC = AArch64CC::VS;
3989 break;
3990 case ISD::UADDO:
3991 Opc = AArch64ISD::ADDS;
3992 CC = AArch64CC::HS;
3993 break;
3994 case ISD::SSUBO:
3995 Opc = AArch64ISD::SUBS;
3996 CC = AArch64CC::VS;
3997 break;
3998 case ISD::USUBO:
3999 Opc = AArch64ISD::SUBS;
4000 CC = AArch64CC::LO;
4001 break;
4002 // Multiply needs a little bit extra work.
4003 case ISD::SMULO:
4004 case ISD::UMULO: {
4005 CC = AArch64CC::NE;
4006 bool IsSigned = Op.getOpcode() == ISD::SMULO;
4007 if (Op.getValueType() == MVT::i32) {
4008 // Extend to 64-bits, then perform a 64-bit multiply.
4009 unsigned ExtendOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
4010 LHS = DAG.getNode(ExtendOpc, DL, MVT::i64, LHS);
4011 RHS = DAG.getNode(ExtendOpc, DL, MVT::i64, RHS);
4012 SDValue Mul = DAG.getNode(ISD::MUL, DL, MVT::i64, LHS, RHS);
4013 Value = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, Mul);
4014
4015 // Check that the result fits into a 32-bit integer.
4016 SDVTList VTs = DAG.getVTList(MVT::i64, MVT_CC);
4017 if (IsSigned) {
4018 // cmp xreg, wreg, sxtw
4019 SDValue SExtMul = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, Value);
4020 Overflow =
4021 DAG.getNode(AArch64ISD::SUBS, DL, VTs, Mul, SExtMul).getValue(1);
4022 } else {
4023 // tst xreg, #0xffffffff00000000
4024 SDValue UpperBits = DAG.getConstant(0xFFFFFFFF00000000, DL, MVT::i64);
4025 Overflow =
4026 DAG.getNode(AArch64ISD::ANDS, DL, VTs, Mul, UpperBits).getValue(1);
4027 }
4028 break;
4029 }
4030 assert(Op.getValueType() == MVT::i64 && "Expected an i64 value type");
4031 // For the 64 bit multiply
4032 Value = DAG.getNode(ISD::MUL, DL, MVT::i64, LHS, RHS);
4033 if (IsSigned) {
4034 SDValue UpperBits = DAG.getNode(ISD::MULHS, DL, MVT::i64, LHS, RHS);
4035 SDValue LowerBits = DAG.getNode(ISD::SRA, DL, MVT::i64, Value,
4036 DAG.getConstant(63, DL, MVT::i64));
4037 // It is important that LowerBits is last, otherwise the arithmetic
4038 // shift will not be folded into the compare (SUBS).
4039 SDVTList VTs = DAG.getVTList(MVT::i64, MVT::i32);
4040 Overflow = DAG.getNode(AArch64ISD::SUBS, DL, VTs, UpperBits, LowerBits)
4041 .getValue(1);
4042 } else {
4043 SDValue UpperBits = DAG.getNode(ISD::MULHU, DL, MVT::i64, LHS, RHS);
4044 SDVTList VTs = DAG.getVTList(MVT::i64, MVT::i32);
4045 Overflow =
4046 DAG.getNode(AArch64ISD::SUBS, DL, VTs,
4047 DAG.getConstant(0, DL, MVT::i64),
4048 UpperBits).getValue(1);
4049 }
4050 break;
4051 }
4052 } // switch (...)
4053
4054 if (Opc) {
4055 SDVTList VTs = DAG.getVTList(Op->getValueType(0), MVT::i32);
4056
4057 // Emit the AArch64 operation with overflow check.
4058 Value = DAG.getNode(Opc, DL, VTs, LHS, RHS);
4059 Overflow = Value.getValue(1);
4060 }
4061 return std::make_pair(Value, Overflow);
4062 }
4063
LowerXOR(SDValue Op,SelectionDAG & DAG) const4064 SDValue AArch64TargetLowering::LowerXOR(SDValue Op, SelectionDAG &DAG) const {
4065 if (useSVEForFixedLengthVectorVT(Op.getValueType(),
4066 !Subtarget->isNeonAvailable()))
4067 return LowerToScalableOp(Op, DAG);
4068
4069 SDValue Sel = Op.getOperand(0);
4070 SDValue Other = Op.getOperand(1);
4071 SDLoc dl(Sel);
4072
4073 // If the operand is an overflow checking operation, invert the condition
4074 // code and kill the Not operation. I.e., transform:
4075 // (xor (overflow_op_bool, 1))
4076 // -->
4077 // (csel 1, 0, invert(cc), overflow_op_bool)
4078 // ... which later gets transformed to just a cset instruction with an
4079 // inverted condition code, rather than a cset + eor sequence.
4080 if (isOneConstant(Other) && ISD::isOverflowIntrOpRes(Sel)) {
4081 // Only lower legal XALUO ops.
4082 if (!DAG.getTargetLoweringInfo().isTypeLegal(Sel->getValueType(0)))
4083 return SDValue();
4084
4085 SDValue TVal = DAG.getConstant(1, dl, MVT::i32);
4086 SDValue FVal = DAG.getConstant(0, dl, MVT::i32);
4087 AArch64CC::CondCode CC;
4088 SDValue Value, Overflow;
4089 std::tie(Value, Overflow) = getAArch64XALUOOp(CC, Sel.getValue(0), DAG);
4090 SDValue CCVal = DAG.getConstant(getInvertedCondCode(CC), dl, MVT::i32);
4091 return DAG.getNode(AArch64ISD::CSEL, dl, Op.getValueType(), TVal, FVal,
4092 CCVal, Overflow);
4093 }
4094 // If neither operand is a SELECT_CC, give up.
4095 if (Sel.getOpcode() != ISD::SELECT_CC)
4096 std::swap(Sel, Other);
4097 if (Sel.getOpcode() != ISD::SELECT_CC)
4098 return Op;
4099
4100 // The folding we want to perform is:
4101 // (xor x, (select_cc a, b, cc, 0, -1) )
4102 // -->
4103 // (csel x, (xor x, -1), cc ...)
4104 //
4105 // The latter will get matched to a CSINV instruction.
4106
4107 ISD::CondCode CC = cast<CondCodeSDNode>(Sel.getOperand(4))->get();
4108 SDValue LHS = Sel.getOperand(0);
4109 SDValue RHS = Sel.getOperand(1);
4110 SDValue TVal = Sel.getOperand(2);
4111 SDValue FVal = Sel.getOperand(3);
4112
4113 // FIXME: This could be generalized to non-integer comparisons.
4114 if (LHS.getValueType() != MVT::i32 && LHS.getValueType() != MVT::i64)
4115 return Op;
4116
4117 ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(FVal);
4118 ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(TVal);
4119
4120 // The values aren't constants, this isn't the pattern we're looking for.
4121 if (!CFVal || !CTVal)
4122 return Op;
4123
4124 // We can commute the SELECT_CC by inverting the condition. This
4125 // might be needed to make this fit into a CSINV pattern.
4126 if (CTVal->isAllOnes() && CFVal->isZero()) {
4127 std::swap(TVal, FVal);
4128 std::swap(CTVal, CFVal);
4129 CC = ISD::getSetCCInverse(CC, LHS.getValueType());
4130 }
4131
4132 // If the constants line up, perform the transform!
4133 if (CTVal->isZero() && CFVal->isAllOnes()) {
4134 SDValue CCVal;
4135 SDValue Cmp = getAArch64Cmp(LHS, RHS, CC, CCVal, DAG, dl);
4136
4137 FVal = Other;
4138 TVal = DAG.getNode(ISD::XOR, dl, Other.getValueType(), Other,
4139 DAG.getConstant(-1ULL, dl, Other.getValueType()));
4140
4141 return DAG.getNode(AArch64ISD::CSEL, dl, Sel.getValueType(), FVal, TVal,
4142 CCVal, Cmp);
4143 }
4144
4145 return Op;
4146 }
4147
4148 // If Invert is false, sets 'C' bit of NZCV to 0 if value is 0, else sets 'C'
4149 // bit to 1. If Invert is true, sets 'C' bit of NZCV to 1 if value is 0, else
4150 // sets 'C' bit to 0.
valueToCarryFlag(SDValue Value,SelectionDAG & DAG,bool Invert)4151 static SDValue valueToCarryFlag(SDValue Value, SelectionDAG &DAG, bool Invert) {
4152 SDLoc DL(Value);
4153 EVT VT = Value.getValueType();
4154 SDValue Op0 = Invert ? DAG.getConstant(0, DL, VT) : Value;
4155 SDValue Op1 = Invert ? Value : DAG.getConstant(1, DL, VT);
4156 SDValue Cmp =
4157 DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, MVT::Glue), Op0, Op1);
4158 return Cmp.getValue(1);
4159 }
4160
4161 // If Invert is false, value is 1 if 'C' bit of NZCV is 1, else 0.
4162 // If Invert is true, value is 0 if 'C' bit of NZCV is 1, else 1.
carryFlagToValue(SDValue Glue,EVT VT,SelectionDAG & DAG,bool Invert)4163 static SDValue carryFlagToValue(SDValue Glue, EVT VT, SelectionDAG &DAG,
4164 bool Invert) {
4165 assert(Glue.getResNo() == 1);
4166 SDLoc DL(Glue);
4167 SDValue Zero = DAG.getConstant(0, DL, VT);
4168 SDValue One = DAG.getConstant(1, DL, VT);
4169 unsigned Cond = Invert ? AArch64CC::LO : AArch64CC::HS;
4170 SDValue CC = DAG.getConstant(Cond, DL, MVT::i32);
4171 return DAG.getNode(AArch64ISD::CSEL, DL, VT, One, Zero, CC, Glue);
4172 }
4173
4174 // Value is 1 if 'V' bit of NZCV is 1, else 0
overflowFlagToValue(SDValue Glue,EVT VT,SelectionDAG & DAG)4175 static SDValue overflowFlagToValue(SDValue Glue, EVT VT, SelectionDAG &DAG) {
4176 assert(Glue.getResNo() == 1);
4177 SDLoc DL(Glue);
4178 SDValue Zero = DAG.getConstant(0, DL, VT);
4179 SDValue One = DAG.getConstant(1, DL, VT);
4180 SDValue CC = DAG.getConstant(AArch64CC::VS, DL, MVT::i32);
4181 return DAG.getNode(AArch64ISD::CSEL, DL, VT, One, Zero, CC, Glue);
4182 }
4183
4184 // This lowering is inefficient, but it will get cleaned up by
4185 // `foldOverflowCheck`
lowerADDSUBO_CARRY(SDValue Op,SelectionDAG & DAG,unsigned Opcode,bool IsSigned)4186 static SDValue lowerADDSUBO_CARRY(SDValue Op, SelectionDAG &DAG,
4187 unsigned Opcode, bool IsSigned) {
4188 EVT VT0 = Op.getValue(0).getValueType();
4189 EVT VT1 = Op.getValue(1).getValueType();
4190
4191 if (VT0 != MVT::i32 && VT0 != MVT::i64)
4192 return SDValue();
4193
4194 bool InvertCarry = Opcode == AArch64ISD::SBCS;
4195 SDValue OpLHS = Op.getOperand(0);
4196 SDValue OpRHS = Op.getOperand(1);
4197 SDValue OpCarryIn = valueToCarryFlag(Op.getOperand(2), DAG, InvertCarry);
4198
4199 SDLoc DL(Op);
4200 SDVTList VTs = DAG.getVTList(VT0, VT1);
4201
4202 SDValue Sum = DAG.getNode(Opcode, DL, DAG.getVTList(VT0, MVT::Glue), OpLHS,
4203 OpRHS, OpCarryIn);
4204
4205 SDValue OutFlag =
4206 IsSigned ? overflowFlagToValue(Sum.getValue(1), VT1, DAG)
4207 : carryFlagToValue(Sum.getValue(1), VT1, DAG, InvertCarry);
4208
4209 return DAG.getNode(ISD::MERGE_VALUES, DL, VTs, Sum, OutFlag);
4210 }
4211
LowerXALUO(SDValue Op,SelectionDAG & DAG)4212 static SDValue LowerXALUO(SDValue Op, SelectionDAG &DAG) {
4213 // Let legalize expand this if it isn't a legal type yet.
4214 if (!DAG.getTargetLoweringInfo().isTypeLegal(Op.getValueType()))
4215 return SDValue();
4216
4217 SDLoc dl(Op);
4218 AArch64CC::CondCode CC;
4219 // The actual operation that sets the overflow or carry flag.
4220 SDValue Value, Overflow;
4221 std::tie(Value, Overflow) = getAArch64XALUOOp(CC, Op, DAG);
4222
4223 // We use 0 and 1 as false and true values.
4224 SDValue TVal = DAG.getConstant(1, dl, MVT::i32);
4225 SDValue FVal = DAG.getConstant(0, dl, MVT::i32);
4226
4227 // We use an inverted condition, because the conditional select is inverted
4228 // too. This will allow it to be selected to a single instruction:
4229 // CSINC Wd, WZR, WZR, invert(cond).
4230 SDValue CCVal = DAG.getConstant(getInvertedCondCode(CC), dl, MVT::i32);
4231 Overflow = DAG.getNode(AArch64ISD::CSEL, dl, MVT::i32, FVal, TVal,
4232 CCVal, Overflow);
4233
4234 SDVTList VTs = DAG.getVTList(Op.getValueType(), MVT::i32);
4235 return DAG.getNode(ISD::MERGE_VALUES, dl, VTs, Value, Overflow);
4236 }
4237
4238 // Prefetch operands are:
4239 // 1: Address to prefetch
4240 // 2: bool isWrite
4241 // 3: int locality (0 = no locality ... 3 = extreme locality)
4242 // 4: bool isDataCache
LowerPREFETCH(SDValue Op,SelectionDAG & DAG)4243 static SDValue LowerPREFETCH(SDValue Op, SelectionDAG &DAG) {
4244 SDLoc DL(Op);
4245 unsigned IsWrite = Op.getConstantOperandVal(2);
4246 unsigned Locality = Op.getConstantOperandVal(3);
4247 unsigned IsData = Op.getConstantOperandVal(4);
4248
4249 bool IsStream = !Locality;
4250 // When the locality number is set
4251 if (Locality) {
4252 // The front-end should have filtered out the out-of-range values
4253 assert(Locality <= 3 && "Prefetch locality out-of-range");
4254 // The locality degree is the opposite of the cache speed.
4255 // Put the number the other way around.
4256 // The encoding starts at 0 for level 1
4257 Locality = 3 - Locality;
4258 }
4259
4260 // built the mask value encoding the expected behavior.
4261 unsigned PrfOp = (IsWrite << 4) | // Load/Store bit
4262 (!IsData << 3) | // IsDataCache bit
4263 (Locality << 1) | // Cache level bits
4264 (unsigned)IsStream; // Stream bit
4265 return DAG.getNode(AArch64ISD::PREFETCH, DL, MVT::Other, Op.getOperand(0),
4266 DAG.getTargetConstant(PrfOp, DL, MVT::i32),
4267 Op.getOperand(1));
4268 }
4269
LowerFP_EXTEND(SDValue Op,SelectionDAG & DAG) const4270 SDValue AArch64TargetLowering::LowerFP_EXTEND(SDValue Op,
4271 SelectionDAG &DAG) const {
4272 EVT VT = Op.getValueType();
4273 if (VT.isScalableVector())
4274 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_EXTEND_MERGE_PASSTHRU);
4275
4276 if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
4277 return LowerFixedLengthFPExtendToSVE(Op, DAG);
4278
4279 assert(Op.getValueType() == MVT::f128 && "Unexpected lowering");
4280 return SDValue();
4281 }
4282
LowerFP_ROUND(SDValue Op,SelectionDAG & DAG) const4283 SDValue AArch64TargetLowering::LowerFP_ROUND(SDValue Op,
4284 SelectionDAG &DAG) const {
4285 EVT VT = Op.getValueType();
4286 if (VT.isScalableVector())
4287 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FP_ROUND_MERGE_PASSTHRU);
4288
4289 bool IsStrict = Op->isStrictFPOpcode();
4290 SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
4291 EVT SrcVT = SrcVal.getValueType();
4292 bool Trunc = Op.getConstantOperandVal(IsStrict ? 2 : 1) == 1;
4293
4294 if (useSVEForFixedLengthVectorVT(SrcVT, !Subtarget->isNeonAvailable()))
4295 return LowerFixedLengthFPRoundToSVE(Op, DAG);
4296
4297 // Expand cases where the result type is BF16 but we don't have hardware
4298 // instructions to lower it.
4299 if (VT.getScalarType() == MVT::bf16 &&
4300 !((Subtarget->hasNEON() || Subtarget->hasSME()) &&
4301 Subtarget->hasBF16())) {
4302 SDLoc dl(Op);
4303 SDValue Narrow = SrcVal;
4304 SDValue NaN;
4305 EVT I32 = SrcVT.changeElementType(MVT::i32);
4306 EVT F32 = SrcVT.changeElementType(MVT::f32);
4307 if (SrcVT.getScalarType() == MVT::f32) {
4308 bool NeverSNaN = DAG.isKnownNeverSNaN(Narrow);
4309 Narrow = DAG.getNode(ISD::BITCAST, dl, I32, Narrow);
4310 if (!NeverSNaN) {
4311 // Set the quiet bit.
4312 NaN = DAG.getNode(ISD::OR, dl, I32, Narrow,
4313 DAG.getConstant(0x400000, dl, I32));
4314 }
4315 } else if (SrcVT.getScalarType() == MVT::f64) {
4316 Narrow = DAG.getNode(AArch64ISD::FCVTXN, dl, F32, Narrow);
4317 Narrow = DAG.getNode(ISD::BITCAST, dl, I32, Narrow);
4318 } else {
4319 return SDValue();
4320 }
4321 if (!Trunc) {
4322 SDValue One = DAG.getConstant(1, dl, I32);
4323 SDValue Lsb = DAG.getNode(ISD::SRL, dl, I32, Narrow,
4324 DAG.getShiftAmountConstant(16, I32, dl));
4325 Lsb = DAG.getNode(ISD::AND, dl, I32, Lsb, One);
4326 SDValue RoundingBias =
4327 DAG.getNode(ISD::ADD, dl, I32, DAG.getConstant(0x7fff, dl, I32), Lsb);
4328 Narrow = DAG.getNode(ISD::ADD, dl, I32, Narrow, RoundingBias);
4329 }
4330
4331 // Don't round if we had a NaN, we don't want to turn 0x7fffffff into
4332 // 0x80000000.
4333 if (NaN) {
4334 SDValue IsNaN = DAG.getSetCC(
4335 dl, getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), SrcVT),
4336 SrcVal, SrcVal, ISD::SETUO);
4337 Narrow = DAG.getSelect(dl, I32, IsNaN, NaN, Narrow);
4338 }
4339
4340 // Now that we have rounded, shift the bits into position.
4341 Narrow = DAG.getNode(ISD::SRL, dl, I32, Narrow,
4342 DAG.getShiftAmountConstant(16, I32, dl));
4343 if (VT.isVector()) {
4344 EVT I16 = I32.changeVectorElementType(MVT::i16);
4345 Narrow = DAG.getNode(ISD::TRUNCATE, dl, I16, Narrow);
4346 return DAG.getNode(ISD::BITCAST, dl, VT, Narrow);
4347 }
4348 Narrow = DAG.getNode(ISD::BITCAST, dl, F32, Narrow);
4349 SDValue Result = DAG.getTargetExtractSubreg(AArch64::hsub, dl, VT, Narrow);
4350 return IsStrict ? DAG.getMergeValues({Result, Op.getOperand(0)}, dl)
4351 : Result;
4352 }
4353
4354 if (SrcVT != MVT::f128) {
4355 // Expand cases where the input is a vector bigger than NEON.
4356 if (useSVEForFixedLengthVectorVT(SrcVT))
4357 return SDValue();
4358
4359 // It's legal except when f128 is involved
4360 return Op;
4361 }
4362
4363 return SDValue();
4364 }
4365
LowerVectorFP_TO_INT(SDValue Op,SelectionDAG & DAG) const4366 SDValue AArch64TargetLowering::LowerVectorFP_TO_INT(SDValue Op,
4367 SelectionDAG &DAG) const {
4368 // Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp.
4369 // Any additional optimization in this function should be recorded
4370 // in the cost tables.
4371 bool IsStrict = Op->isStrictFPOpcode();
4372 EVT InVT = Op.getOperand(IsStrict ? 1 : 0).getValueType();
4373 EVT VT = Op.getValueType();
4374
4375 if (VT.isScalableVector()) {
4376 unsigned Opcode = Op.getOpcode() == ISD::FP_TO_UINT
4377 ? AArch64ISD::FCVTZU_MERGE_PASSTHRU
4378 : AArch64ISD::FCVTZS_MERGE_PASSTHRU;
4379 return LowerToPredicatedOp(Op, DAG, Opcode);
4380 }
4381
4382 if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()) ||
4383 useSVEForFixedLengthVectorVT(InVT, !Subtarget->isNeonAvailable()))
4384 return LowerFixedLengthFPToIntToSVE(Op, DAG);
4385
4386 unsigned NumElts = InVT.getVectorNumElements();
4387
4388 // f16 conversions are promoted to f32 when full fp16 is not supported.
4389 if ((InVT.getVectorElementType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
4390 InVT.getVectorElementType() == MVT::bf16) {
4391 MVT NewVT = MVT::getVectorVT(MVT::f32, NumElts);
4392 SDLoc dl(Op);
4393 if (IsStrict) {
4394 SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {NewVT, MVT::Other},
4395 {Op.getOperand(0), Op.getOperand(1)});
4396 return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
4397 {Ext.getValue(1), Ext.getValue(0)});
4398 }
4399 return DAG.getNode(
4400 Op.getOpcode(), dl, Op.getValueType(),
4401 DAG.getNode(ISD::FP_EXTEND, dl, NewVT, Op.getOperand(0)));
4402 }
4403
4404 uint64_t VTSize = VT.getFixedSizeInBits();
4405 uint64_t InVTSize = InVT.getFixedSizeInBits();
4406 if (VTSize < InVTSize) {
4407 SDLoc dl(Op);
4408 if (IsStrict) {
4409 InVT = InVT.changeVectorElementTypeToInteger();
4410 SDValue Cv = DAG.getNode(Op.getOpcode(), dl, {InVT, MVT::Other},
4411 {Op.getOperand(0), Op.getOperand(1)});
4412 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, dl, VT, Cv);
4413 return DAG.getMergeValues({Trunc, Cv.getValue(1)}, dl);
4414 }
4415 SDValue Cv =
4416 DAG.getNode(Op.getOpcode(), dl, InVT.changeVectorElementTypeToInteger(),
4417 Op.getOperand(0));
4418 return DAG.getNode(ISD::TRUNCATE, dl, VT, Cv);
4419 }
4420
4421 if (VTSize > InVTSize) {
4422 SDLoc dl(Op);
4423 MVT ExtVT =
4424 MVT::getVectorVT(MVT::getFloatingPointVT(VT.getScalarSizeInBits()),
4425 VT.getVectorNumElements());
4426 if (IsStrict) {
4427 SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {ExtVT, MVT::Other},
4428 {Op.getOperand(0), Op.getOperand(1)});
4429 return DAG.getNode(Op.getOpcode(), dl, {VT, MVT::Other},
4430 {Ext.getValue(1), Ext.getValue(0)});
4431 }
4432 SDValue Ext = DAG.getNode(ISD::FP_EXTEND, dl, ExtVT, Op.getOperand(0));
4433 return DAG.getNode(Op.getOpcode(), dl, VT, Ext);
4434 }
4435
4436 // Use a scalar operation for conversions between single-element vectors of
4437 // the same size.
4438 if (NumElts == 1) {
4439 SDLoc dl(Op);
4440 SDValue Extract = DAG.getNode(
4441 ISD::EXTRACT_VECTOR_ELT, dl, InVT.getScalarType(),
4442 Op.getOperand(IsStrict ? 1 : 0), DAG.getConstant(0, dl, MVT::i64));
4443 EVT ScalarVT = VT.getScalarType();
4444 if (IsStrict)
4445 return DAG.getNode(Op.getOpcode(), dl, {ScalarVT, MVT::Other},
4446 {Op.getOperand(0), Extract});
4447 return DAG.getNode(Op.getOpcode(), dl, ScalarVT, Extract);
4448 }
4449
4450 // Type changing conversions are illegal.
4451 return Op;
4452 }
4453
LowerFP_TO_INT(SDValue Op,SelectionDAG & DAG) const4454 SDValue AArch64TargetLowering::LowerFP_TO_INT(SDValue Op,
4455 SelectionDAG &DAG) const {
4456 bool IsStrict = Op->isStrictFPOpcode();
4457 SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
4458
4459 if (SrcVal.getValueType().isVector())
4460 return LowerVectorFP_TO_INT(Op, DAG);
4461
4462 // f16 conversions are promoted to f32 when full fp16 is not supported.
4463 if ((SrcVal.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
4464 SrcVal.getValueType() == MVT::bf16) {
4465 SDLoc dl(Op);
4466 if (IsStrict) {
4467 SDValue Ext =
4468 DAG.getNode(ISD::STRICT_FP_EXTEND, dl, {MVT::f32, MVT::Other},
4469 {Op.getOperand(0), SrcVal});
4470 return DAG.getNode(Op.getOpcode(), dl, {Op.getValueType(), MVT::Other},
4471 {Ext.getValue(1), Ext.getValue(0)});
4472 }
4473 return DAG.getNode(
4474 Op.getOpcode(), dl, Op.getValueType(),
4475 DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, SrcVal));
4476 }
4477
4478 if (SrcVal.getValueType() != MVT::f128) {
4479 // It's legal except when f128 is involved
4480 return Op;
4481 }
4482
4483 return SDValue();
4484 }
4485
4486 SDValue
LowerVectorFP_TO_INT_SAT(SDValue Op,SelectionDAG & DAG) const4487 AArch64TargetLowering::LowerVectorFP_TO_INT_SAT(SDValue Op,
4488 SelectionDAG &DAG) const {
4489 // AArch64 FP-to-int conversions saturate to the destination element size, so
4490 // we can lower common saturating conversions to simple instructions.
4491 SDValue SrcVal = Op.getOperand(0);
4492 EVT SrcVT = SrcVal.getValueType();
4493 EVT DstVT = Op.getValueType();
4494 EVT SatVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
4495
4496 uint64_t SrcElementWidth = SrcVT.getScalarSizeInBits();
4497 uint64_t DstElementWidth = DstVT.getScalarSizeInBits();
4498 uint64_t SatWidth = SatVT.getScalarSizeInBits();
4499 assert(SatWidth <= DstElementWidth &&
4500 "Saturation width cannot exceed result width");
4501
4502 // TODO: Consider lowering to SVE operations, as in LowerVectorFP_TO_INT.
4503 // Currently, the `llvm.fpto[su]i.sat.*` intrinsics don't accept scalable
4504 // types, so this is hard to reach.
4505 if (DstVT.isScalableVector())
4506 return SDValue();
4507
4508 EVT SrcElementVT = SrcVT.getVectorElementType();
4509
4510 // In the absence of FP16 support, promote f16 to f32 and saturate the result.
4511 if ((SrcElementVT == MVT::f16 &&
4512 (!Subtarget->hasFullFP16() || DstElementWidth > 16)) ||
4513 SrcElementVT == MVT::bf16) {
4514 MVT F32VT = MVT::getVectorVT(MVT::f32, SrcVT.getVectorNumElements());
4515 SrcVal = DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), F32VT, SrcVal);
4516 SrcVT = F32VT;
4517 SrcElementVT = MVT::f32;
4518 SrcElementWidth = 32;
4519 } else if (SrcElementVT != MVT::f64 && SrcElementVT != MVT::f32 &&
4520 SrcElementVT != MVT::f16 && SrcElementVT != MVT::bf16)
4521 return SDValue();
4522
4523 SDLoc DL(Op);
4524 // Expand to f64 if we are saturating to i64, to help produce keep the lanes
4525 // the same width and produce a fcvtzu.
4526 if (SatWidth == 64 && SrcElementWidth < 64) {
4527 MVT F64VT = MVT::getVectorVT(MVT::f64, SrcVT.getVectorNumElements());
4528 SrcVal = DAG.getNode(ISD::FP_EXTEND, DL, F64VT, SrcVal);
4529 SrcVT = F64VT;
4530 SrcElementVT = MVT::f64;
4531 SrcElementWidth = 64;
4532 }
4533 // Cases that we can emit directly.
4534 if (SrcElementWidth == DstElementWidth && SrcElementWidth == SatWidth)
4535 return DAG.getNode(Op.getOpcode(), DL, DstVT, SrcVal,
4536 DAG.getValueType(DstVT.getScalarType()));
4537
4538 // Otherwise we emit a cvt that saturates to a higher BW, and saturate the
4539 // result. This is only valid if the legal cvt is larger than the saturate
4540 // width. For double, as we don't have MIN/MAX, it can be simpler to scalarize
4541 // (at least until sqxtn is selected).
4542 if (SrcElementWidth < SatWidth || SrcElementVT == MVT::f64)
4543 return SDValue();
4544
4545 EVT IntVT = SrcVT.changeVectorElementTypeToInteger();
4546 SDValue NativeCvt = DAG.getNode(Op.getOpcode(), DL, IntVT, SrcVal,
4547 DAG.getValueType(IntVT.getScalarType()));
4548 SDValue Sat;
4549 if (Op.getOpcode() == ISD::FP_TO_SINT_SAT) {
4550 SDValue MinC = DAG.getConstant(
4551 APInt::getSignedMaxValue(SatWidth).sext(SrcElementWidth), DL, IntVT);
4552 SDValue Min = DAG.getNode(ISD::SMIN, DL, IntVT, NativeCvt, MinC);
4553 SDValue MaxC = DAG.getConstant(
4554 APInt::getSignedMinValue(SatWidth).sext(SrcElementWidth), DL, IntVT);
4555 Sat = DAG.getNode(ISD::SMAX, DL, IntVT, Min, MaxC);
4556 } else {
4557 SDValue MinC = DAG.getConstant(
4558 APInt::getAllOnes(SatWidth).zext(SrcElementWidth), DL, IntVT);
4559 Sat = DAG.getNode(ISD::UMIN, DL, IntVT, NativeCvt, MinC);
4560 }
4561
4562 return DAG.getNode(ISD::TRUNCATE, DL, DstVT, Sat);
4563 }
4564
LowerFP_TO_INT_SAT(SDValue Op,SelectionDAG & DAG) const4565 SDValue AArch64TargetLowering::LowerFP_TO_INT_SAT(SDValue Op,
4566 SelectionDAG &DAG) const {
4567 // AArch64 FP-to-int conversions saturate to the destination register size, so
4568 // we can lower common saturating conversions to simple instructions.
4569 SDValue SrcVal = Op.getOperand(0);
4570 EVT SrcVT = SrcVal.getValueType();
4571
4572 if (SrcVT.isVector())
4573 return LowerVectorFP_TO_INT_SAT(Op, DAG);
4574
4575 EVT DstVT = Op.getValueType();
4576 EVT SatVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
4577 uint64_t SatWidth = SatVT.getScalarSizeInBits();
4578 uint64_t DstWidth = DstVT.getScalarSizeInBits();
4579 assert(SatWidth <= DstWidth && "Saturation width cannot exceed result width");
4580
4581 // In the absence of FP16 support, promote f16 to f32 and saturate the result.
4582 if ((SrcVT == MVT::f16 && !Subtarget->hasFullFP16()) || SrcVT == MVT::bf16) {
4583 SrcVal = DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, SrcVal);
4584 SrcVT = MVT::f32;
4585 } else if (SrcVT != MVT::f64 && SrcVT != MVT::f32 && SrcVT != MVT::f16 &&
4586 SrcVT != MVT::bf16)
4587 return SDValue();
4588
4589 SDLoc DL(Op);
4590 // Cases that we can emit directly.
4591 if ((SrcVT == MVT::f64 || SrcVT == MVT::f32 ||
4592 (SrcVT == MVT::f16 && Subtarget->hasFullFP16())) &&
4593 DstVT == SatVT && (DstVT == MVT::i64 || DstVT == MVT::i32))
4594 return DAG.getNode(Op.getOpcode(), DL, DstVT, SrcVal,
4595 DAG.getValueType(DstVT));
4596
4597 // Otherwise we emit a cvt that saturates to a higher BW, and saturate the
4598 // result. This is only valid if the legal cvt is larger than the saturate
4599 // width.
4600 if (DstWidth < SatWidth)
4601 return SDValue();
4602
4603 SDValue NativeCvt =
4604 DAG.getNode(Op.getOpcode(), DL, DstVT, SrcVal, DAG.getValueType(DstVT));
4605 SDValue Sat;
4606 if (Op.getOpcode() == ISD::FP_TO_SINT_SAT) {
4607 SDValue MinC = DAG.getConstant(
4608 APInt::getSignedMaxValue(SatWidth).sext(DstWidth), DL, DstVT);
4609 SDValue Min = DAG.getNode(ISD::SMIN, DL, DstVT, NativeCvt, MinC);
4610 SDValue MaxC = DAG.getConstant(
4611 APInt::getSignedMinValue(SatWidth).sext(DstWidth), DL, DstVT);
4612 Sat = DAG.getNode(ISD::SMAX, DL, DstVT, Min, MaxC);
4613 } else {
4614 SDValue MinC = DAG.getConstant(
4615 APInt::getAllOnes(SatWidth).zext(DstWidth), DL, DstVT);
4616 Sat = DAG.getNode(ISD::UMIN, DL, DstVT, NativeCvt, MinC);
4617 }
4618
4619 return DAG.getNode(ISD::TRUNCATE, DL, DstVT, Sat);
4620 }
4621
LowerVectorXRINT(SDValue Op,SelectionDAG & DAG) const4622 SDValue AArch64TargetLowering::LowerVectorXRINT(SDValue Op,
4623 SelectionDAG &DAG) const {
4624 EVT VT = Op.getValueType();
4625 SDValue Src = Op.getOperand(0);
4626 SDLoc DL(Op);
4627
4628 assert(VT.isVector() && "Expected vector type");
4629
4630 EVT CastVT =
4631 VT.changeVectorElementType(Src.getValueType().getVectorElementType());
4632
4633 // Round the floating-point value into a floating-point register with the
4634 // current rounding mode.
4635 SDValue FOp = DAG.getNode(ISD::FRINT, DL, CastVT, Src);
4636
4637 // Truncate the rounded floating point to an integer.
4638 return DAG.getNode(ISD::FP_TO_SINT_SAT, DL, VT, FOp,
4639 DAG.getValueType(VT.getVectorElementType()));
4640 }
4641
LowerVectorINT_TO_FP(SDValue Op,SelectionDAG & DAG) const4642 SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
4643 SelectionDAG &DAG) const {
4644 // Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp.
4645 // Any additional optimization in this function should be recorded
4646 // in the cost tables.
4647 bool IsStrict = Op->isStrictFPOpcode();
4648 EVT VT = Op.getValueType();
4649 SDLoc dl(Op);
4650 SDValue In = Op.getOperand(IsStrict ? 1 : 0);
4651 EVT InVT = In.getValueType();
4652 unsigned Opc = Op.getOpcode();
4653 bool IsSigned = Opc == ISD::SINT_TO_FP || Opc == ISD::STRICT_SINT_TO_FP;
4654
4655 if (VT.isScalableVector()) {
4656 if (InVT.getVectorElementType() == MVT::i1) {
4657 // We can't directly extend an SVE predicate; extend it first.
4658 unsigned CastOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
4659 EVT CastVT = getPromotedVTForPredicate(InVT);
4660 In = DAG.getNode(CastOpc, dl, CastVT, In);
4661 return DAG.getNode(Opc, dl, VT, In);
4662 }
4663
4664 unsigned Opcode = IsSigned ? AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU
4665 : AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU;
4666 return LowerToPredicatedOp(Op, DAG, Opcode);
4667 }
4668
4669 if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()) ||
4670 useSVEForFixedLengthVectorVT(InVT, !Subtarget->isNeonAvailable()))
4671 return LowerFixedLengthIntToFPToSVE(Op, DAG);
4672
4673 // Promote bf16 conversions to f32.
4674 if (VT.getVectorElementType() == MVT::bf16) {
4675 EVT F32 = VT.changeElementType(MVT::f32);
4676 if (IsStrict) {
4677 SDValue Val = DAG.getNode(Op.getOpcode(), dl, {F32, MVT::Other},
4678 {Op.getOperand(0), In});
4679 return DAG.getNode(
4680 ISD::STRICT_FP_ROUND, dl, {Op.getValueType(), MVT::Other},
4681 {Val.getValue(1), Val.getValue(0), DAG.getIntPtrConstant(0, dl)});
4682 }
4683 return DAG.getNode(ISD::FP_ROUND, dl, Op.getValueType(),
4684 DAG.getNode(Op.getOpcode(), dl, F32, In),
4685 DAG.getIntPtrConstant(0, dl));
4686 }
4687
4688 uint64_t VTSize = VT.getFixedSizeInBits();
4689 uint64_t InVTSize = InVT.getFixedSizeInBits();
4690 if (VTSize < InVTSize) {
4691 MVT CastVT =
4692 MVT::getVectorVT(MVT::getFloatingPointVT(InVT.getScalarSizeInBits()),
4693 InVT.getVectorNumElements());
4694 if (IsStrict) {
4695 In = DAG.getNode(Opc, dl, {CastVT, MVT::Other},
4696 {Op.getOperand(0), In});
4697 return DAG.getNode(
4698 ISD::STRICT_FP_ROUND, dl, {VT, MVT::Other},
4699 {In.getValue(1), In.getValue(0), DAG.getIntPtrConstant(0, dl)});
4700 }
4701 In = DAG.getNode(Opc, dl, CastVT, In);
4702 return DAG.getNode(ISD::FP_ROUND, dl, VT, In,
4703 DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
4704 }
4705
4706 if (VTSize > InVTSize) {
4707 unsigned CastOpc = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
4708 EVT CastVT = VT.changeVectorElementTypeToInteger();
4709 In = DAG.getNode(CastOpc, dl, CastVT, In);
4710 if (IsStrict)
4711 return DAG.getNode(Opc, dl, {VT, MVT::Other}, {Op.getOperand(0), In});
4712 return DAG.getNode(Opc, dl, VT, In);
4713 }
4714
4715 // Use a scalar operation for conversions between single-element vectors of
4716 // the same size.
4717 if (VT.getVectorNumElements() == 1) {
4718 SDValue Extract = DAG.getNode(
4719 ISD::EXTRACT_VECTOR_ELT, dl, InVT.getScalarType(),
4720 In, DAG.getConstant(0, dl, MVT::i64));
4721 EVT ScalarVT = VT.getScalarType();
4722 if (IsStrict)
4723 return DAG.getNode(Op.getOpcode(), dl, {ScalarVT, MVT::Other},
4724 {Op.getOperand(0), Extract});
4725 return DAG.getNode(Op.getOpcode(), dl, ScalarVT, Extract);
4726 }
4727
4728 return Op;
4729 }
4730
LowerINT_TO_FP(SDValue Op,SelectionDAG & DAG) const4731 SDValue AArch64TargetLowering::LowerINT_TO_FP(SDValue Op,
4732 SelectionDAG &DAG) const {
4733 if (Op.getValueType().isVector())
4734 return LowerVectorINT_TO_FP(Op, DAG);
4735
4736 bool IsStrict = Op->isStrictFPOpcode();
4737 SDValue SrcVal = Op.getOperand(IsStrict ? 1 : 0);
4738
4739 bool IsSigned = Op->getOpcode() == ISD::STRICT_SINT_TO_FP ||
4740 Op->getOpcode() == ISD::SINT_TO_FP;
4741
4742 auto IntToFpViaPromotion = [&](EVT PromoteVT) {
4743 SDLoc dl(Op);
4744 if (IsStrict) {
4745 SDValue Val = DAG.getNode(Op.getOpcode(), dl, {PromoteVT, MVT::Other},
4746 {Op.getOperand(0), SrcVal});
4747 return DAG.getNode(
4748 ISD::STRICT_FP_ROUND, dl, {Op.getValueType(), MVT::Other},
4749 {Val.getValue(1), Val.getValue(0), DAG.getIntPtrConstant(0, dl)});
4750 }
4751 return DAG.getNode(ISD::FP_ROUND, dl, Op.getValueType(),
4752 DAG.getNode(Op.getOpcode(), dl, PromoteVT, SrcVal),
4753 DAG.getIntPtrConstant(0, dl));
4754 };
4755
4756 if (Op.getValueType() == MVT::bf16) {
4757 unsigned MaxWidth = IsSigned
4758 ? DAG.ComputeMaxSignificantBits(SrcVal)
4759 : DAG.computeKnownBits(SrcVal).countMaxActiveBits();
4760 // bf16 conversions are promoted to f32 when converting from i16.
4761 if (MaxWidth <= 24) {
4762 return IntToFpViaPromotion(MVT::f32);
4763 }
4764
4765 // bf16 conversions are promoted to f64 when converting from i32.
4766 if (MaxWidth <= 53) {
4767 return IntToFpViaPromotion(MVT::f64);
4768 }
4769
4770 // We need to be careful about i64 -> bf16.
4771 // Consider an i32 22216703.
4772 // This number cannot be represented exactly as an f32 and so a itofp will
4773 // turn it into 22216704.0 fptrunc to bf16 will turn this into 22282240.0
4774 // However, the correct bf16 was supposed to be 22151168.0
4775 // We need to use sticky rounding to get this correct.
4776 if (SrcVal.getValueType() == MVT::i64) {
4777 SDLoc DL(Op);
4778 // This algorithm is equivalent to the following:
4779 // uint64_t SrcHi = SrcVal & ~0xfffull;
4780 // uint64_t SrcLo = SrcVal & 0xfffull;
4781 // uint64_t Highest = SrcVal >> 53;
4782 // bool HasHighest = Highest != 0;
4783 // uint64_t ToRound = HasHighest ? SrcHi : SrcVal;
4784 // double Rounded = static_cast<double>(ToRound);
4785 // uint64_t RoundedBits = std::bit_cast<uint64_t>(Rounded);
4786 // uint64_t HasLo = SrcLo != 0;
4787 // bool NeedsAdjustment = HasHighest & HasLo;
4788 // uint64_t AdjustedBits = RoundedBits | uint64_t{NeedsAdjustment};
4789 // double Adjusted = std::bit_cast<double>(AdjustedBits);
4790 // return static_cast<__bf16>(Adjusted);
4791 //
4792 // Essentially, what happens is that SrcVal either fits perfectly in a
4793 // double-precision value or it is too big. If it is sufficiently small,
4794 // we should just go u64 -> double -> bf16 in a naive way. Otherwise, we
4795 // ensure that u64 -> double has no rounding error by only using the 52
4796 // MSB of the input. The low order bits will get merged into a sticky bit
4797 // which will avoid issues incurred by double rounding.
4798
4799 // Signed conversion is more or less like so:
4800 // copysign((__bf16)abs(SrcVal), SrcVal)
4801 SDValue SignBit;
4802 if (IsSigned) {
4803 SignBit = DAG.getNode(ISD::AND, DL, MVT::i64, SrcVal,
4804 DAG.getConstant(1ull << 63, DL, MVT::i64));
4805 SrcVal = DAG.getNode(ISD::ABS, DL, MVT::i64, SrcVal);
4806 }
4807 SDValue SrcHi = DAG.getNode(ISD::AND, DL, MVT::i64, SrcVal,
4808 DAG.getConstant(~0xfffull, DL, MVT::i64));
4809 SDValue SrcLo = DAG.getNode(ISD::AND, DL, MVT::i64, SrcVal,
4810 DAG.getConstant(0xfffull, DL, MVT::i64));
4811 SDValue Highest =
4812 DAG.getNode(ISD::SRL, DL, MVT::i64, SrcVal,
4813 DAG.getShiftAmountConstant(53, MVT::i64, DL));
4814 SDValue Zero64 = DAG.getConstant(0, DL, MVT::i64);
4815 SDValue ToRound =
4816 DAG.getSelectCC(DL, Highest, Zero64, SrcHi, SrcVal, ISD::SETNE);
4817 SDValue Rounded =
4818 IsStrict ? DAG.getNode(Op.getOpcode(), DL, {MVT::f64, MVT::Other},
4819 {Op.getOperand(0), ToRound})
4820 : DAG.getNode(Op.getOpcode(), DL, MVT::f64, ToRound);
4821
4822 SDValue RoundedBits = DAG.getNode(ISD::BITCAST, DL, MVT::i64, Rounded);
4823 if (SignBit) {
4824 RoundedBits = DAG.getNode(ISD::OR, DL, MVT::i64, RoundedBits, SignBit);
4825 }
4826
4827 SDValue HasHighest = DAG.getSetCC(
4828 DL,
4829 getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), MVT::i64),
4830 Highest, Zero64, ISD::SETNE);
4831
4832 SDValue HasLo = DAG.getSetCC(
4833 DL,
4834 getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), MVT::i64),
4835 SrcLo, Zero64, ISD::SETNE);
4836
4837 SDValue NeedsAdjustment =
4838 DAG.getNode(ISD::AND, DL, HasLo.getValueType(), HasHighest, HasLo);
4839 NeedsAdjustment = DAG.getZExtOrTrunc(NeedsAdjustment, DL, MVT::i64);
4840
4841 SDValue AdjustedBits =
4842 DAG.getNode(ISD::OR, DL, MVT::i64, RoundedBits, NeedsAdjustment);
4843 SDValue Adjusted = DAG.getNode(ISD::BITCAST, DL, MVT::f64, AdjustedBits);
4844 return IsStrict
4845 ? DAG.getNode(ISD::STRICT_FP_ROUND, DL,
4846 {Op.getValueType(), MVT::Other},
4847 {Rounded.getValue(1), Adjusted,
4848 DAG.getIntPtrConstant(0, DL)})
4849 : DAG.getNode(ISD::FP_ROUND, DL, Op.getValueType(), Adjusted,
4850 DAG.getIntPtrConstant(0, DL, true));
4851 }
4852 }
4853
4854 // f16 conversions are promoted to f32 when full fp16 is not supported.
4855 if (Op.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) {
4856 return IntToFpViaPromotion(MVT::f32);
4857 }
4858
4859 // i128 conversions are libcalls.
4860 if (SrcVal.getValueType() == MVT::i128)
4861 return SDValue();
4862
4863 // Other conversions are legal, unless it's to the completely software-based
4864 // fp128.
4865 if (Op.getValueType() != MVT::f128)
4866 return Op;
4867 return SDValue();
4868 }
4869
LowerFSINCOS(SDValue Op,SelectionDAG & DAG) const4870 SDValue AArch64TargetLowering::LowerFSINCOS(SDValue Op,
4871 SelectionDAG &DAG) const {
4872 // For iOS, we want to call an alternative entry point: __sincos_stret,
4873 // which returns the values in two S / D registers.
4874 SDLoc dl(Op);
4875 SDValue Arg = Op.getOperand(0);
4876 EVT ArgVT = Arg.getValueType();
4877 Type *ArgTy = ArgVT.getTypeForEVT(*DAG.getContext());
4878
4879 ArgListTy Args;
4880 ArgListEntry Entry;
4881
4882 Entry.Node = Arg;
4883 Entry.Ty = ArgTy;
4884 Entry.IsSExt = false;
4885 Entry.IsZExt = false;
4886 Args.push_back(Entry);
4887
4888 RTLIB::Libcall LC = ArgVT == MVT::f64 ? RTLIB::SINCOS_STRET_F64
4889 : RTLIB::SINCOS_STRET_F32;
4890 const char *LibcallName = getLibcallName(LC);
4891 SDValue Callee =
4892 DAG.getExternalSymbol(LibcallName, getPointerTy(DAG.getDataLayout()));
4893
4894 StructType *RetTy = StructType::get(ArgTy, ArgTy);
4895 TargetLowering::CallLoweringInfo CLI(DAG);
4896 CLI.setDebugLoc(dl)
4897 .setChain(DAG.getEntryNode())
4898 .setLibCallee(CallingConv::Fast, RetTy, Callee, std::move(Args));
4899
4900 std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI);
4901 return CallResult.first;
4902 }
4903
4904 static MVT getSVEContainerType(EVT ContentTy);
4905
LowerBITCAST(SDValue Op,SelectionDAG & DAG) const4906 SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op,
4907 SelectionDAG &DAG) const {
4908 EVT OpVT = Op.getValueType();
4909 EVT ArgVT = Op.getOperand(0).getValueType();
4910
4911 if (useSVEForFixedLengthVectorVT(OpVT))
4912 return LowerFixedLengthBitcastToSVE(Op, DAG);
4913
4914 if (OpVT.isScalableVector()) {
4915 // Bitcasting between unpacked vector types of different element counts is
4916 // not a NOP because the live elements are laid out differently.
4917 // 01234567
4918 // e.g. nxv2i32 = XX??XX??
4919 // nxv4f16 = X?X?X?X?
4920 if (OpVT.getVectorElementCount() != ArgVT.getVectorElementCount())
4921 return SDValue();
4922
4923 if (isTypeLegal(OpVT) && !isTypeLegal(ArgVT)) {
4924 assert(OpVT.isFloatingPoint() && !ArgVT.isFloatingPoint() &&
4925 "Expected int->fp bitcast!");
4926 SDValue ExtResult =
4927 DAG.getNode(ISD::ANY_EXTEND, SDLoc(Op), getSVEContainerType(ArgVT),
4928 Op.getOperand(0));
4929 return getSVESafeBitCast(OpVT, ExtResult, DAG);
4930 }
4931 return getSVESafeBitCast(OpVT, Op.getOperand(0), DAG);
4932 }
4933
4934 if (OpVT != MVT::f16 && OpVT != MVT::bf16)
4935 return SDValue();
4936
4937 // Bitcasts between f16 and bf16 are legal.
4938 if (ArgVT == MVT::f16 || ArgVT == MVT::bf16)
4939 return Op;
4940
4941 assert(ArgVT == MVT::i16);
4942 SDLoc DL(Op);
4943
4944 Op = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op.getOperand(0));
4945 Op = DAG.getNode(ISD::BITCAST, DL, MVT::f32, Op);
4946 return DAG.getTargetExtractSubreg(AArch64::hsub, DL, OpVT, Op);
4947 }
4948
getExtensionTo64Bits(const EVT & OrigVT)4949 static EVT getExtensionTo64Bits(const EVT &OrigVT) {
4950 if (OrigVT.getSizeInBits() >= 64)
4951 return OrigVT;
4952
4953 assert(OrigVT.isSimple() && "Expecting a simple value type");
4954
4955 MVT::SimpleValueType OrigSimpleTy = OrigVT.getSimpleVT().SimpleTy;
4956 switch (OrigSimpleTy) {
4957 default: llvm_unreachable("Unexpected Vector Type");
4958 case MVT::v2i8:
4959 case MVT::v2i16:
4960 return MVT::v2i32;
4961 case MVT::v4i8:
4962 return MVT::v4i16;
4963 }
4964 }
4965
addRequiredExtensionForVectorMULL(SDValue N,SelectionDAG & DAG,const EVT & OrigTy,const EVT & ExtTy,unsigned ExtOpcode)4966 static SDValue addRequiredExtensionForVectorMULL(SDValue N, SelectionDAG &DAG,
4967 const EVT &OrigTy,
4968 const EVT &ExtTy,
4969 unsigned ExtOpcode) {
4970 // The vector originally had a size of OrigTy. It was then extended to ExtTy.
4971 // We expect the ExtTy to be 128-bits total. If the OrigTy is less than
4972 // 64-bits we need to insert a new extension so that it will be 64-bits.
4973 assert(ExtTy.is128BitVector() && "Unexpected extension size");
4974 if (OrigTy.getSizeInBits() >= 64)
4975 return N;
4976
4977 // Must extend size to at least 64 bits to be used as an operand for VMULL.
4978 EVT NewVT = getExtensionTo64Bits(OrigTy);
4979
4980 return DAG.getNode(ExtOpcode, SDLoc(N), NewVT, N);
4981 }
4982
4983 // Returns lane if Op extracts from a two-element vector and lane is constant
4984 // (i.e., extractelt(<2 x Ty> %v, ConstantLane)), and std::nullopt otherwise.
4985 static std::optional<uint64_t>
getConstantLaneNumOfExtractHalfOperand(SDValue & Op)4986 getConstantLaneNumOfExtractHalfOperand(SDValue &Op) {
4987 SDNode *OpNode = Op.getNode();
4988 if (OpNode->getOpcode() != ISD::EXTRACT_VECTOR_ELT)
4989 return std::nullopt;
4990
4991 EVT VT = OpNode->getOperand(0).getValueType();
4992 ConstantSDNode *C = dyn_cast<ConstantSDNode>(OpNode->getOperand(1));
4993 if (!VT.isFixedLengthVector() || VT.getVectorNumElements() != 2 || !C)
4994 return std::nullopt;
4995
4996 return C->getZExtValue();
4997 }
4998
isExtendedBUILD_VECTOR(SDValue N,SelectionDAG & DAG,bool isSigned)4999 static bool isExtendedBUILD_VECTOR(SDValue N, SelectionDAG &DAG,
5000 bool isSigned) {
5001 EVT VT = N.getValueType();
5002
5003 if (N.getOpcode() != ISD::BUILD_VECTOR)
5004 return false;
5005
5006 for (const SDValue &Elt : N->op_values()) {
5007 if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(Elt)) {
5008 unsigned EltSize = VT.getScalarSizeInBits();
5009 unsigned HalfSize = EltSize / 2;
5010 if (isSigned) {
5011 if (!isIntN(HalfSize, C->getSExtValue()))
5012 return false;
5013 } else {
5014 if (!isUIntN(HalfSize, C->getZExtValue()))
5015 return false;
5016 }
5017 continue;
5018 }
5019 return false;
5020 }
5021
5022 return true;
5023 }
5024
skipExtensionForVectorMULL(SDValue N,SelectionDAG & DAG)5025 static SDValue skipExtensionForVectorMULL(SDValue N, SelectionDAG &DAG) {
5026 EVT VT = N.getValueType();
5027 assert(VT.is128BitVector() && "Unexpected vector MULL size");
5028
5029 unsigned NumElts = VT.getVectorNumElements();
5030 unsigned OrigEltSize = VT.getScalarSizeInBits();
5031 unsigned EltSize = OrigEltSize / 2;
5032 MVT TruncVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize), NumElts);
5033
5034 APInt HiBits = APInt::getHighBitsSet(OrigEltSize, EltSize);
5035 if (DAG.MaskedValueIsZero(N, HiBits))
5036 return DAG.getNode(ISD::TRUNCATE, SDLoc(N), TruncVT, N);
5037
5038 if (ISD::isExtOpcode(N.getOpcode()))
5039 return addRequiredExtensionForVectorMULL(N.getOperand(0), DAG,
5040 N.getOperand(0).getValueType(), VT,
5041 N.getOpcode());
5042
5043 assert(N.getOpcode() == ISD::BUILD_VECTOR && "expected BUILD_VECTOR");
5044 SDLoc dl(N);
5045 SmallVector<SDValue, 8> Ops;
5046 for (unsigned i = 0; i != NumElts; ++i) {
5047 const APInt &CInt = N.getConstantOperandAPInt(i);
5048 // Element types smaller than 32 bits are not legal, so use i32 elements.
5049 // The values are implicitly truncated so sext vs. zext doesn't matter.
5050 Ops.push_back(DAG.getConstant(CInt.zextOrTrunc(32), dl, MVT::i32));
5051 }
5052 return DAG.getBuildVector(TruncVT, dl, Ops);
5053 }
5054
isSignExtended(SDValue N,SelectionDAG & DAG)5055 static bool isSignExtended(SDValue N, SelectionDAG &DAG) {
5056 return N.getOpcode() == ISD::SIGN_EXTEND ||
5057 N.getOpcode() == ISD::ANY_EXTEND ||
5058 isExtendedBUILD_VECTOR(N, DAG, true);
5059 }
5060
isZeroExtended(SDValue N,SelectionDAG & DAG)5061 static bool isZeroExtended(SDValue N, SelectionDAG &DAG) {
5062 return N.getOpcode() == ISD::ZERO_EXTEND ||
5063 N.getOpcode() == ISD::ANY_EXTEND ||
5064 isExtendedBUILD_VECTOR(N, DAG, false);
5065 }
5066
isAddSubSExt(SDValue N,SelectionDAG & DAG)5067 static bool isAddSubSExt(SDValue N, SelectionDAG &DAG) {
5068 unsigned Opcode = N.getOpcode();
5069 if (Opcode == ISD::ADD || Opcode == ISD::SUB) {
5070 SDValue N0 = N.getOperand(0);
5071 SDValue N1 = N.getOperand(1);
5072 return N0->hasOneUse() && N1->hasOneUse() &&
5073 isSignExtended(N0, DAG) && isSignExtended(N1, DAG);
5074 }
5075 return false;
5076 }
5077
isAddSubZExt(SDValue N,SelectionDAG & DAG)5078 static bool isAddSubZExt(SDValue N, SelectionDAG &DAG) {
5079 unsigned Opcode = N.getOpcode();
5080 if (Opcode == ISD::ADD || Opcode == ISD::SUB) {
5081 SDValue N0 = N.getOperand(0);
5082 SDValue N1 = N.getOperand(1);
5083 return N0->hasOneUse() && N1->hasOneUse() &&
5084 isZeroExtended(N0, DAG) && isZeroExtended(N1, DAG);
5085 }
5086 return false;
5087 }
5088
LowerGET_ROUNDING(SDValue Op,SelectionDAG & DAG) const5089 SDValue AArch64TargetLowering::LowerGET_ROUNDING(SDValue Op,
5090 SelectionDAG &DAG) const {
5091 // The rounding mode is in bits 23:22 of the FPSCR.
5092 // The ARM rounding mode value to FLT_ROUNDS mapping is 0->1, 1->2, 2->3, 3->0
5093 // The formula we use to implement this is (((FPSCR + 1 << 22) >> 22) & 3)
5094 // so that the shift + and get folded into a bitfield extract.
5095 SDLoc dl(Op);
5096
5097 SDValue Chain = Op.getOperand(0);
5098 SDValue FPCR_64 = DAG.getNode(
5099 ISD::INTRINSIC_W_CHAIN, dl, {MVT::i64, MVT::Other},
5100 {Chain, DAG.getConstant(Intrinsic::aarch64_get_fpcr, dl, MVT::i64)});
5101 Chain = FPCR_64.getValue(1);
5102 SDValue FPCR_32 = DAG.getNode(ISD::TRUNCATE, dl, MVT::i32, FPCR_64);
5103 SDValue FltRounds = DAG.getNode(ISD::ADD, dl, MVT::i32, FPCR_32,
5104 DAG.getConstant(1U << 22, dl, MVT::i32));
5105 SDValue RMODE = DAG.getNode(ISD::SRL, dl, MVT::i32, FltRounds,
5106 DAG.getConstant(22, dl, MVT::i32));
5107 SDValue AND = DAG.getNode(ISD::AND, dl, MVT::i32, RMODE,
5108 DAG.getConstant(3, dl, MVT::i32));
5109 return DAG.getMergeValues({AND, Chain}, dl);
5110 }
5111
LowerSET_ROUNDING(SDValue Op,SelectionDAG & DAG) const5112 SDValue AArch64TargetLowering::LowerSET_ROUNDING(SDValue Op,
5113 SelectionDAG &DAG) const {
5114 SDLoc DL(Op);
5115 SDValue Chain = Op->getOperand(0);
5116 SDValue RMValue = Op->getOperand(1);
5117
5118 // The rounding mode is in bits 23:22 of the FPCR.
5119 // The llvm.set.rounding argument value to the rounding mode in FPCR mapping
5120 // is 0->3, 1->0, 2->1, 3->2. The formula we use to implement this is
5121 // ((arg - 1) & 3) << 22).
5122 //
5123 // The argument of llvm.set.rounding must be within the segment [0, 3], so
5124 // NearestTiesToAway (4) is not handled here. It is responsibility of the code
5125 // generated llvm.set.rounding to ensure this condition.
5126
5127 // Calculate new value of FPCR[23:22].
5128 RMValue = DAG.getNode(ISD::SUB, DL, MVT::i32, RMValue,
5129 DAG.getConstant(1, DL, MVT::i32));
5130 RMValue = DAG.getNode(ISD::AND, DL, MVT::i32, RMValue,
5131 DAG.getConstant(0x3, DL, MVT::i32));
5132 RMValue =
5133 DAG.getNode(ISD::SHL, DL, MVT::i32, RMValue,
5134 DAG.getConstant(AArch64::RoundingBitsPos, DL, MVT::i32));
5135 RMValue = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, RMValue);
5136
5137 // Get current value of FPCR.
5138 SDValue Ops[] = {
5139 Chain, DAG.getTargetConstant(Intrinsic::aarch64_get_fpcr, DL, MVT::i64)};
5140 SDValue FPCR =
5141 DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL, {MVT::i64, MVT::Other}, Ops);
5142 Chain = FPCR.getValue(1);
5143 FPCR = FPCR.getValue(0);
5144
5145 // Put new rounding mode into FPSCR[23:22].
5146 const int RMMask = ~(AArch64::Rounding::rmMask << AArch64::RoundingBitsPos);
5147 FPCR = DAG.getNode(ISD::AND, DL, MVT::i64, FPCR,
5148 DAG.getConstant(RMMask, DL, MVT::i64));
5149 FPCR = DAG.getNode(ISD::OR, DL, MVT::i64, FPCR, RMValue);
5150 SDValue Ops2[] = {
5151 Chain, DAG.getTargetConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64),
5152 FPCR};
5153 return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2);
5154 }
5155
LowerGET_FPMODE(SDValue Op,SelectionDAG & DAG) const5156 SDValue AArch64TargetLowering::LowerGET_FPMODE(SDValue Op,
5157 SelectionDAG &DAG) const {
5158 SDLoc DL(Op);
5159 SDValue Chain = Op->getOperand(0);
5160
5161 // Get current value of FPCR.
5162 SDValue Ops[] = {
5163 Chain, DAG.getTargetConstant(Intrinsic::aarch64_get_fpcr, DL, MVT::i64)};
5164 SDValue FPCR =
5165 DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL, {MVT::i64, MVT::Other}, Ops);
5166 Chain = FPCR.getValue(1);
5167 FPCR = FPCR.getValue(0);
5168
5169 // Truncate FPCR to 32 bits.
5170 SDValue Result = DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, FPCR);
5171
5172 return DAG.getMergeValues({Result, Chain}, DL);
5173 }
5174
LowerSET_FPMODE(SDValue Op,SelectionDAG & DAG) const5175 SDValue AArch64TargetLowering::LowerSET_FPMODE(SDValue Op,
5176 SelectionDAG &DAG) const {
5177 SDLoc DL(Op);
5178 SDValue Chain = Op->getOperand(0);
5179 SDValue Mode = Op->getOperand(1);
5180
5181 // Extend the specified value to 64 bits.
5182 SDValue FPCR = DAG.getZExtOrTrunc(Mode, DL, MVT::i64);
5183
5184 // Set new value of FPCR.
5185 SDValue Ops2[] = {
5186 Chain, DAG.getConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64), FPCR};
5187 return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2);
5188 }
5189
LowerRESET_FPMODE(SDValue Op,SelectionDAG & DAG) const5190 SDValue AArch64TargetLowering::LowerRESET_FPMODE(SDValue Op,
5191 SelectionDAG &DAG) const {
5192 SDLoc DL(Op);
5193 SDValue Chain = Op->getOperand(0);
5194
5195 // Get current value of FPCR.
5196 SDValue Ops[] = {
5197 Chain, DAG.getTargetConstant(Intrinsic::aarch64_get_fpcr, DL, MVT::i64)};
5198 SDValue FPCR =
5199 DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL, {MVT::i64, MVT::Other}, Ops);
5200 Chain = FPCR.getValue(1);
5201 FPCR = FPCR.getValue(0);
5202
5203 // Clear bits that are not reserved.
5204 SDValue FPSCRMasked = DAG.getNode(
5205 ISD::AND, DL, MVT::i64, FPCR,
5206 DAG.getConstant(AArch64::ReservedFPControlBits, DL, MVT::i64));
5207
5208 // Set new value of FPCR.
5209 SDValue Ops2[] = {Chain,
5210 DAG.getConstant(Intrinsic::aarch64_set_fpcr, DL, MVT::i64),
5211 FPSCRMasked};
5212 return DAG.getNode(ISD::INTRINSIC_VOID, DL, MVT::Other, Ops2);
5213 }
5214
selectUmullSmull(SDValue & N0,SDValue & N1,SelectionDAG & DAG,SDLoc DL,bool & IsMLA)5215 static unsigned selectUmullSmull(SDValue &N0, SDValue &N1, SelectionDAG &DAG,
5216 SDLoc DL, bool &IsMLA) {
5217 bool IsN0SExt = isSignExtended(N0, DAG);
5218 bool IsN1SExt = isSignExtended(N1, DAG);
5219 if (IsN0SExt && IsN1SExt)
5220 return AArch64ISD::SMULL;
5221
5222 bool IsN0ZExt = isZeroExtended(N0, DAG);
5223 bool IsN1ZExt = isZeroExtended(N1, DAG);
5224
5225 if (IsN0ZExt && IsN1ZExt)
5226 return AArch64ISD::UMULL;
5227
5228 // Select SMULL if we can replace zext with sext.
5229 if (((IsN0SExt && IsN1ZExt) || (IsN0ZExt && IsN1SExt)) &&
5230 !isExtendedBUILD_VECTOR(N0, DAG, false) &&
5231 !isExtendedBUILD_VECTOR(N1, DAG, false)) {
5232 SDValue ZextOperand;
5233 if (IsN0ZExt)
5234 ZextOperand = N0.getOperand(0);
5235 else
5236 ZextOperand = N1.getOperand(0);
5237 if (DAG.SignBitIsZero(ZextOperand)) {
5238 SDValue NewSext =
5239 DAG.getSExtOrTrunc(ZextOperand, DL, N0.getValueType());
5240 if (IsN0ZExt)
5241 N0 = NewSext;
5242 else
5243 N1 = NewSext;
5244 return AArch64ISD::SMULL;
5245 }
5246 }
5247
5248 // Select UMULL if we can replace the other operand with an extend.
5249 if (IsN0ZExt || IsN1ZExt) {
5250 EVT VT = N0.getValueType();
5251 APInt Mask = APInt::getHighBitsSet(VT.getScalarSizeInBits(),
5252 VT.getScalarSizeInBits() / 2);
5253 if (DAG.MaskedValueIsZero(IsN0ZExt ? N1 : N0, Mask))
5254 return AArch64ISD::UMULL;
5255 }
5256
5257 if (!IsN1SExt && !IsN1ZExt)
5258 return 0;
5259
5260 // Look for (s/zext A + s/zext B) * (s/zext C). We want to turn these
5261 // into (s/zext A * s/zext C) + (s/zext B * s/zext C)
5262 if (IsN1SExt && isAddSubSExt(N0, DAG)) {
5263 IsMLA = true;
5264 return AArch64ISD::SMULL;
5265 }
5266 if (IsN1ZExt && isAddSubZExt(N0, DAG)) {
5267 IsMLA = true;
5268 return AArch64ISD::UMULL;
5269 }
5270 if (IsN0ZExt && isAddSubZExt(N1, DAG)) {
5271 std::swap(N0, N1);
5272 IsMLA = true;
5273 return AArch64ISD::UMULL;
5274 }
5275 return 0;
5276 }
5277
LowerMUL(SDValue Op,SelectionDAG & DAG) const5278 SDValue AArch64TargetLowering::LowerMUL(SDValue Op, SelectionDAG &DAG) const {
5279 EVT VT = Op.getValueType();
5280
5281 bool OverrideNEON = !Subtarget->isNeonAvailable();
5282 if (VT.isScalableVector() || useSVEForFixedLengthVectorVT(VT, OverrideNEON))
5283 return LowerToPredicatedOp(Op, DAG, AArch64ISD::MUL_PRED);
5284
5285 // Multiplications are only custom-lowered for 128-bit and 64-bit vectors so
5286 // that VMULL can be detected. Otherwise v2i64 multiplications are not legal.
5287 assert((VT.is128BitVector() || VT.is64BitVector()) && VT.isInteger() &&
5288 "unexpected type for custom-lowering ISD::MUL");
5289 SDValue N0 = Op.getOperand(0);
5290 SDValue N1 = Op.getOperand(1);
5291 bool isMLA = false;
5292 EVT OVT = VT;
5293 if (VT.is64BitVector()) {
5294 if (N0.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
5295 isNullConstant(N0.getOperand(1)) &&
5296 N1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
5297 isNullConstant(N1.getOperand(1))) {
5298 N0 = N0.getOperand(0);
5299 N1 = N1.getOperand(0);
5300 VT = N0.getValueType();
5301 } else {
5302 if (VT == MVT::v1i64) {
5303 if (Subtarget->hasSVE())
5304 return LowerToPredicatedOp(Op, DAG, AArch64ISD::MUL_PRED);
5305 // Fall through to expand this. It is not legal.
5306 return SDValue();
5307 } else
5308 // Other vector multiplications are legal.
5309 return Op;
5310 }
5311 }
5312
5313 SDLoc DL(Op);
5314 unsigned NewOpc = selectUmullSmull(N0, N1, DAG, DL, isMLA);
5315
5316 if (!NewOpc) {
5317 if (VT.getVectorElementType() == MVT::i64) {
5318 // If SVE is available then i64 vector multiplications can also be made
5319 // legal.
5320 if (Subtarget->hasSVE())
5321 return LowerToPredicatedOp(Op, DAG, AArch64ISD::MUL_PRED);
5322 // Fall through to expand this. It is not legal.
5323 return SDValue();
5324 } else
5325 // Other vector multiplications are legal.
5326 return Op;
5327 }
5328
5329 // Legalize to a S/UMULL instruction
5330 SDValue Op0;
5331 SDValue Op1 = skipExtensionForVectorMULL(N1, DAG);
5332 if (!isMLA) {
5333 Op0 = skipExtensionForVectorMULL(N0, DAG);
5334 assert(Op0.getValueType().is64BitVector() &&
5335 Op1.getValueType().is64BitVector() &&
5336 "unexpected types for extended operands to VMULL");
5337 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, OVT,
5338 DAG.getNode(NewOpc, DL, VT, Op0, Op1),
5339 DAG.getConstant(0, DL, MVT::i64));
5340 }
5341 // Optimizing (zext A + zext B) * C, to (S/UMULL A, C) + (S/UMULL B, C) during
5342 // isel lowering to take advantage of no-stall back to back s/umul + s/umla.
5343 // This is true for CPUs with accumulate forwarding such as Cortex-A53/A57
5344 SDValue N00 = skipExtensionForVectorMULL(N0.getOperand(0), DAG);
5345 SDValue N01 = skipExtensionForVectorMULL(N0.getOperand(1), DAG);
5346 EVT Op1VT = Op1.getValueType();
5347 return DAG.getNode(
5348 ISD::EXTRACT_SUBVECTOR, DL, OVT,
5349 DAG.getNode(N0.getOpcode(), DL, VT,
5350 DAG.getNode(NewOpc, DL, VT,
5351 DAG.getNode(ISD::BITCAST, DL, Op1VT, N00), Op1),
5352 DAG.getNode(NewOpc, DL, VT,
5353 DAG.getNode(ISD::BITCAST, DL, Op1VT, N01), Op1)),
5354 DAG.getConstant(0, DL, MVT::i64));
5355 }
5356
getPTrue(SelectionDAG & DAG,SDLoc DL,EVT VT,int Pattern)5357 static inline SDValue getPTrue(SelectionDAG &DAG, SDLoc DL, EVT VT,
5358 int Pattern) {
5359 if (VT == MVT::nxv1i1 && Pattern == AArch64SVEPredPattern::all)
5360 return DAG.getConstant(1, DL, MVT::nxv1i1);
5361 return DAG.getNode(AArch64ISD::PTRUE, DL, VT,
5362 DAG.getTargetConstant(Pattern, DL, MVT::i32));
5363 }
5364
optimizeIncrementingWhile(SDValue Op,SelectionDAG & DAG,bool IsSigned,bool IsEqual)5365 static SDValue optimizeIncrementingWhile(SDValue Op, SelectionDAG &DAG,
5366 bool IsSigned, bool IsEqual) {
5367 if (!isa<ConstantSDNode>(Op.getOperand(1)) ||
5368 !isa<ConstantSDNode>(Op.getOperand(2)))
5369 return SDValue();
5370
5371 SDLoc dl(Op);
5372 APInt X = Op.getConstantOperandAPInt(1);
5373 APInt Y = Op.getConstantOperandAPInt(2);
5374 bool Overflow;
5375 APInt NumActiveElems =
5376 IsSigned ? Y.ssub_ov(X, Overflow) : Y.usub_ov(X, Overflow);
5377
5378 if (Overflow)
5379 return SDValue();
5380
5381 if (IsEqual) {
5382 APInt One(NumActiveElems.getBitWidth(), 1, IsSigned);
5383 NumActiveElems = IsSigned ? NumActiveElems.sadd_ov(One, Overflow)
5384 : NumActiveElems.uadd_ov(One, Overflow);
5385 if (Overflow)
5386 return SDValue();
5387 }
5388
5389 std::optional<unsigned> PredPattern =
5390 getSVEPredPatternFromNumElements(NumActiveElems.getZExtValue());
5391 unsigned MinSVEVectorSize = std::max(
5392 DAG.getSubtarget<AArch64Subtarget>().getMinSVEVectorSizeInBits(), 128u);
5393 unsigned ElementSize = 128 / Op.getValueType().getVectorMinNumElements();
5394 if (PredPattern != std::nullopt &&
5395 NumActiveElems.getZExtValue() <= (MinSVEVectorSize / ElementSize))
5396 return getPTrue(DAG, dl, Op.getValueType(), *PredPattern);
5397
5398 return SDValue();
5399 }
5400
5401 // Returns a safe bitcast between two scalable vector predicates, where
5402 // any newly created lanes from a widening bitcast are defined as zero.
getSVEPredicateBitCast(EVT VT,SDValue Op,SelectionDAG & DAG)5403 static SDValue getSVEPredicateBitCast(EVT VT, SDValue Op, SelectionDAG &DAG) {
5404 SDLoc DL(Op);
5405 EVT InVT = Op.getValueType();
5406
5407 assert(InVT.getVectorElementType() == MVT::i1 &&
5408 VT.getVectorElementType() == MVT::i1 &&
5409 "Expected a predicate-to-predicate bitcast");
5410 assert(VT.isScalableVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
5411 InVT.isScalableVector() &&
5412 DAG.getTargetLoweringInfo().isTypeLegal(InVT) &&
5413 "Only expect to cast between legal scalable predicate types!");
5414
5415 // Return the operand if the cast isn't changing type,
5416 // e.g. <n x 16 x i1> -> <n x 16 x i1>
5417 if (InVT == VT)
5418 return Op;
5419
5420 SDValue Reinterpret = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
5421
5422 // We only have to zero the lanes if new lanes are being defined, e.g. when
5423 // casting from <vscale x 2 x i1> to <vscale x 16 x i1>. If this is not the
5424 // case (e.g. when casting from <vscale x 16 x i1> -> <vscale x 2 x i1>) then
5425 // we can return here.
5426 if (InVT.bitsGT(VT))
5427 return Reinterpret;
5428
5429 // Check if the other lanes are already known to be zeroed by
5430 // construction.
5431 if (isZeroingInactiveLanes(Op))
5432 return Reinterpret;
5433
5434 // Zero the newly introduced lanes.
5435 SDValue Mask = DAG.getConstant(1, DL, InVT);
5436 Mask = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Mask);
5437 return DAG.getNode(ISD::AND, DL, VT, Reinterpret, Mask);
5438 }
5439
getRuntimePStateSM(SelectionDAG & DAG,SDValue Chain,SDLoc DL,EVT VT) const5440 SDValue AArch64TargetLowering::getRuntimePStateSM(SelectionDAG &DAG,
5441 SDValue Chain, SDLoc DL,
5442 EVT VT) const {
5443 SDValue Callee = DAG.getExternalSymbol("__arm_sme_state",
5444 getPointerTy(DAG.getDataLayout()));
5445 Type *Int64Ty = Type::getInt64Ty(*DAG.getContext());
5446 Type *RetTy = StructType::get(Int64Ty, Int64Ty);
5447 TargetLowering::CallLoweringInfo CLI(DAG);
5448 ArgListTy Args;
5449 CLI.setDebugLoc(DL).setChain(Chain).setLibCallee(
5450 CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2,
5451 RetTy, Callee, std::move(Args));
5452 std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI);
5453 SDValue Mask = DAG.getConstant(/*PSTATE.SM*/ 1, DL, MVT::i64);
5454 return DAG.getNode(ISD::AND, DL, MVT::i64, CallResult.first.getOperand(0),
5455 Mask);
5456 }
5457
5458 // Lower an SME LDR/STR ZA intrinsic
5459 // Case 1: If the vector number (vecnum) is an immediate in range, it gets
5460 // folded into the instruction
5461 // ldr(%tileslice, %ptr, 11) -> ldr [%tileslice, 11], [%ptr, 11]
5462 // Case 2: If the vecnum is not an immediate, then it is used to modify the base
5463 // and tile slice registers
5464 // ldr(%tileslice, %ptr, %vecnum)
5465 // ->
5466 // %svl = rdsvl
5467 // %ptr2 = %ptr + %svl * %vecnum
5468 // %tileslice2 = %tileslice + %vecnum
5469 // ldr [%tileslice2, 0], [%ptr2, 0]
5470 // Case 3: If the vecnum is an immediate out of range, then the same is done as
5471 // case 2, but the base and slice registers are modified by the greatest
5472 // multiple of 15 lower than the vecnum and the remainder is folded into the
5473 // instruction. This means that successive loads and stores that are offset from
5474 // each other can share the same base and slice register updates.
5475 // ldr(%tileslice, %ptr, 22)
5476 // ldr(%tileslice, %ptr, 23)
5477 // ->
5478 // %svl = rdsvl
5479 // %ptr2 = %ptr + %svl * 15
5480 // %tileslice2 = %tileslice + 15
5481 // ldr [%tileslice2, 7], [%ptr2, 7]
5482 // ldr [%tileslice2, 8], [%ptr2, 8]
5483 // Case 4: If the vecnum is an add of an immediate, then the non-immediate
5484 // operand and the immediate can be folded into the instruction, like case 2.
5485 // ldr(%tileslice, %ptr, %vecnum + 7)
5486 // ldr(%tileslice, %ptr, %vecnum + 8)
5487 // ->
5488 // %svl = rdsvl
5489 // %ptr2 = %ptr + %svl * %vecnum
5490 // %tileslice2 = %tileslice + %vecnum
5491 // ldr [%tileslice2, 7], [%ptr2, 7]
5492 // ldr [%tileslice2, 8], [%ptr2, 8]
5493 // Case 5: The vecnum being an add of an immediate out of range is also handled,
5494 // in which case the same remainder logic as case 3 is used.
LowerSMELdrStr(SDValue N,SelectionDAG & DAG,bool IsLoad)5495 SDValue LowerSMELdrStr(SDValue N, SelectionDAG &DAG, bool IsLoad) {
5496 SDLoc DL(N);
5497
5498 SDValue TileSlice = N->getOperand(2);
5499 SDValue Base = N->getOperand(3);
5500 SDValue VecNum = N->getOperand(4);
5501 int32_t ConstAddend = 0;
5502 SDValue VarAddend = VecNum;
5503
5504 // If the vnum is an add of an immediate, we can fold it into the instruction
5505 if (VecNum.getOpcode() == ISD::ADD &&
5506 isa<ConstantSDNode>(VecNum.getOperand(1))) {
5507 ConstAddend = cast<ConstantSDNode>(VecNum.getOperand(1))->getSExtValue();
5508 VarAddend = VecNum.getOperand(0);
5509 } else if (auto ImmNode = dyn_cast<ConstantSDNode>(VecNum)) {
5510 ConstAddend = ImmNode->getSExtValue();
5511 VarAddend = SDValue();
5512 }
5513
5514 int32_t ImmAddend = ConstAddend % 16;
5515 if (int32_t C = (ConstAddend - ImmAddend)) {
5516 SDValue CVal = DAG.getTargetConstant(C, DL, MVT::i32);
5517 VarAddend = VarAddend
5518 ? DAG.getNode(ISD::ADD, DL, MVT::i32, {VarAddend, CVal})
5519 : CVal;
5520 }
5521
5522 if (VarAddend) {
5523 // Get the vector length that will be multiplied by vnum
5524 auto SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
5525 DAG.getConstant(1, DL, MVT::i32));
5526
5527 // Multiply SVL and vnum then add it to the base
5528 SDValue Mul = DAG.getNode(
5529 ISD::MUL, DL, MVT::i64,
5530 {SVL, DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, VarAddend)});
5531 Base = DAG.getNode(ISD::ADD, DL, MVT::i64, {Base, Mul});
5532 // Just add vnum to the tileslice
5533 TileSlice = DAG.getNode(ISD::ADD, DL, MVT::i32, {TileSlice, VarAddend});
5534 }
5535
5536 return DAG.getNode(IsLoad ? AArch64ISD::SME_ZA_LDR : AArch64ISD::SME_ZA_STR,
5537 DL, MVT::Other,
5538 {/*Chain=*/N.getOperand(0), TileSlice, Base,
5539 DAG.getTargetConstant(ImmAddend, DL, MVT::i32)});
5540 }
5541
LowerINTRINSIC_VOID(SDValue Op,SelectionDAG & DAG) const5542 SDValue AArch64TargetLowering::LowerINTRINSIC_VOID(SDValue Op,
5543 SelectionDAG &DAG) const {
5544 unsigned IntNo = Op.getConstantOperandVal(1);
5545 SDLoc DL(Op);
5546 switch (IntNo) {
5547 default:
5548 return SDValue(); // Don't custom lower most intrinsics.
5549 case Intrinsic::aarch64_prefetch: {
5550 SDValue Chain = Op.getOperand(0);
5551 SDValue Addr = Op.getOperand(2);
5552
5553 unsigned IsWrite = Op.getConstantOperandVal(3);
5554 unsigned Locality = Op.getConstantOperandVal(4);
5555 unsigned IsStream = Op.getConstantOperandVal(5);
5556 unsigned IsData = Op.getConstantOperandVal(6);
5557 unsigned PrfOp = (IsWrite << 4) | // Load/Store bit
5558 (!IsData << 3) | // IsDataCache bit
5559 (Locality << 1) | // Cache level bits
5560 (unsigned)IsStream; // Stream bit
5561
5562 return DAG.getNode(AArch64ISD::PREFETCH, DL, MVT::Other, Chain,
5563 DAG.getTargetConstant(PrfOp, DL, MVT::i32), Addr);
5564 }
5565 case Intrinsic::aarch64_sme_str:
5566 case Intrinsic::aarch64_sme_ldr: {
5567 return LowerSMELdrStr(Op, DAG, IntNo == Intrinsic::aarch64_sme_ldr);
5568 }
5569 case Intrinsic::aarch64_sme_za_enable:
5570 return DAG.getNode(
5571 AArch64ISD::SMSTART, DL, MVT::Other,
5572 Op->getOperand(0), // Chain
5573 DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
5574 DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
5575 case Intrinsic::aarch64_sme_za_disable:
5576 return DAG.getNode(
5577 AArch64ISD::SMSTOP, DL, MVT::Other,
5578 Op->getOperand(0), // Chain
5579 DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
5580 DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
5581 }
5582 }
5583
LowerINTRINSIC_W_CHAIN(SDValue Op,SelectionDAG & DAG) const5584 SDValue AArch64TargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
5585 SelectionDAG &DAG) const {
5586 unsigned IntNo = Op.getConstantOperandVal(1);
5587 SDLoc DL(Op);
5588 switch (IntNo) {
5589 default:
5590 return SDValue(); // Don't custom lower most intrinsics.
5591 case Intrinsic::aarch64_mops_memset_tag: {
5592 auto Node = cast<MemIntrinsicSDNode>(Op.getNode());
5593 SDValue Chain = Node->getChain();
5594 SDValue Dst = Op.getOperand(2);
5595 SDValue Val = Op.getOperand(3);
5596 Val = DAG.getAnyExtOrTrunc(Val, DL, MVT::i64);
5597 SDValue Size = Op.getOperand(4);
5598 auto Alignment = Node->getMemOperand()->getAlign();
5599 bool IsVol = Node->isVolatile();
5600 auto DstPtrInfo = Node->getPointerInfo();
5601
5602 const auto &SDI =
5603 static_cast<const AArch64SelectionDAGInfo &>(DAG.getSelectionDAGInfo());
5604 SDValue MS =
5605 SDI.EmitMOPS(AArch64ISD::MOPS_MEMSET_TAGGING, DAG, DL, Chain, Dst, Val,
5606 Size, Alignment, IsVol, DstPtrInfo, MachinePointerInfo{});
5607
5608 // MOPS_MEMSET_TAGGING has 3 results (DstWb, SizeWb, Chain) whereas the
5609 // intrinsic has 2. So hide SizeWb using MERGE_VALUES. Otherwise
5610 // LowerOperationWrapper will complain that the number of results has
5611 // changed.
5612 return DAG.getMergeValues({MS.getValue(0), MS.getValue(2)}, DL);
5613 }
5614 }
5615 }
5616
LowerINTRINSIC_WO_CHAIN(SDValue Op,SelectionDAG & DAG) const5617 SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
5618 SelectionDAG &DAG) const {
5619 unsigned IntNo = Op.getConstantOperandVal(0);
5620 SDLoc dl(Op);
5621 switch (IntNo) {
5622 default: return SDValue(); // Don't custom lower most intrinsics.
5623 case Intrinsic::thread_pointer: {
5624 EVT PtrVT = getPointerTy(DAG.getDataLayout());
5625 return DAG.getNode(AArch64ISD::THREAD_POINTER, dl, PtrVT);
5626 }
5627 case Intrinsic::aarch64_neon_abs: {
5628 EVT Ty = Op.getValueType();
5629 if (Ty == MVT::i64) {
5630 SDValue Result = DAG.getNode(ISD::BITCAST, dl, MVT::v1i64,
5631 Op.getOperand(1));
5632 Result = DAG.getNode(ISD::ABS, dl, MVT::v1i64, Result);
5633 return DAG.getNode(ISD::BITCAST, dl, MVT::i64, Result);
5634 } else if (Ty.isVector() && Ty.isInteger() && isTypeLegal(Ty)) {
5635 return DAG.getNode(ISD::ABS, dl, Ty, Op.getOperand(1));
5636 } else {
5637 report_fatal_error("Unexpected type for AArch64 NEON intrinic");
5638 }
5639 }
5640 case Intrinsic::aarch64_neon_pmull64: {
5641 SDValue LHS = Op.getOperand(1);
5642 SDValue RHS = Op.getOperand(2);
5643
5644 std::optional<uint64_t> LHSLane =
5645 getConstantLaneNumOfExtractHalfOperand(LHS);
5646 std::optional<uint64_t> RHSLane =
5647 getConstantLaneNumOfExtractHalfOperand(RHS);
5648
5649 assert((!LHSLane || *LHSLane < 2) && "Expect lane to be None or 0 or 1");
5650 assert((!RHSLane || *RHSLane < 2) && "Expect lane to be None or 0 or 1");
5651
5652 // 'aarch64_neon_pmull64' takes i64 parameters; while pmull/pmull2
5653 // instructions execute on SIMD registers. So canonicalize i64 to v1i64,
5654 // which ISel recognizes better. For example, generate a ldr into d*
5655 // registers as opposed to a GPR load followed by a fmov.
5656 auto TryVectorizeOperand = [](SDValue N, std::optional<uint64_t> NLane,
5657 std::optional<uint64_t> OtherLane,
5658 const SDLoc &dl,
5659 SelectionDAG &DAG) -> SDValue {
5660 // If the operand is an higher half itself, rewrite it to
5661 // extract_high_v2i64; this way aarch64_neon_pmull64 could
5662 // re-use the dag-combiner function with aarch64_neon_{pmull,smull,umull}.
5663 if (NLane && *NLane == 1)
5664 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v1i64,
5665 N.getOperand(0), DAG.getConstant(1, dl, MVT::i64));
5666
5667 // Operand N is not a higher half but the other operand is.
5668 if (OtherLane && *OtherLane == 1) {
5669 // If this operand is a lower half, rewrite it to
5670 // extract_high_v2i64(duplane(<2 x Ty>, 0)). This saves a roundtrip to
5671 // align lanes of two operands. A roundtrip sequence (to move from lane
5672 // 1 to lane 0) is like this:
5673 // mov x8, v0.d[1]
5674 // fmov d0, x8
5675 if (NLane && *NLane == 0)
5676 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, MVT::v1i64,
5677 DAG.getNode(AArch64ISD::DUPLANE64, dl, MVT::v2i64,
5678 N.getOperand(0),
5679 DAG.getConstant(0, dl, MVT::i64)),
5680 DAG.getConstant(1, dl, MVT::i64));
5681
5682 // Otherwise just dup from main to all lanes.
5683 return DAG.getNode(AArch64ISD::DUP, dl, MVT::v1i64, N);
5684 }
5685
5686 // Neither operand is an extract of higher half, so codegen may just use
5687 // the non-high version of PMULL instruction. Use v1i64 to represent i64.
5688 assert(N.getValueType() == MVT::i64 &&
5689 "Intrinsic aarch64_neon_pmull64 requires i64 parameters");
5690 return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v1i64, N);
5691 };
5692
5693 LHS = TryVectorizeOperand(LHS, LHSLane, RHSLane, dl, DAG);
5694 RHS = TryVectorizeOperand(RHS, RHSLane, LHSLane, dl, DAG);
5695
5696 return DAG.getNode(AArch64ISD::PMULL, dl, Op.getValueType(), LHS, RHS);
5697 }
5698 case Intrinsic::aarch64_neon_smax:
5699 return DAG.getNode(ISD::SMAX, dl, Op.getValueType(),
5700 Op.getOperand(1), Op.getOperand(2));
5701 case Intrinsic::aarch64_neon_umax:
5702 return DAG.getNode(ISD::UMAX, dl, Op.getValueType(),
5703 Op.getOperand(1), Op.getOperand(2));
5704 case Intrinsic::aarch64_neon_smin:
5705 return DAG.getNode(ISD::SMIN, dl, Op.getValueType(),
5706 Op.getOperand(1), Op.getOperand(2));
5707 case Intrinsic::aarch64_neon_umin:
5708 return DAG.getNode(ISD::UMIN, dl, Op.getValueType(),
5709 Op.getOperand(1), Op.getOperand(2));
5710 case Intrinsic::aarch64_neon_scalar_sqxtn:
5711 case Intrinsic::aarch64_neon_scalar_sqxtun:
5712 case Intrinsic::aarch64_neon_scalar_uqxtn: {
5713 assert(Op.getValueType() == MVT::i32 || Op.getValueType() == MVT::f32);
5714 if (Op.getValueType() == MVT::i32)
5715 return DAG.getNode(ISD::BITCAST, dl, MVT::i32,
5716 DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MVT::f32,
5717 Op.getOperand(0),
5718 DAG.getNode(ISD::BITCAST, dl, MVT::f64,
5719 Op.getOperand(1))));
5720 return SDValue();
5721 }
5722 case Intrinsic::aarch64_sve_whilelo:
5723 return optimizeIncrementingWhile(Op, DAG, /*IsSigned=*/false,
5724 /*IsEqual=*/false);
5725 case Intrinsic::aarch64_sve_whilelt:
5726 return optimizeIncrementingWhile(Op, DAG, /*IsSigned=*/true,
5727 /*IsEqual=*/false);
5728 case Intrinsic::aarch64_sve_whilels:
5729 return optimizeIncrementingWhile(Op, DAG, /*IsSigned=*/false,
5730 /*IsEqual=*/true);
5731 case Intrinsic::aarch64_sve_whilele:
5732 return optimizeIncrementingWhile(Op, DAG, /*IsSigned=*/true,
5733 /*IsEqual=*/true);
5734 case Intrinsic::aarch64_sve_sunpkhi:
5735 return DAG.getNode(AArch64ISD::SUNPKHI, dl, Op.getValueType(),
5736 Op.getOperand(1));
5737 case Intrinsic::aarch64_sve_sunpklo:
5738 return DAG.getNode(AArch64ISD::SUNPKLO, dl, Op.getValueType(),
5739 Op.getOperand(1));
5740 case Intrinsic::aarch64_sve_uunpkhi:
5741 return DAG.getNode(AArch64ISD::UUNPKHI, dl, Op.getValueType(),
5742 Op.getOperand(1));
5743 case Intrinsic::aarch64_sve_uunpklo:
5744 return DAG.getNode(AArch64ISD::UUNPKLO, dl, Op.getValueType(),
5745 Op.getOperand(1));
5746 case Intrinsic::aarch64_sve_clasta_n:
5747 return DAG.getNode(AArch64ISD::CLASTA_N, dl, Op.getValueType(),
5748 Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
5749 case Intrinsic::aarch64_sve_clastb_n:
5750 return DAG.getNode(AArch64ISD::CLASTB_N, dl, Op.getValueType(),
5751 Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
5752 case Intrinsic::aarch64_sve_lasta:
5753 return DAG.getNode(AArch64ISD::LASTA, dl, Op.getValueType(),
5754 Op.getOperand(1), Op.getOperand(2));
5755 case Intrinsic::aarch64_sve_lastb:
5756 return DAG.getNode(AArch64ISD::LASTB, dl, Op.getValueType(),
5757 Op.getOperand(1), Op.getOperand(2));
5758 case Intrinsic::aarch64_sve_rev:
5759 return DAG.getNode(ISD::VECTOR_REVERSE, dl, Op.getValueType(),
5760 Op.getOperand(1));
5761 case Intrinsic::aarch64_sve_tbl:
5762 return DAG.getNode(AArch64ISD::TBL, dl, Op.getValueType(),
5763 Op.getOperand(1), Op.getOperand(2));
5764 case Intrinsic::aarch64_sve_trn1:
5765 return DAG.getNode(AArch64ISD::TRN1, dl, Op.getValueType(),
5766 Op.getOperand(1), Op.getOperand(2));
5767 case Intrinsic::aarch64_sve_trn2:
5768 return DAG.getNode(AArch64ISD::TRN2, dl, Op.getValueType(),
5769 Op.getOperand(1), Op.getOperand(2));
5770 case Intrinsic::aarch64_sve_uzp1:
5771 return DAG.getNode(AArch64ISD::UZP1, dl, Op.getValueType(),
5772 Op.getOperand(1), Op.getOperand(2));
5773 case Intrinsic::aarch64_sve_uzp2:
5774 return DAG.getNode(AArch64ISD::UZP2, dl, Op.getValueType(),
5775 Op.getOperand(1), Op.getOperand(2));
5776 case Intrinsic::aarch64_sve_zip1:
5777 return DAG.getNode(AArch64ISD::ZIP1, dl, Op.getValueType(),
5778 Op.getOperand(1), Op.getOperand(2));
5779 case Intrinsic::aarch64_sve_zip2:
5780 return DAG.getNode(AArch64ISD::ZIP2, dl, Op.getValueType(),
5781 Op.getOperand(1), Op.getOperand(2));
5782 case Intrinsic::aarch64_sve_splice:
5783 return DAG.getNode(AArch64ISD::SPLICE, dl, Op.getValueType(),
5784 Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
5785 case Intrinsic::aarch64_sve_ptrue:
5786 return getPTrue(DAG, dl, Op.getValueType(), Op.getConstantOperandVal(1));
5787 case Intrinsic::aarch64_sve_clz:
5788 return DAG.getNode(AArch64ISD::CTLZ_MERGE_PASSTHRU, dl, Op.getValueType(),
5789 Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5790 case Intrinsic::aarch64_sme_cntsb:
5791 return DAG.getNode(AArch64ISD::RDSVL, dl, Op.getValueType(),
5792 DAG.getConstant(1, dl, MVT::i32));
5793 case Intrinsic::aarch64_sme_cntsh: {
5794 SDValue One = DAG.getConstant(1, dl, MVT::i32);
5795 SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, dl, Op.getValueType(), One);
5796 return DAG.getNode(ISD::SRL, dl, Op.getValueType(), Bytes, One);
5797 }
5798 case Intrinsic::aarch64_sme_cntsw: {
5799 SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, dl, Op.getValueType(),
5800 DAG.getConstant(1, dl, MVT::i32));
5801 return DAG.getNode(ISD::SRL, dl, Op.getValueType(), Bytes,
5802 DAG.getConstant(2, dl, MVT::i32));
5803 }
5804 case Intrinsic::aarch64_sme_cntsd: {
5805 SDValue Bytes = DAG.getNode(AArch64ISD::RDSVL, dl, Op.getValueType(),
5806 DAG.getConstant(1, dl, MVT::i32));
5807 return DAG.getNode(ISD::SRL, dl, Op.getValueType(), Bytes,
5808 DAG.getConstant(3, dl, MVT::i32));
5809 }
5810 case Intrinsic::aarch64_sve_cnt: {
5811 SDValue Data = Op.getOperand(3);
5812 // CTPOP only supports integer operands.
5813 if (Data.getValueType().isFloatingPoint())
5814 Data = DAG.getNode(ISD::BITCAST, dl, Op.getValueType(), Data);
5815 return DAG.getNode(AArch64ISD::CTPOP_MERGE_PASSTHRU, dl, Op.getValueType(),
5816 Op.getOperand(2), Data, Op.getOperand(1));
5817 }
5818 case Intrinsic::aarch64_sve_dupq_lane:
5819 return LowerDUPQLane(Op, DAG);
5820 case Intrinsic::aarch64_sve_convert_from_svbool:
5821 if (Op.getValueType() == MVT::aarch64svcount)
5822 return DAG.getNode(ISD::BITCAST, dl, Op.getValueType(), Op.getOperand(1));
5823 return getSVEPredicateBitCast(Op.getValueType(), Op.getOperand(1), DAG);
5824 case Intrinsic::aarch64_sve_convert_to_svbool:
5825 if (Op.getOperand(1).getValueType() == MVT::aarch64svcount)
5826 return DAG.getNode(ISD::BITCAST, dl, MVT::nxv16i1, Op.getOperand(1));
5827 return getSVEPredicateBitCast(MVT::nxv16i1, Op.getOperand(1), DAG);
5828 case Intrinsic::aarch64_sve_fneg:
5829 return DAG.getNode(AArch64ISD::FNEG_MERGE_PASSTHRU, dl, Op.getValueType(),
5830 Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5831 case Intrinsic::aarch64_sve_frintp:
5832 return DAG.getNode(AArch64ISD::FCEIL_MERGE_PASSTHRU, dl, Op.getValueType(),
5833 Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5834 case Intrinsic::aarch64_sve_frintm:
5835 return DAG.getNode(AArch64ISD::FFLOOR_MERGE_PASSTHRU, dl, Op.getValueType(),
5836 Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5837 case Intrinsic::aarch64_sve_frinti:
5838 return DAG.getNode(AArch64ISD::FNEARBYINT_MERGE_PASSTHRU, dl, Op.getValueType(),
5839 Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5840 case Intrinsic::aarch64_sve_frintx:
5841 return DAG.getNode(AArch64ISD::FRINT_MERGE_PASSTHRU, dl, Op.getValueType(),
5842 Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5843 case Intrinsic::aarch64_sve_frinta:
5844 return DAG.getNode(AArch64ISD::FROUND_MERGE_PASSTHRU, dl, Op.getValueType(),
5845 Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5846 case Intrinsic::aarch64_sve_frintn:
5847 return DAG.getNode(AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU, dl, Op.getValueType(),
5848 Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5849 case Intrinsic::aarch64_sve_frintz:
5850 return DAG.getNode(AArch64ISD::FTRUNC_MERGE_PASSTHRU, dl, Op.getValueType(),
5851 Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5852 case Intrinsic::aarch64_sve_ucvtf:
5853 return DAG.getNode(AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU, dl,
5854 Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
5855 Op.getOperand(1));
5856 case Intrinsic::aarch64_sve_scvtf:
5857 return DAG.getNode(AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU, dl,
5858 Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
5859 Op.getOperand(1));
5860 case Intrinsic::aarch64_sve_fcvtzu:
5861 return DAG.getNode(AArch64ISD::FCVTZU_MERGE_PASSTHRU, dl,
5862 Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
5863 Op.getOperand(1));
5864 case Intrinsic::aarch64_sve_fcvtzs:
5865 return DAG.getNode(AArch64ISD::FCVTZS_MERGE_PASSTHRU, dl,
5866 Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
5867 Op.getOperand(1));
5868 case Intrinsic::aarch64_sve_fsqrt:
5869 return DAG.getNode(AArch64ISD::FSQRT_MERGE_PASSTHRU, dl, Op.getValueType(),
5870 Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5871 case Intrinsic::aarch64_sve_frecpx:
5872 return DAG.getNode(AArch64ISD::FRECPX_MERGE_PASSTHRU, dl, Op.getValueType(),
5873 Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5874 case Intrinsic::aarch64_sve_frecpe_x:
5875 return DAG.getNode(AArch64ISD::FRECPE, dl, Op.getValueType(),
5876 Op.getOperand(1));
5877 case Intrinsic::aarch64_sve_frecps_x:
5878 return DAG.getNode(AArch64ISD::FRECPS, dl, Op.getValueType(),
5879 Op.getOperand(1), Op.getOperand(2));
5880 case Intrinsic::aarch64_sve_frsqrte_x:
5881 return DAG.getNode(AArch64ISD::FRSQRTE, dl, Op.getValueType(),
5882 Op.getOperand(1));
5883 case Intrinsic::aarch64_sve_frsqrts_x:
5884 return DAG.getNode(AArch64ISD::FRSQRTS, dl, Op.getValueType(),
5885 Op.getOperand(1), Op.getOperand(2));
5886 case Intrinsic::aarch64_sve_fabs:
5887 return DAG.getNode(AArch64ISD::FABS_MERGE_PASSTHRU, dl, Op.getValueType(),
5888 Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5889 case Intrinsic::aarch64_sve_abs:
5890 return DAG.getNode(AArch64ISD::ABS_MERGE_PASSTHRU, dl, Op.getValueType(),
5891 Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5892 case Intrinsic::aarch64_sve_neg:
5893 return DAG.getNode(AArch64ISD::NEG_MERGE_PASSTHRU, dl, Op.getValueType(),
5894 Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5895 case Intrinsic::aarch64_sve_insr: {
5896 SDValue Scalar = Op.getOperand(2);
5897 EVT ScalarTy = Scalar.getValueType();
5898 if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16))
5899 Scalar = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, Scalar);
5900
5901 return DAG.getNode(AArch64ISD::INSR, dl, Op.getValueType(),
5902 Op.getOperand(1), Scalar);
5903 }
5904 case Intrinsic::aarch64_sve_rbit:
5905 return DAG.getNode(AArch64ISD::BITREVERSE_MERGE_PASSTHRU, dl,
5906 Op.getValueType(), Op.getOperand(2), Op.getOperand(3),
5907 Op.getOperand(1));
5908 case Intrinsic::aarch64_sve_revb:
5909 return DAG.getNode(AArch64ISD::BSWAP_MERGE_PASSTHRU, dl, Op.getValueType(),
5910 Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5911 case Intrinsic::aarch64_sve_revh:
5912 return DAG.getNode(AArch64ISD::REVH_MERGE_PASSTHRU, dl, Op.getValueType(),
5913 Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5914 case Intrinsic::aarch64_sve_revw:
5915 return DAG.getNode(AArch64ISD::REVW_MERGE_PASSTHRU, dl, Op.getValueType(),
5916 Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5917 case Intrinsic::aarch64_sve_revd:
5918 return DAG.getNode(AArch64ISD::REVD_MERGE_PASSTHRU, dl, Op.getValueType(),
5919 Op.getOperand(2), Op.getOperand(3), Op.getOperand(1));
5920 case Intrinsic::aarch64_sve_sxtb:
5921 return DAG.getNode(
5922 AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
5923 Op.getOperand(2), Op.getOperand(3),
5924 DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i8)),
5925 Op.getOperand(1));
5926 case Intrinsic::aarch64_sve_sxth:
5927 return DAG.getNode(
5928 AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
5929 Op.getOperand(2), Op.getOperand(3),
5930 DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i16)),
5931 Op.getOperand(1));
5932 case Intrinsic::aarch64_sve_sxtw:
5933 return DAG.getNode(
5934 AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
5935 Op.getOperand(2), Op.getOperand(3),
5936 DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i32)),
5937 Op.getOperand(1));
5938 case Intrinsic::aarch64_sve_uxtb:
5939 return DAG.getNode(
5940 AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
5941 Op.getOperand(2), Op.getOperand(3),
5942 DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i8)),
5943 Op.getOperand(1));
5944 case Intrinsic::aarch64_sve_uxth:
5945 return DAG.getNode(
5946 AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
5947 Op.getOperand(2), Op.getOperand(3),
5948 DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i16)),
5949 Op.getOperand(1));
5950 case Intrinsic::aarch64_sve_uxtw:
5951 return DAG.getNode(
5952 AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU, dl, Op.getValueType(),
5953 Op.getOperand(2), Op.getOperand(3),
5954 DAG.getValueType(Op.getValueType().changeVectorElementType(MVT::i32)),
5955 Op.getOperand(1));
5956 case Intrinsic::localaddress: {
5957 const auto &MF = DAG.getMachineFunction();
5958 const auto *RegInfo = Subtarget->getRegisterInfo();
5959 unsigned Reg = RegInfo->getLocalAddressRegister(MF);
5960 return DAG.getCopyFromReg(DAG.getEntryNode(), dl, Reg,
5961 Op.getSimpleValueType());
5962 }
5963
5964 case Intrinsic::eh_recoverfp: {
5965 // FIXME: This needs to be implemented to correctly handle highly aligned
5966 // stack objects. For now we simply return the incoming FP. Refer D53541
5967 // for more details.
5968 SDValue FnOp = Op.getOperand(1);
5969 SDValue IncomingFPOp = Op.getOperand(2);
5970 GlobalAddressSDNode *GSD = dyn_cast<GlobalAddressSDNode>(FnOp);
5971 auto *Fn = dyn_cast_or_null<Function>(GSD ? GSD->getGlobal() : nullptr);
5972 if (!Fn)
5973 report_fatal_error(
5974 "llvm.eh.recoverfp must take a function as the first argument");
5975 return IncomingFPOp;
5976 }
5977
5978 case Intrinsic::aarch64_neon_vsri:
5979 case Intrinsic::aarch64_neon_vsli:
5980 case Intrinsic::aarch64_sve_sri:
5981 case Intrinsic::aarch64_sve_sli: {
5982 EVT Ty = Op.getValueType();
5983
5984 if (!Ty.isVector())
5985 report_fatal_error("Unexpected type for aarch64_neon_vsli");
5986
5987 assert(Op.getConstantOperandVal(3) <= Ty.getScalarSizeInBits());
5988
5989 bool IsShiftRight = IntNo == Intrinsic::aarch64_neon_vsri ||
5990 IntNo == Intrinsic::aarch64_sve_sri;
5991 unsigned Opcode = IsShiftRight ? AArch64ISD::VSRI : AArch64ISD::VSLI;
5992 return DAG.getNode(Opcode, dl, Ty, Op.getOperand(1), Op.getOperand(2),
5993 Op.getOperand(3));
5994 }
5995
5996 case Intrinsic::aarch64_neon_srhadd:
5997 case Intrinsic::aarch64_neon_urhadd:
5998 case Intrinsic::aarch64_neon_shadd:
5999 case Intrinsic::aarch64_neon_uhadd: {
6000 bool IsSignedAdd = (IntNo == Intrinsic::aarch64_neon_srhadd ||
6001 IntNo == Intrinsic::aarch64_neon_shadd);
6002 bool IsRoundingAdd = (IntNo == Intrinsic::aarch64_neon_srhadd ||
6003 IntNo == Intrinsic::aarch64_neon_urhadd);
6004 unsigned Opcode = IsSignedAdd
6005 ? (IsRoundingAdd ? ISD::AVGCEILS : ISD::AVGFLOORS)
6006 : (IsRoundingAdd ? ISD::AVGCEILU : ISD::AVGFLOORU);
6007 return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
6008 Op.getOperand(2));
6009 }
6010 case Intrinsic::aarch64_neon_saddlp:
6011 case Intrinsic::aarch64_neon_uaddlp: {
6012 unsigned Opcode = IntNo == Intrinsic::aarch64_neon_uaddlp
6013 ? AArch64ISD::UADDLP
6014 : AArch64ISD::SADDLP;
6015 return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1));
6016 }
6017 case Intrinsic::aarch64_neon_sdot:
6018 case Intrinsic::aarch64_neon_udot:
6019 case Intrinsic::aarch64_sve_sdot:
6020 case Intrinsic::aarch64_sve_udot: {
6021 unsigned Opcode = (IntNo == Intrinsic::aarch64_neon_udot ||
6022 IntNo == Intrinsic::aarch64_sve_udot)
6023 ? AArch64ISD::UDOT
6024 : AArch64ISD::SDOT;
6025 return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
6026 Op.getOperand(2), Op.getOperand(3));
6027 }
6028 case Intrinsic::get_active_lane_mask: {
6029 SDValue ID =
6030 DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, dl, MVT::i64);
6031
6032 EVT VT = Op.getValueType();
6033 if (VT.isScalableVector())
6034 return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, VT, ID, Op.getOperand(1),
6035 Op.getOperand(2));
6036
6037 // We can use the SVE whilelo instruction to lower this intrinsic by
6038 // creating the appropriate sequence of scalable vector operations and
6039 // then extracting a fixed-width subvector from the scalable vector.
6040
6041 EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
6042 EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
6043
6044 SDValue Mask = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, WhileVT, ID,
6045 Op.getOperand(1), Op.getOperand(2));
6046 SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, dl, ContainerVT, Mask);
6047 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, MaskAsInt,
6048 DAG.getVectorIdxConstant(0, dl));
6049 }
6050 case Intrinsic::aarch64_neon_uaddlv: {
6051 EVT OpVT = Op.getOperand(1).getValueType();
6052 EVT ResVT = Op.getValueType();
6053 if (ResVT == MVT::i32 && (OpVT == MVT::v8i8 || OpVT == MVT::v16i8 ||
6054 OpVT == MVT::v8i16 || OpVT == MVT::v4i16)) {
6055 // In order to avoid insert_subvector, used v4i32 than v2i32.
6056 SDValue UADDLV =
6057 DAG.getNode(AArch64ISD::UADDLV, dl, MVT::v4i32, Op.getOperand(1));
6058 SDValue EXTRACT_VEC_ELT =
6059 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i32, UADDLV,
6060 DAG.getConstant(0, dl, MVT::i64));
6061 return EXTRACT_VEC_ELT;
6062 }
6063 return SDValue();
6064 }
6065 case Intrinsic::experimental_cttz_elts: {
6066 SDValue CttzOp = Op.getOperand(1);
6067 EVT VT = CttzOp.getValueType();
6068 assert(VT.getVectorElementType() == MVT::i1 && "Expected MVT::i1");
6069
6070 if (VT.isFixedLengthVector()) {
6071 // We can use SVE instructions to lower this intrinsic by first creating
6072 // an SVE predicate register mask from the fixed-width vector.
6073 EVT NewVT = getTypeToTransformTo(*DAG.getContext(), VT);
6074 SDValue Mask = DAG.getNode(ISD::SIGN_EXTEND, dl, NewVT, CttzOp);
6075 CttzOp = convertFixedMaskToScalableVector(Mask, DAG);
6076 }
6077
6078 SDValue NewCttzElts =
6079 DAG.getNode(AArch64ISD::CTTZ_ELTS, dl, MVT::i64, CttzOp);
6080 return DAG.getZExtOrTrunc(NewCttzElts, dl, Op.getValueType());
6081 }
6082 }
6083 }
6084
shouldExtendGSIndex(EVT VT,EVT & EltTy) const6085 bool AArch64TargetLowering::shouldExtendGSIndex(EVT VT, EVT &EltTy) const {
6086 if (VT.getVectorElementType() == MVT::i8 ||
6087 VT.getVectorElementType() == MVT::i16) {
6088 EltTy = MVT::i32;
6089 return true;
6090 }
6091 return false;
6092 }
6093
shouldRemoveExtendFromGSIndex(SDValue Extend,EVT DataVT) const6094 bool AArch64TargetLowering::shouldRemoveExtendFromGSIndex(SDValue Extend,
6095 EVT DataVT) const {
6096 const EVT IndexVT = Extend.getOperand(0).getValueType();
6097 // SVE only supports implicit extension of 32-bit indices.
6098 if (!Subtarget->hasSVE() || IndexVT.getVectorElementType() != MVT::i32)
6099 return false;
6100
6101 // Indices cannot be smaller than the main data type.
6102 if (IndexVT.getScalarSizeInBits() < DataVT.getScalarSizeInBits())
6103 return false;
6104
6105 // Scalable vectors with "vscale * 2" or fewer elements sit within a 64-bit
6106 // element container type, which would violate the previous clause.
6107 return DataVT.isFixedLengthVector() || DataVT.getVectorMinNumElements() > 2;
6108 }
6109
isVectorLoadExtDesirable(SDValue ExtVal) const6110 bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const {
6111 EVT ExtVT = ExtVal.getValueType();
6112 if (!ExtVT.isScalableVector() && !Subtarget->useSVEForFixedLengthVectors())
6113 return false;
6114
6115 // It may be worth creating extending masked loads if there are multiple
6116 // masked loads using the same predicate. That way we'll end up creating
6117 // extending masked loads that may then get split by the legaliser. This
6118 // results in just one set of predicate unpacks at the start, instead of
6119 // multiple sets of vector unpacks after each load.
6120 if (auto *Ld = dyn_cast<MaskedLoadSDNode>(ExtVal->getOperand(0))) {
6121 if (!isLoadExtLegalOrCustom(ISD::ZEXTLOAD, ExtVT, Ld->getValueType(0))) {
6122 // Disable extending masked loads for fixed-width for now, since the code
6123 // quality doesn't look great.
6124 if (!ExtVT.isScalableVector())
6125 return false;
6126
6127 unsigned NumExtMaskedLoads = 0;
6128 for (auto *U : Ld->getMask()->uses())
6129 if (isa<MaskedLoadSDNode>(U))
6130 NumExtMaskedLoads++;
6131
6132 if (NumExtMaskedLoads <= 1)
6133 return false;
6134 }
6135 }
6136
6137 return true;
6138 }
6139
getGatherVecOpcode(bool IsScaled,bool IsSigned,bool NeedsExtend)6140 unsigned getGatherVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) {
6141 std::map<std::tuple<bool, bool, bool>, unsigned> AddrModes = {
6142 {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ false),
6143 AArch64ISD::GLD1_MERGE_ZERO},
6144 {std::make_tuple(/*Scaled*/ false, /*Signed*/ false, /*Extend*/ true),
6145 AArch64ISD::GLD1_UXTW_MERGE_ZERO},
6146 {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ false),
6147 AArch64ISD::GLD1_MERGE_ZERO},
6148 {std::make_tuple(/*Scaled*/ false, /*Signed*/ true, /*Extend*/ true),
6149 AArch64ISD::GLD1_SXTW_MERGE_ZERO},
6150 {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ false),
6151 AArch64ISD::GLD1_SCALED_MERGE_ZERO},
6152 {std::make_tuple(/*Scaled*/ true, /*Signed*/ false, /*Extend*/ true),
6153 AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO},
6154 {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ false),
6155 AArch64ISD::GLD1_SCALED_MERGE_ZERO},
6156 {std::make_tuple(/*Scaled*/ true, /*Signed*/ true, /*Extend*/ true),
6157 AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO},
6158 };
6159 auto Key = std::make_tuple(IsScaled, IsSigned, NeedsExtend);
6160 return AddrModes.find(Key)->second;
6161 }
6162
getSignExtendedGatherOpcode(unsigned Opcode)6163 unsigned getSignExtendedGatherOpcode(unsigned Opcode) {
6164 switch (Opcode) {
6165 default:
6166 llvm_unreachable("unimplemented opcode");
6167 return Opcode;
6168 case AArch64ISD::GLD1_MERGE_ZERO:
6169 return AArch64ISD::GLD1S_MERGE_ZERO;
6170 case AArch64ISD::GLD1_IMM_MERGE_ZERO:
6171 return AArch64ISD::GLD1S_IMM_MERGE_ZERO;
6172 case AArch64ISD::GLD1_UXTW_MERGE_ZERO:
6173 return AArch64ISD::GLD1S_UXTW_MERGE_ZERO;
6174 case AArch64ISD::GLD1_SXTW_MERGE_ZERO:
6175 return AArch64ISD::GLD1S_SXTW_MERGE_ZERO;
6176 case AArch64ISD::GLD1_SCALED_MERGE_ZERO:
6177 return AArch64ISD::GLD1S_SCALED_MERGE_ZERO;
6178 case AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO:
6179 return AArch64ISD::GLD1S_UXTW_SCALED_MERGE_ZERO;
6180 case AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO:
6181 return AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO;
6182 }
6183 }
6184
LowerMGATHER(SDValue Op,SelectionDAG & DAG) const6185 SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
6186 SelectionDAG &DAG) const {
6187 MaskedGatherSDNode *MGT = cast<MaskedGatherSDNode>(Op);
6188
6189 SDLoc DL(Op);
6190 SDValue Chain = MGT->getChain();
6191 SDValue PassThru = MGT->getPassThru();
6192 SDValue Mask = MGT->getMask();
6193 SDValue BasePtr = MGT->getBasePtr();
6194 SDValue Index = MGT->getIndex();
6195 SDValue Scale = MGT->getScale();
6196 EVT VT = Op.getValueType();
6197 EVT MemVT = MGT->getMemoryVT();
6198 ISD::LoadExtType ExtType = MGT->getExtensionType();
6199 ISD::MemIndexType IndexType = MGT->getIndexType();
6200
6201 // SVE supports zero (and so undef) passthrough values only, everything else
6202 // must be handled manually by an explicit select on the load's output.
6203 if (!PassThru->isUndef() && !isZerosVector(PassThru.getNode())) {
6204 SDValue Ops[] = {Chain, DAG.getUNDEF(VT), Mask, BasePtr, Index, Scale};
6205 SDValue Load =
6206 DAG.getMaskedGather(MGT->getVTList(), MemVT, DL, Ops,
6207 MGT->getMemOperand(), IndexType, ExtType);
6208 SDValue Select = DAG.getSelect(DL, VT, Mask, Load, PassThru);
6209 return DAG.getMergeValues({Select, Load.getValue(1)}, DL);
6210 }
6211
6212 bool IsScaled = MGT->isIndexScaled();
6213 bool IsSigned = MGT->isIndexSigned();
6214
6215 // SVE supports an index scaled by sizeof(MemVT.elt) only, everything else
6216 // must be calculated before hand.
6217 uint64_t ScaleVal = Scale->getAsZExtVal();
6218 if (IsScaled && ScaleVal != MemVT.getScalarStoreSize()) {
6219 assert(isPowerOf2_64(ScaleVal) && "Expecting power-of-two types");
6220 EVT IndexVT = Index.getValueType();
6221 Index = DAG.getNode(ISD::SHL, DL, IndexVT, Index,
6222 DAG.getConstant(Log2_32(ScaleVal), DL, IndexVT));
6223 Scale = DAG.getTargetConstant(1, DL, Scale.getValueType());
6224
6225 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
6226 return DAG.getMaskedGather(MGT->getVTList(), MemVT, DL, Ops,
6227 MGT->getMemOperand(), IndexType, ExtType);
6228 }
6229
6230 // Lower fixed length gather to a scalable equivalent.
6231 if (VT.isFixedLengthVector()) {
6232 assert(Subtarget->useSVEForFixedLengthVectors() &&
6233 "Cannot lower when not using SVE for fixed vectors!");
6234
6235 // NOTE: Handle floating-point as if integer then bitcast the result.
6236 EVT DataVT = VT.changeVectorElementTypeToInteger();
6237 MemVT = MemVT.changeVectorElementTypeToInteger();
6238
6239 // Find the smallest integer fixed length vector we can use for the gather.
6240 EVT PromotedVT = VT.changeVectorElementType(MVT::i32);
6241 if (DataVT.getVectorElementType() == MVT::i64 ||
6242 Index.getValueType().getVectorElementType() == MVT::i64 ||
6243 Mask.getValueType().getVectorElementType() == MVT::i64)
6244 PromotedVT = VT.changeVectorElementType(MVT::i64);
6245
6246 // Promote vector operands except for passthrough, which we know is either
6247 // undef or zero, and thus best constructed directly.
6248 unsigned ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
6249 Index = DAG.getNode(ExtOpcode, DL, PromotedVT, Index);
6250 Mask = DAG.getNode(ISD::SIGN_EXTEND, DL, PromotedVT, Mask);
6251
6252 // A promoted result type forces the need for an extending load.
6253 if (PromotedVT != DataVT && ExtType == ISD::NON_EXTLOAD)
6254 ExtType = ISD::EXTLOAD;
6255
6256 EVT ContainerVT = getContainerForFixedLengthVector(DAG, PromotedVT);
6257
6258 // Convert fixed length vector operands to scalable.
6259 MemVT = ContainerVT.changeVectorElementType(MemVT.getVectorElementType());
6260 Index = convertToScalableVector(DAG, ContainerVT, Index);
6261 Mask = convertFixedMaskToScalableVector(Mask, DAG);
6262 PassThru = PassThru->isUndef() ? DAG.getUNDEF(ContainerVT)
6263 : DAG.getConstant(0, DL, ContainerVT);
6264
6265 // Emit equivalent scalable vector gather.
6266 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
6267 SDValue Load =
6268 DAG.getMaskedGather(DAG.getVTList(ContainerVT, MVT::Other), MemVT, DL,
6269 Ops, MGT->getMemOperand(), IndexType, ExtType);
6270
6271 // Extract fixed length data then convert to the required result type.
6272 SDValue Result = convertFromScalableVector(DAG, PromotedVT, Load);
6273 Result = DAG.getNode(ISD::TRUNCATE, DL, DataVT, Result);
6274 if (VT.isFloatingPoint())
6275 Result = DAG.getNode(ISD::BITCAST, DL, VT, Result);
6276
6277 return DAG.getMergeValues({Result, Load.getValue(1)}, DL);
6278 }
6279
6280 // Everything else is legal.
6281 return Op;
6282 }
6283
LowerMSCATTER(SDValue Op,SelectionDAG & DAG) const6284 SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
6285 SelectionDAG &DAG) const {
6286 MaskedScatterSDNode *MSC = cast<MaskedScatterSDNode>(Op);
6287
6288 SDLoc DL(Op);
6289 SDValue Chain = MSC->getChain();
6290 SDValue StoreVal = MSC->getValue();
6291 SDValue Mask = MSC->getMask();
6292 SDValue BasePtr = MSC->getBasePtr();
6293 SDValue Index = MSC->getIndex();
6294 SDValue Scale = MSC->getScale();
6295 EVT VT = StoreVal.getValueType();
6296 EVT MemVT = MSC->getMemoryVT();
6297 ISD::MemIndexType IndexType = MSC->getIndexType();
6298 bool Truncating = MSC->isTruncatingStore();
6299
6300 bool IsScaled = MSC->isIndexScaled();
6301 bool IsSigned = MSC->isIndexSigned();
6302
6303 // SVE supports an index scaled by sizeof(MemVT.elt) only, everything else
6304 // must be calculated before hand.
6305 uint64_t ScaleVal = Scale->getAsZExtVal();
6306 if (IsScaled && ScaleVal != MemVT.getScalarStoreSize()) {
6307 assert(isPowerOf2_64(ScaleVal) && "Expecting power-of-two types");
6308 EVT IndexVT = Index.getValueType();
6309 Index = DAG.getNode(ISD::SHL, DL, IndexVT, Index,
6310 DAG.getConstant(Log2_32(ScaleVal), DL, IndexVT));
6311 Scale = DAG.getTargetConstant(1, DL, Scale.getValueType());
6312
6313 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
6314 return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
6315 MSC->getMemOperand(), IndexType, Truncating);
6316 }
6317
6318 // Lower fixed length scatter to a scalable equivalent.
6319 if (VT.isFixedLengthVector()) {
6320 assert(Subtarget->useSVEForFixedLengthVectors() &&
6321 "Cannot lower when not using SVE for fixed vectors!");
6322
6323 // Once bitcast we treat floating-point scatters as if integer.
6324 if (VT.isFloatingPoint()) {
6325 VT = VT.changeVectorElementTypeToInteger();
6326 MemVT = MemVT.changeVectorElementTypeToInteger();
6327 StoreVal = DAG.getNode(ISD::BITCAST, DL, VT, StoreVal);
6328 }
6329
6330 // Find the smallest integer fixed length vector we can use for the scatter.
6331 EVT PromotedVT = VT.changeVectorElementType(MVT::i32);
6332 if (VT.getVectorElementType() == MVT::i64 ||
6333 Index.getValueType().getVectorElementType() == MVT::i64 ||
6334 Mask.getValueType().getVectorElementType() == MVT::i64)
6335 PromotedVT = VT.changeVectorElementType(MVT::i64);
6336
6337 // Promote vector operands.
6338 unsigned ExtOpcode = IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
6339 Index = DAG.getNode(ExtOpcode, DL, PromotedVT, Index);
6340 Mask = DAG.getNode(ISD::SIGN_EXTEND, DL, PromotedVT, Mask);
6341 StoreVal = DAG.getNode(ISD::ANY_EXTEND, DL, PromotedVT, StoreVal);
6342
6343 // A promoted value type forces the need for a truncating store.
6344 if (PromotedVT != VT)
6345 Truncating = true;
6346
6347 EVT ContainerVT = getContainerForFixedLengthVector(DAG, PromotedVT);
6348
6349 // Convert fixed length vector operands to scalable.
6350 MemVT = ContainerVT.changeVectorElementType(MemVT.getVectorElementType());
6351 Index = convertToScalableVector(DAG, ContainerVT, Index);
6352 Mask = convertFixedMaskToScalableVector(Mask, DAG);
6353 StoreVal = convertToScalableVector(DAG, ContainerVT, StoreVal);
6354
6355 // Emit equivalent scalable vector scatter.
6356 SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
6357 return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
6358 MSC->getMemOperand(), IndexType, Truncating);
6359 }
6360
6361 // Everything else is legal.
6362 return Op;
6363 }
6364
LowerMLOAD(SDValue Op,SelectionDAG & DAG) const6365 SDValue AArch64TargetLowering::LowerMLOAD(SDValue Op, SelectionDAG &DAG) const {
6366 SDLoc DL(Op);
6367 MaskedLoadSDNode *LoadNode = cast<MaskedLoadSDNode>(Op);
6368 assert(LoadNode && "Expected custom lowering of a masked load node");
6369 EVT VT = Op->getValueType(0);
6370
6371 if (useSVEForFixedLengthVectorVT(VT, /*OverrideNEON=*/true))
6372 return LowerFixedLengthVectorMLoadToSVE(Op, DAG);
6373
6374 SDValue PassThru = LoadNode->getPassThru();
6375 SDValue Mask = LoadNode->getMask();
6376
6377 if (PassThru->isUndef() || isZerosVector(PassThru.getNode()))
6378 return Op;
6379
6380 SDValue Load = DAG.getMaskedLoad(
6381 VT, DL, LoadNode->getChain(), LoadNode->getBasePtr(),
6382 LoadNode->getOffset(), Mask, DAG.getUNDEF(VT), LoadNode->getMemoryVT(),
6383 LoadNode->getMemOperand(), LoadNode->getAddressingMode(),
6384 LoadNode->getExtensionType());
6385
6386 SDValue Result = DAG.getSelect(DL, VT, Mask, Load, PassThru);
6387
6388 return DAG.getMergeValues({Result, Load.getValue(1)}, DL);
6389 }
6390
6391 // Custom lower trunc store for v4i8 vectors, since it is promoted to v4i16.
LowerTruncateVectorStore(SDLoc DL,StoreSDNode * ST,EVT VT,EVT MemVT,SelectionDAG & DAG)6392 static SDValue LowerTruncateVectorStore(SDLoc DL, StoreSDNode *ST,
6393 EVT VT, EVT MemVT,
6394 SelectionDAG &DAG) {
6395 assert(VT.isVector() && "VT should be a vector type");
6396 assert(MemVT == MVT::v4i8 && VT == MVT::v4i16);
6397
6398 SDValue Value = ST->getValue();
6399
6400 // It first extend the promoted v4i16 to v8i16, truncate to v8i8, and extract
6401 // the word lane which represent the v4i8 subvector. It optimizes the store
6402 // to:
6403 //
6404 // xtn v0.8b, v0.8h
6405 // str s0, [x0]
6406
6407 SDValue Undef = DAG.getUNDEF(MVT::i16);
6408 SDValue UndefVec = DAG.getBuildVector(MVT::v4i16, DL,
6409 {Undef, Undef, Undef, Undef});
6410
6411 SDValue TruncExt = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8i16,
6412 Value, UndefVec);
6413 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i8, TruncExt);
6414
6415 Trunc = DAG.getNode(ISD::BITCAST, DL, MVT::v2i32, Trunc);
6416 SDValue ExtractTrunc = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32,
6417 Trunc, DAG.getConstant(0, DL, MVT::i64));
6418
6419 return DAG.getStore(ST->getChain(), DL, ExtractTrunc,
6420 ST->getBasePtr(), ST->getMemOperand());
6421 }
6422
6423 // Custom lowering for any store, vector or scalar and/or default or with
6424 // a truncate operations. Currently only custom lower truncate operation
6425 // from vector v4i16 to v4i8 or volatile stores of i128.
LowerSTORE(SDValue Op,SelectionDAG & DAG) const6426 SDValue AArch64TargetLowering::LowerSTORE(SDValue Op,
6427 SelectionDAG &DAG) const {
6428 SDLoc Dl(Op);
6429 StoreSDNode *StoreNode = cast<StoreSDNode>(Op);
6430 assert (StoreNode && "Can only custom lower store nodes");
6431
6432 SDValue Value = StoreNode->getValue();
6433
6434 EVT VT = Value.getValueType();
6435 EVT MemVT = StoreNode->getMemoryVT();
6436
6437 if (VT.isVector()) {
6438 if (useSVEForFixedLengthVectorVT(
6439 VT,
6440 /*OverrideNEON=*/Subtarget->useSVEForFixedLengthVectors()))
6441 return LowerFixedLengthVectorStoreToSVE(Op, DAG);
6442
6443 unsigned AS = StoreNode->getAddressSpace();
6444 Align Alignment = StoreNode->getAlign();
6445 if (Alignment < MemVT.getStoreSize() &&
6446 !allowsMisalignedMemoryAccesses(MemVT, AS, Alignment,
6447 StoreNode->getMemOperand()->getFlags(),
6448 nullptr)) {
6449 return scalarizeVectorStore(StoreNode, DAG);
6450 }
6451
6452 if (StoreNode->isTruncatingStore() && VT == MVT::v4i16 &&
6453 MemVT == MVT::v4i8) {
6454 return LowerTruncateVectorStore(Dl, StoreNode, VT, MemVT, DAG);
6455 }
6456 // 256 bit non-temporal stores can be lowered to STNP. Do this as part of
6457 // the custom lowering, as there are no un-paired non-temporal stores and
6458 // legalization will break up 256 bit inputs.
6459 ElementCount EC = MemVT.getVectorElementCount();
6460 if (StoreNode->isNonTemporal() && MemVT.getSizeInBits() == 256u &&
6461 EC.isKnownEven() && DAG.getDataLayout().isLittleEndian() &&
6462 (MemVT.getScalarSizeInBits() == 8u ||
6463 MemVT.getScalarSizeInBits() == 16u ||
6464 MemVT.getScalarSizeInBits() == 32u ||
6465 MemVT.getScalarSizeInBits() == 64u)) {
6466 SDValue Lo =
6467 DAG.getNode(ISD::EXTRACT_SUBVECTOR, Dl,
6468 MemVT.getHalfNumVectorElementsVT(*DAG.getContext()),
6469 StoreNode->getValue(), DAG.getConstant(0, Dl, MVT::i64));
6470 SDValue Hi =
6471 DAG.getNode(ISD::EXTRACT_SUBVECTOR, Dl,
6472 MemVT.getHalfNumVectorElementsVT(*DAG.getContext()),
6473 StoreNode->getValue(),
6474 DAG.getConstant(EC.getKnownMinValue() / 2, Dl, MVT::i64));
6475 SDValue Result = DAG.getMemIntrinsicNode(
6476 AArch64ISD::STNP, Dl, DAG.getVTList(MVT::Other),
6477 {StoreNode->getChain(), Lo, Hi, StoreNode->getBasePtr()},
6478 StoreNode->getMemoryVT(), StoreNode->getMemOperand());
6479 return Result;
6480 }
6481 } else if (MemVT == MVT::i128 && StoreNode->isVolatile()) {
6482 return LowerStore128(Op, DAG);
6483 } else if (MemVT == MVT::i64x8) {
6484 SDValue Value = StoreNode->getValue();
6485 assert(Value->getValueType(0) == MVT::i64x8);
6486 SDValue Chain = StoreNode->getChain();
6487 SDValue Base = StoreNode->getBasePtr();
6488 EVT PtrVT = Base.getValueType();
6489 for (unsigned i = 0; i < 8; i++) {
6490 SDValue Part = DAG.getNode(AArch64ISD::LS64_EXTRACT, Dl, MVT::i64,
6491 Value, DAG.getConstant(i, Dl, MVT::i32));
6492 SDValue Ptr = DAG.getNode(ISD::ADD, Dl, PtrVT, Base,
6493 DAG.getConstant(i * 8, Dl, PtrVT));
6494 Chain = DAG.getStore(Chain, Dl, Part, Ptr, StoreNode->getPointerInfo(),
6495 StoreNode->getOriginalAlign());
6496 }
6497 return Chain;
6498 }
6499
6500 return SDValue();
6501 }
6502
6503 /// Lower atomic or volatile 128-bit stores to a single STP instruction.
LowerStore128(SDValue Op,SelectionDAG & DAG) const6504 SDValue AArch64TargetLowering::LowerStore128(SDValue Op,
6505 SelectionDAG &DAG) const {
6506 MemSDNode *StoreNode = cast<MemSDNode>(Op);
6507 assert(StoreNode->getMemoryVT() == MVT::i128);
6508 assert(StoreNode->isVolatile() || StoreNode->isAtomic());
6509
6510 bool IsStoreRelease =
6511 StoreNode->getMergedOrdering() == AtomicOrdering::Release;
6512 if (StoreNode->isAtomic())
6513 assert((Subtarget->hasFeature(AArch64::FeatureLSE2) &&
6514 Subtarget->hasFeature(AArch64::FeatureRCPC3) && IsStoreRelease) ||
6515 StoreNode->getMergedOrdering() == AtomicOrdering::Unordered ||
6516 StoreNode->getMergedOrdering() == AtomicOrdering::Monotonic);
6517
6518 SDValue Value = (StoreNode->getOpcode() == ISD::STORE ||
6519 StoreNode->getOpcode() == ISD::ATOMIC_STORE)
6520 ? StoreNode->getOperand(1)
6521 : StoreNode->getOperand(2);
6522 SDLoc DL(Op);
6523 auto StoreValue = DAG.SplitScalar(Value, DL, MVT::i64, MVT::i64);
6524 unsigned Opcode = IsStoreRelease ? AArch64ISD::STILP : AArch64ISD::STP;
6525 if (DAG.getDataLayout().isBigEndian())
6526 std::swap(StoreValue.first, StoreValue.second);
6527 SDValue Result = DAG.getMemIntrinsicNode(
6528 Opcode, DL, DAG.getVTList(MVT::Other),
6529 {StoreNode->getChain(), StoreValue.first, StoreValue.second,
6530 StoreNode->getBasePtr()},
6531 StoreNode->getMemoryVT(), StoreNode->getMemOperand());
6532 return Result;
6533 }
6534
LowerLOAD(SDValue Op,SelectionDAG & DAG) const6535 SDValue AArch64TargetLowering::LowerLOAD(SDValue Op,
6536 SelectionDAG &DAG) const {
6537 SDLoc DL(Op);
6538 LoadSDNode *LoadNode = cast<LoadSDNode>(Op);
6539 assert(LoadNode && "Expected custom lowering of a load node");
6540
6541 if (LoadNode->getMemoryVT() == MVT::i64x8) {
6542 SmallVector<SDValue, 8> Ops;
6543 SDValue Base = LoadNode->getBasePtr();
6544 SDValue Chain = LoadNode->getChain();
6545 EVT PtrVT = Base.getValueType();
6546 for (unsigned i = 0; i < 8; i++) {
6547 SDValue Ptr = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
6548 DAG.getConstant(i * 8, DL, PtrVT));
6549 SDValue Part = DAG.getLoad(MVT::i64, DL, Chain, Ptr,
6550 LoadNode->getPointerInfo(),
6551 LoadNode->getOriginalAlign());
6552 Ops.push_back(Part);
6553 Chain = SDValue(Part.getNode(), 1);
6554 }
6555 SDValue Loaded = DAG.getNode(AArch64ISD::LS64_BUILD, DL, MVT::i64x8, Ops);
6556 return DAG.getMergeValues({Loaded, Chain}, DL);
6557 }
6558
6559 // Custom lowering for extending v4i8 vector loads.
6560 EVT VT = Op->getValueType(0);
6561 assert((VT == MVT::v4i16 || VT == MVT::v4i32) && "Expected v4i16 or v4i32");
6562
6563 if (LoadNode->getMemoryVT() != MVT::v4i8)
6564 return SDValue();
6565
6566 // Avoid generating unaligned loads.
6567 if (Subtarget->requiresStrictAlign() && LoadNode->getAlign() < Align(4))
6568 return SDValue();
6569
6570 unsigned ExtType;
6571 if (LoadNode->getExtensionType() == ISD::SEXTLOAD)
6572 ExtType = ISD::SIGN_EXTEND;
6573 else if (LoadNode->getExtensionType() == ISD::ZEXTLOAD ||
6574 LoadNode->getExtensionType() == ISD::EXTLOAD)
6575 ExtType = ISD::ZERO_EXTEND;
6576 else
6577 return SDValue();
6578
6579 SDValue Load = DAG.getLoad(MVT::f32, DL, LoadNode->getChain(),
6580 LoadNode->getBasePtr(), MachinePointerInfo());
6581 SDValue Chain = Load.getValue(1);
6582 SDValue Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v2f32, Load);
6583 SDValue BC = DAG.getNode(ISD::BITCAST, DL, MVT::v8i8, Vec);
6584 SDValue Ext = DAG.getNode(ExtType, DL, MVT::v8i16, BC);
6585 Ext = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v4i16, Ext,
6586 DAG.getConstant(0, DL, MVT::i64));
6587 if (VT == MVT::v4i32)
6588 Ext = DAG.getNode(ExtType, DL, MVT::v4i32, Ext);
6589 return DAG.getMergeValues({Ext, Chain}, DL);
6590 }
6591
6592 // Generate SUBS and CSEL for integer abs.
LowerABS(SDValue Op,SelectionDAG & DAG) const6593 SDValue AArch64TargetLowering::LowerABS(SDValue Op, SelectionDAG &DAG) const {
6594 MVT VT = Op.getSimpleValueType();
6595
6596 if (VT.isVector())
6597 return LowerToPredicatedOp(Op, DAG, AArch64ISD::ABS_MERGE_PASSTHRU);
6598
6599 SDLoc DL(Op);
6600 SDValue Neg = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
6601 Op.getOperand(0));
6602 // Generate SUBS & CSEL.
6603 SDValue Cmp =
6604 DAG.getNode(AArch64ISD::SUBS, DL, DAG.getVTList(VT, MVT::i32),
6605 Op.getOperand(0), DAG.getConstant(0, DL, VT));
6606 return DAG.getNode(AArch64ISD::CSEL, DL, VT, Op.getOperand(0), Neg,
6607 DAG.getConstant(AArch64CC::PL, DL, MVT::i32),
6608 Cmp.getValue(1));
6609 }
6610
LowerBRCOND(SDValue Op,SelectionDAG & DAG)6611 static SDValue LowerBRCOND(SDValue Op, SelectionDAG &DAG) {
6612 SDValue Chain = Op.getOperand(0);
6613 SDValue Cond = Op.getOperand(1);
6614 SDValue Dest = Op.getOperand(2);
6615
6616 AArch64CC::CondCode CC;
6617 if (SDValue Cmp = emitConjunction(DAG, Cond, CC)) {
6618 SDLoc dl(Op);
6619 SDValue CCVal = DAG.getConstant(CC, dl, MVT::i32);
6620 return DAG.getNode(AArch64ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
6621 Cmp);
6622 }
6623
6624 return SDValue();
6625 }
6626
6627 // Treat FSHR with constant shifts as legal operation, otherwise it is expanded
6628 // FSHL is converted to FSHR before deciding what to do with it
LowerFunnelShift(SDValue Op,SelectionDAG & DAG)6629 static SDValue LowerFunnelShift(SDValue Op, SelectionDAG &DAG) {
6630 SDValue Shifts = Op.getOperand(2);
6631 // Check if the shift amount is a constant
6632 // If opcode is FSHL, convert it to FSHR
6633 if (auto *ShiftNo = dyn_cast<ConstantSDNode>(Shifts)) {
6634 SDLoc DL(Op);
6635 MVT VT = Op.getSimpleValueType();
6636
6637 if (Op.getOpcode() == ISD::FSHL) {
6638 unsigned int NewShiftNo =
6639 VT.getFixedSizeInBits() - ShiftNo->getZExtValue();
6640 return DAG.getNode(
6641 ISD::FSHR, DL, VT, Op.getOperand(0), Op.getOperand(1),
6642 DAG.getConstant(NewShiftNo, DL, Shifts.getValueType()));
6643 } else if (Op.getOpcode() == ISD::FSHR) {
6644 return Op;
6645 }
6646 }
6647
6648 return SDValue();
6649 }
6650
LowerFLDEXP(SDValue Op,SelectionDAG & DAG)6651 static SDValue LowerFLDEXP(SDValue Op, SelectionDAG &DAG) {
6652 SDValue X = Op.getOperand(0);
6653 EVT XScalarTy = X.getValueType();
6654 SDValue Exp = Op.getOperand(1);
6655
6656 SDLoc DL(Op);
6657 EVT XVT, ExpVT;
6658 switch (Op.getSimpleValueType().SimpleTy) {
6659 default:
6660 return SDValue();
6661 case MVT::bf16:
6662 case MVT::f16:
6663 X = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, X);
6664 [[fallthrough]];
6665 case MVT::f32:
6666 XVT = MVT::nxv4f32;
6667 ExpVT = MVT::nxv4i32;
6668 break;
6669 case MVT::f64:
6670 XVT = MVT::nxv2f64;
6671 ExpVT = MVT::nxv2i64;
6672 Exp = DAG.getNode(ISD::SIGN_EXTEND, DL, MVT::i64, Exp);
6673 break;
6674 }
6675
6676 SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
6677 SDValue VX =
6678 DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, XVT, DAG.getUNDEF(XVT), X, Zero);
6679 SDValue VExp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ExpVT,
6680 DAG.getUNDEF(ExpVT), Exp, Zero);
6681 SDValue VPg = getPTrue(DAG, DL, XVT.changeVectorElementType(MVT::i1),
6682 AArch64SVEPredPattern::all);
6683 SDValue FScale =
6684 DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, XVT,
6685 DAG.getConstant(Intrinsic::aarch64_sve_fscale, DL, MVT::i64),
6686 VPg, VX, VExp);
6687 SDValue Final =
6688 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, X.getValueType(), FScale, Zero);
6689 if (X.getValueType() != XScalarTy)
6690 Final = DAG.getNode(ISD::FP_ROUND, DL, XScalarTy, Final,
6691 DAG.getIntPtrConstant(1, SDLoc(Op)));
6692 return Final;
6693 }
6694
LowerADJUST_TRAMPOLINE(SDValue Op,SelectionDAG & DAG) const6695 SDValue AArch64TargetLowering::LowerADJUST_TRAMPOLINE(SDValue Op,
6696 SelectionDAG &DAG) const {
6697 // Note: x18 cannot be used for the Nest parameter on Windows and macOS.
6698 if (Subtarget->isTargetDarwin() || Subtarget->isTargetWindows())
6699 report_fatal_error(
6700 "ADJUST_TRAMPOLINE operation is only supported on Linux.");
6701
6702 return Op.getOperand(0);
6703 }
6704
LowerINIT_TRAMPOLINE(SDValue Op,SelectionDAG & DAG) const6705 SDValue AArch64TargetLowering::LowerINIT_TRAMPOLINE(SDValue Op,
6706 SelectionDAG &DAG) const {
6707
6708 // Note: x18 cannot be used for the Nest parameter on Windows and macOS.
6709 if (Subtarget->isTargetDarwin() || Subtarget->isTargetWindows())
6710 report_fatal_error("INIT_TRAMPOLINE operation is only supported on Linux.");
6711
6712 SDValue Chain = Op.getOperand(0);
6713 SDValue Trmp = Op.getOperand(1); // trampoline
6714 SDValue FPtr = Op.getOperand(2); // nested function
6715 SDValue Nest = Op.getOperand(3); // 'nest' parameter value
6716 SDLoc dl(Op);
6717
6718 EVT PtrVT = getPointerTy(DAG.getDataLayout());
6719 Type *IntPtrTy = DAG.getDataLayout().getIntPtrType(*DAG.getContext());
6720
6721 TargetLowering::ArgListTy Args;
6722 TargetLowering::ArgListEntry Entry;
6723
6724 Entry.Ty = IntPtrTy;
6725 Entry.Node = Trmp;
6726 Args.push_back(Entry);
6727 Entry.Node = DAG.getConstant(20, dl, MVT::i64);
6728 Args.push_back(Entry);
6729
6730 Entry.Node = FPtr;
6731 Args.push_back(Entry);
6732 Entry.Node = Nest;
6733 Args.push_back(Entry);
6734
6735 // Lower to a call to __trampoline_setup(Trmp, TrampSize, FPtr, ctx_reg)
6736 TargetLowering::CallLoweringInfo CLI(DAG);
6737 CLI.setDebugLoc(dl).setChain(Chain).setLibCallee(
6738 CallingConv::C, Type::getVoidTy(*DAG.getContext()),
6739 DAG.getExternalSymbol("__trampoline_setup", PtrVT), std::move(Args));
6740
6741 std::pair<SDValue, SDValue> CallResult = LowerCallTo(CLI);
6742 return CallResult.second;
6743 }
6744
LowerOperation(SDValue Op,SelectionDAG & DAG) const6745 SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
6746 SelectionDAG &DAG) const {
6747 LLVM_DEBUG(dbgs() << "Custom lowering: ");
6748 LLVM_DEBUG(Op.dump());
6749
6750 switch (Op.getOpcode()) {
6751 default:
6752 llvm_unreachable("unimplemented operand");
6753 return SDValue();
6754 case ISD::BITCAST:
6755 return LowerBITCAST(Op, DAG);
6756 case ISD::GlobalAddress:
6757 return LowerGlobalAddress(Op, DAG);
6758 case ISD::GlobalTLSAddress:
6759 return LowerGlobalTLSAddress(Op, DAG);
6760 case ISD::PtrAuthGlobalAddress:
6761 return LowerPtrAuthGlobalAddress(Op, DAG);
6762 case ISD::ADJUST_TRAMPOLINE:
6763 return LowerADJUST_TRAMPOLINE(Op, DAG);
6764 case ISD::INIT_TRAMPOLINE:
6765 return LowerINIT_TRAMPOLINE(Op, DAG);
6766 case ISD::SETCC:
6767 case ISD::STRICT_FSETCC:
6768 case ISD::STRICT_FSETCCS:
6769 return LowerSETCC(Op, DAG);
6770 case ISD::SETCCCARRY:
6771 return LowerSETCCCARRY(Op, DAG);
6772 case ISD::BRCOND:
6773 return LowerBRCOND(Op, DAG);
6774 case ISD::BR_CC:
6775 return LowerBR_CC(Op, DAG);
6776 case ISD::SELECT:
6777 return LowerSELECT(Op, DAG);
6778 case ISD::SELECT_CC:
6779 return LowerSELECT_CC(Op, DAG);
6780 case ISD::JumpTable:
6781 return LowerJumpTable(Op, DAG);
6782 case ISD::BR_JT:
6783 return LowerBR_JT(Op, DAG);
6784 case ISD::BRIND:
6785 return LowerBRIND(Op, DAG);
6786 case ISD::ConstantPool:
6787 return LowerConstantPool(Op, DAG);
6788 case ISD::BlockAddress:
6789 return LowerBlockAddress(Op, DAG);
6790 case ISD::VASTART:
6791 return LowerVASTART(Op, DAG);
6792 case ISD::VACOPY:
6793 return LowerVACOPY(Op, DAG);
6794 case ISD::VAARG:
6795 return LowerVAARG(Op, DAG);
6796 case ISD::UADDO_CARRY:
6797 return lowerADDSUBO_CARRY(Op, DAG, AArch64ISD::ADCS, false /*unsigned*/);
6798 case ISD::USUBO_CARRY:
6799 return lowerADDSUBO_CARRY(Op, DAG, AArch64ISD::SBCS, false /*unsigned*/);
6800 case ISD::SADDO_CARRY:
6801 return lowerADDSUBO_CARRY(Op, DAG, AArch64ISD::ADCS, true /*signed*/);
6802 case ISD::SSUBO_CARRY:
6803 return lowerADDSUBO_CARRY(Op, DAG, AArch64ISD::SBCS, true /*signed*/);
6804 case ISD::SADDO:
6805 case ISD::UADDO:
6806 case ISD::SSUBO:
6807 case ISD::USUBO:
6808 case ISD::SMULO:
6809 case ISD::UMULO:
6810 return LowerXALUO(Op, DAG);
6811 case ISD::FADD:
6812 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FADD_PRED);
6813 case ISD::FSUB:
6814 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSUB_PRED);
6815 case ISD::FMUL:
6816 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED);
6817 case ISD::FMA:
6818 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED);
6819 case ISD::FDIV:
6820 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED);
6821 case ISD::FNEG:
6822 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEG_MERGE_PASSTHRU);
6823 case ISD::FCEIL:
6824 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FCEIL_MERGE_PASSTHRU);
6825 case ISD::FFLOOR:
6826 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FFLOOR_MERGE_PASSTHRU);
6827 case ISD::FNEARBYINT:
6828 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FNEARBYINT_MERGE_PASSTHRU);
6829 case ISD::FRINT:
6830 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FRINT_MERGE_PASSTHRU);
6831 case ISD::FROUND:
6832 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUND_MERGE_PASSTHRU);
6833 case ISD::FROUNDEVEN:
6834 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FROUNDEVEN_MERGE_PASSTHRU);
6835 case ISD::FTRUNC:
6836 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FTRUNC_MERGE_PASSTHRU);
6837 case ISD::FSQRT:
6838 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FSQRT_MERGE_PASSTHRU);
6839 case ISD::FABS:
6840 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FABS_MERGE_PASSTHRU);
6841 case ISD::FP_ROUND:
6842 case ISD::STRICT_FP_ROUND:
6843 return LowerFP_ROUND(Op, DAG);
6844 case ISD::FP_EXTEND:
6845 return LowerFP_EXTEND(Op, DAG);
6846 case ISD::FRAMEADDR:
6847 return LowerFRAMEADDR(Op, DAG);
6848 case ISD::SPONENTRY:
6849 return LowerSPONENTRY(Op, DAG);
6850 case ISD::RETURNADDR:
6851 return LowerRETURNADDR(Op, DAG);
6852 case ISD::ADDROFRETURNADDR:
6853 return LowerADDROFRETURNADDR(Op, DAG);
6854 case ISD::CONCAT_VECTORS:
6855 return LowerCONCAT_VECTORS(Op, DAG);
6856 case ISD::INSERT_VECTOR_ELT:
6857 return LowerINSERT_VECTOR_ELT(Op, DAG);
6858 case ISD::EXTRACT_VECTOR_ELT:
6859 return LowerEXTRACT_VECTOR_ELT(Op, DAG);
6860 case ISD::BUILD_VECTOR:
6861 return LowerBUILD_VECTOR(Op, DAG);
6862 case ISD::ZERO_EXTEND_VECTOR_INREG:
6863 return LowerZERO_EXTEND_VECTOR_INREG(Op, DAG);
6864 case ISD::VECTOR_SHUFFLE:
6865 return LowerVECTOR_SHUFFLE(Op, DAG);
6866 case ISD::SPLAT_VECTOR:
6867 return LowerSPLAT_VECTOR(Op, DAG);
6868 case ISD::EXTRACT_SUBVECTOR:
6869 return LowerEXTRACT_SUBVECTOR(Op, DAG);
6870 case ISD::INSERT_SUBVECTOR:
6871 return LowerINSERT_SUBVECTOR(Op, DAG);
6872 case ISD::SDIV:
6873 case ISD::UDIV:
6874 return LowerDIV(Op, DAG);
6875 case ISD::SMIN:
6876 case ISD::UMIN:
6877 case ISD::SMAX:
6878 case ISD::UMAX:
6879 return LowerMinMax(Op, DAG);
6880 case ISD::SRA:
6881 case ISD::SRL:
6882 case ISD::SHL:
6883 return LowerVectorSRA_SRL_SHL(Op, DAG);
6884 case ISD::SHL_PARTS:
6885 case ISD::SRL_PARTS:
6886 case ISD::SRA_PARTS:
6887 return LowerShiftParts(Op, DAG);
6888 case ISD::CTPOP:
6889 case ISD::PARITY:
6890 return LowerCTPOP_PARITY(Op, DAG);
6891 case ISD::FCOPYSIGN:
6892 return LowerFCOPYSIGN(Op, DAG);
6893 case ISD::OR:
6894 return LowerVectorOR(Op, DAG);
6895 case ISD::XOR:
6896 return LowerXOR(Op, DAG);
6897 case ISD::PREFETCH:
6898 return LowerPREFETCH(Op, DAG);
6899 case ISD::SINT_TO_FP:
6900 case ISD::UINT_TO_FP:
6901 case ISD::STRICT_SINT_TO_FP:
6902 case ISD::STRICT_UINT_TO_FP:
6903 return LowerINT_TO_FP(Op, DAG);
6904 case ISD::FP_TO_SINT:
6905 case ISD::FP_TO_UINT:
6906 case ISD::STRICT_FP_TO_SINT:
6907 case ISD::STRICT_FP_TO_UINT:
6908 return LowerFP_TO_INT(Op, DAG);
6909 case ISD::FP_TO_SINT_SAT:
6910 case ISD::FP_TO_UINT_SAT:
6911 return LowerFP_TO_INT_SAT(Op, DAG);
6912 case ISD::FSINCOS:
6913 return LowerFSINCOS(Op, DAG);
6914 case ISD::GET_ROUNDING:
6915 return LowerGET_ROUNDING(Op, DAG);
6916 case ISD::SET_ROUNDING:
6917 return LowerSET_ROUNDING(Op, DAG);
6918 case ISD::GET_FPMODE:
6919 return LowerGET_FPMODE(Op, DAG);
6920 case ISD::SET_FPMODE:
6921 return LowerSET_FPMODE(Op, DAG);
6922 case ISD::RESET_FPMODE:
6923 return LowerRESET_FPMODE(Op, DAG);
6924 case ISD::MUL:
6925 return LowerMUL(Op, DAG);
6926 case ISD::MULHS:
6927 return LowerToPredicatedOp(Op, DAG, AArch64ISD::MULHS_PRED);
6928 case ISD::MULHU:
6929 return LowerToPredicatedOp(Op, DAG, AArch64ISD::MULHU_PRED);
6930 case ISD::INTRINSIC_W_CHAIN:
6931 return LowerINTRINSIC_W_CHAIN(Op, DAG);
6932 case ISD::INTRINSIC_WO_CHAIN:
6933 return LowerINTRINSIC_WO_CHAIN(Op, DAG);
6934 case ISD::INTRINSIC_VOID:
6935 return LowerINTRINSIC_VOID(Op, DAG);
6936 case ISD::ATOMIC_STORE:
6937 if (cast<MemSDNode>(Op)->getMemoryVT() == MVT::i128) {
6938 assert(Subtarget->hasLSE2() || Subtarget->hasRCPC3());
6939 return LowerStore128(Op, DAG);
6940 }
6941 return SDValue();
6942 case ISD::STORE:
6943 return LowerSTORE(Op, DAG);
6944 case ISD::MSTORE:
6945 return LowerFixedLengthVectorMStoreToSVE(Op, DAG);
6946 case ISD::MGATHER:
6947 return LowerMGATHER(Op, DAG);
6948 case ISD::MSCATTER:
6949 return LowerMSCATTER(Op, DAG);
6950 case ISD::VECREDUCE_SEQ_FADD:
6951 return LowerVECREDUCE_SEQ_FADD(Op, DAG);
6952 case ISD::VECREDUCE_ADD:
6953 case ISD::VECREDUCE_AND:
6954 case ISD::VECREDUCE_OR:
6955 case ISD::VECREDUCE_XOR:
6956 case ISD::VECREDUCE_SMAX:
6957 case ISD::VECREDUCE_SMIN:
6958 case ISD::VECREDUCE_UMAX:
6959 case ISD::VECREDUCE_UMIN:
6960 case ISD::VECREDUCE_FADD:
6961 case ISD::VECREDUCE_FMAX:
6962 case ISD::VECREDUCE_FMIN:
6963 case ISD::VECREDUCE_FMAXIMUM:
6964 case ISD::VECREDUCE_FMINIMUM:
6965 return LowerVECREDUCE(Op, DAG);
6966 case ISD::ATOMIC_LOAD_AND:
6967 return LowerATOMIC_LOAD_AND(Op, DAG);
6968 case ISD::DYNAMIC_STACKALLOC:
6969 return LowerDYNAMIC_STACKALLOC(Op, DAG);
6970 case ISD::VSCALE:
6971 return LowerVSCALE(Op, DAG);
6972 case ISD::ANY_EXTEND:
6973 case ISD::SIGN_EXTEND:
6974 case ISD::ZERO_EXTEND:
6975 return LowerFixedLengthVectorIntExtendToSVE(Op, DAG);
6976 case ISD::SIGN_EXTEND_INREG: {
6977 // Only custom lower when ExtraVT has a legal byte based element type.
6978 EVT ExtraVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
6979 EVT ExtraEltVT = ExtraVT.getVectorElementType();
6980 if ((ExtraEltVT != MVT::i8) && (ExtraEltVT != MVT::i16) &&
6981 (ExtraEltVT != MVT::i32) && (ExtraEltVT != MVT::i64))
6982 return SDValue();
6983
6984 return LowerToPredicatedOp(Op, DAG,
6985 AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU);
6986 }
6987 case ISD::TRUNCATE:
6988 return LowerTRUNCATE(Op, DAG);
6989 case ISD::MLOAD:
6990 return LowerMLOAD(Op, DAG);
6991 case ISD::LOAD:
6992 if (useSVEForFixedLengthVectorVT(Op.getValueType(),
6993 !Subtarget->isNeonAvailable()))
6994 return LowerFixedLengthVectorLoadToSVE(Op, DAG);
6995 return LowerLOAD(Op, DAG);
6996 case ISD::ADD:
6997 case ISD::AND:
6998 case ISD::SUB:
6999 return LowerToScalableOp(Op, DAG);
7000 case ISD::FMAXIMUM:
7001 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAX_PRED);
7002 case ISD::FMAXNUM:
7003 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMAXNM_PRED);
7004 case ISD::FMINIMUM:
7005 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMIN_PRED);
7006 case ISD::FMINNUM:
7007 return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMINNM_PRED);
7008 case ISD::VSELECT:
7009 return LowerFixedLengthVectorSelectToSVE(Op, DAG);
7010 case ISD::ABS:
7011 return LowerABS(Op, DAG);
7012 case ISD::ABDS:
7013 return LowerToPredicatedOp(Op, DAG, AArch64ISD::ABDS_PRED);
7014 case ISD::ABDU:
7015 return LowerToPredicatedOp(Op, DAG, AArch64ISD::ABDU_PRED);
7016 case ISD::AVGFLOORS:
7017 return LowerAVG(Op, DAG, AArch64ISD::HADDS_PRED);
7018 case ISD::AVGFLOORU:
7019 return LowerAVG(Op, DAG, AArch64ISD::HADDU_PRED);
7020 case ISD::AVGCEILS:
7021 return LowerAVG(Op, DAG, AArch64ISD::RHADDS_PRED);
7022 case ISD::AVGCEILU:
7023 return LowerAVG(Op, DAG, AArch64ISD::RHADDU_PRED);
7024 case ISD::BITREVERSE:
7025 return LowerBitreverse(Op, DAG);
7026 case ISD::BSWAP:
7027 return LowerToPredicatedOp(Op, DAG, AArch64ISD::BSWAP_MERGE_PASSTHRU);
7028 case ISD::CTLZ:
7029 return LowerToPredicatedOp(Op, DAG, AArch64ISD::CTLZ_MERGE_PASSTHRU);
7030 case ISD::CTTZ:
7031 return LowerCTTZ(Op, DAG);
7032 case ISD::VECTOR_SPLICE:
7033 return LowerVECTOR_SPLICE(Op, DAG);
7034 case ISD::VECTOR_DEINTERLEAVE:
7035 return LowerVECTOR_DEINTERLEAVE(Op, DAG);
7036 case ISD::VECTOR_INTERLEAVE:
7037 return LowerVECTOR_INTERLEAVE(Op, DAG);
7038 case ISD::LRINT:
7039 case ISD::LLRINT:
7040 if (Op.getValueType().isVector())
7041 return LowerVectorXRINT(Op, DAG);
7042 [[fallthrough]];
7043 case ISD::LROUND:
7044 case ISD::LLROUND: {
7045 assert((Op.getOperand(0).getValueType() == MVT::f16 ||
7046 Op.getOperand(0).getValueType() == MVT::bf16) &&
7047 "Expected custom lowering of rounding operations only for f16");
7048 SDLoc DL(Op);
7049 SDValue Ext = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op.getOperand(0));
7050 return DAG.getNode(Op.getOpcode(), DL, Op.getValueType(), Ext);
7051 }
7052 case ISD::STRICT_LROUND:
7053 case ISD::STRICT_LLROUND:
7054 case ISD::STRICT_LRINT:
7055 case ISD::STRICT_LLRINT: {
7056 assert((Op.getOperand(1).getValueType() == MVT::f16 ||
7057 Op.getOperand(1).getValueType() == MVT::bf16) &&
7058 "Expected custom lowering of rounding operations only for f16");
7059 SDLoc DL(Op);
7060 SDValue Ext = DAG.getNode(ISD::STRICT_FP_EXTEND, DL, {MVT::f32, MVT::Other},
7061 {Op.getOperand(0), Op.getOperand(1)});
7062 return DAG.getNode(Op.getOpcode(), DL, {Op.getValueType(), MVT::Other},
7063 {Ext.getValue(1), Ext.getValue(0)});
7064 }
7065 case ISD::WRITE_REGISTER: {
7066 assert(Op.getOperand(2).getValueType() == MVT::i128 &&
7067 "WRITE_REGISTER custom lowering is only for 128-bit sysregs");
7068 SDLoc DL(Op);
7069
7070 SDValue Chain = Op.getOperand(0);
7071 SDValue SysRegName = Op.getOperand(1);
7072 std::pair<SDValue, SDValue> Pair =
7073 DAG.SplitScalar(Op.getOperand(2), DL, MVT::i64, MVT::i64);
7074
7075 // chain = MSRR(chain, sysregname, lo, hi)
7076 SDValue Result = DAG.getNode(AArch64ISD::MSRR, DL, MVT::Other, Chain,
7077 SysRegName, Pair.first, Pair.second);
7078
7079 return Result;
7080 }
7081 case ISD::FSHL:
7082 case ISD::FSHR:
7083 return LowerFunnelShift(Op, DAG);
7084 case ISD::FLDEXP:
7085 return LowerFLDEXP(Op, DAG);
7086 case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
7087 return LowerVECTOR_HISTOGRAM(Op, DAG);
7088 }
7089 }
7090
mergeStoresAfterLegalization(EVT VT) const7091 bool AArch64TargetLowering::mergeStoresAfterLegalization(EVT VT) const {
7092 return !Subtarget->useSVEForFixedLengthVectors();
7093 }
7094
useSVEForFixedLengthVectorVT(EVT VT,bool OverrideNEON) const7095 bool AArch64TargetLowering::useSVEForFixedLengthVectorVT(
7096 EVT VT, bool OverrideNEON) const {
7097 if (!VT.isFixedLengthVector() || !VT.isSimple())
7098 return false;
7099
7100 // Don't use SVE for vectors we cannot scalarize if required.
7101 switch (VT.getVectorElementType().getSimpleVT().SimpleTy) {
7102 // Fixed length predicates should be promoted to i8.
7103 // NOTE: This is consistent with how NEON (and thus 64/128bit vectors) work.
7104 case MVT::i1:
7105 default:
7106 return false;
7107 case MVT::i8:
7108 case MVT::i16:
7109 case MVT::i32:
7110 case MVT::i64:
7111 case MVT::f16:
7112 case MVT::f32:
7113 case MVT::f64:
7114 break;
7115 }
7116
7117 // NEON-sized vectors can be emulated using SVE instructions.
7118 if (OverrideNEON && (VT.is128BitVector() || VT.is64BitVector()))
7119 return Subtarget->isSVEorStreamingSVEAvailable();
7120
7121 // Ensure NEON MVTs only belong to a single register class.
7122 if (VT.getFixedSizeInBits() <= 128)
7123 return false;
7124
7125 // Ensure wider than NEON code generation is enabled.
7126 if (!Subtarget->useSVEForFixedLengthVectors())
7127 return false;
7128
7129 // Don't use SVE for types that don't fit.
7130 if (VT.getFixedSizeInBits() > Subtarget->getMinSVEVectorSizeInBits())
7131 return false;
7132
7133 // TODO: Perhaps an artificial restriction, but worth having whilst getting
7134 // the base fixed length SVE support in place.
7135 if (!VT.isPow2VectorType())
7136 return false;
7137
7138 return true;
7139 }
7140
7141 //===----------------------------------------------------------------------===//
7142 // Calling Convention Implementation
7143 //===----------------------------------------------------------------------===//
7144
getIntrinsicID(const SDNode * N)7145 static unsigned getIntrinsicID(const SDNode *N) {
7146 unsigned Opcode = N->getOpcode();
7147 switch (Opcode) {
7148 default:
7149 return Intrinsic::not_intrinsic;
7150 case ISD::INTRINSIC_WO_CHAIN: {
7151 unsigned IID = N->getConstantOperandVal(0);
7152 if (IID < Intrinsic::num_intrinsics)
7153 return IID;
7154 return Intrinsic::not_intrinsic;
7155 }
7156 }
7157 }
7158
isReassocProfitable(SelectionDAG & DAG,SDValue N0,SDValue N1) const7159 bool AArch64TargetLowering::isReassocProfitable(SelectionDAG &DAG, SDValue N0,
7160 SDValue N1) const {
7161 if (!N0.hasOneUse())
7162 return false;
7163
7164 unsigned IID = getIntrinsicID(N1.getNode());
7165 // Avoid reassociating expressions that can be lowered to smlal/umlal.
7166 if (IID == Intrinsic::aarch64_neon_umull ||
7167 N1.getOpcode() == AArch64ISD::UMULL ||
7168 IID == Intrinsic::aarch64_neon_smull ||
7169 N1.getOpcode() == AArch64ISD::SMULL)
7170 return N0.getOpcode() != ISD::ADD;
7171
7172 return true;
7173 }
7174
7175 /// Selects the correct CCAssignFn for a given CallingConvention value.
CCAssignFnForCall(CallingConv::ID CC,bool IsVarArg) const7176 CCAssignFn *AArch64TargetLowering::CCAssignFnForCall(CallingConv::ID CC,
7177 bool IsVarArg) const {
7178 switch (CC) {
7179 default:
7180 report_fatal_error("Unsupported calling convention.");
7181 case CallingConv::GHC:
7182 return CC_AArch64_GHC;
7183 case CallingConv::PreserveNone:
7184 // The VarArg implementation makes assumptions about register
7185 // argument passing that do not hold for preserve_none, so we
7186 // instead fall back to C argument passing.
7187 // The non-vararg case is handled in the CC function itself.
7188 if (!IsVarArg)
7189 return CC_AArch64_Preserve_None;
7190 [[fallthrough]];
7191 case CallingConv::C:
7192 case CallingConv::Fast:
7193 case CallingConv::PreserveMost:
7194 case CallingConv::PreserveAll:
7195 case CallingConv::CXX_FAST_TLS:
7196 case CallingConv::Swift:
7197 case CallingConv::SwiftTail:
7198 case CallingConv::Tail:
7199 case CallingConv::GRAAL:
7200 if (Subtarget->isTargetWindows()) {
7201 if (IsVarArg) {
7202 if (Subtarget->isWindowsArm64EC())
7203 return CC_AArch64_Arm64EC_VarArg;
7204 return CC_AArch64_Win64_VarArg;
7205 }
7206 return CC_AArch64_Win64PCS;
7207 }
7208 if (!Subtarget->isTargetDarwin())
7209 return CC_AArch64_AAPCS;
7210 if (!IsVarArg)
7211 return CC_AArch64_DarwinPCS;
7212 return Subtarget->isTargetILP32() ? CC_AArch64_DarwinPCS_ILP32_VarArg
7213 : CC_AArch64_DarwinPCS_VarArg;
7214 case CallingConv::Win64:
7215 if (IsVarArg) {
7216 if (Subtarget->isWindowsArm64EC())
7217 return CC_AArch64_Arm64EC_VarArg;
7218 return CC_AArch64_Win64_VarArg;
7219 }
7220 return CC_AArch64_Win64PCS;
7221 case CallingConv::CFGuard_Check:
7222 if (Subtarget->isWindowsArm64EC())
7223 return CC_AArch64_Arm64EC_CFGuard_Check;
7224 return CC_AArch64_Win64_CFGuard_Check;
7225 case CallingConv::AArch64_VectorCall:
7226 case CallingConv::AArch64_SVE_VectorCall:
7227 case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
7228 case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:
7229 return CC_AArch64_AAPCS;
7230 case CallingConv::ARM64EC_Thunk_X64:
7231 return CC_AArch64_Arm64EC_Thunk;
7232 case CallingConv::ARM64EC_Thunk_Native:
7233 return CC_AArch64_Arm64EC_Thunk_Native;
7234 }
7235 }
7236
7237 CCAssignFn *
CCAssignFnForReturn(CallingConv::ID CC) const7238 AArch64TargetLowering::CCAssignFnForReturn(CallingConv::ID CC) const {
7239 switch (CC) {
7240 default:
7241 return RetCC_AArch64_AAPCS;
7242 case CallingConv::ARM64EC_Thunk_X64:
7243 return RetCC_AArch64_Arm64EC_Thunk;
7244 case CallingConv::CFGuard_Check:
7245 if (Subtarget->isWindowsArm64EC())
7246 return RetCC_AArch64_Arm64EC_CFGuard_Check;
7247 return RetCC_AArch64_AAPCS;
7248 }
7249 }
7250
isPassedInFPR(EVT VT)7251 static bool isPassedInFPR(EVT VT) {
7252 return VT.isFixedLengthVector() ||
7253 (VT.isFloatingPoint() && !VT.isScalableVector());
7254 }
7255
LowerFormalArguments(SDValue Chain,CallingConv::ID CallConv,bool isVarArg,const SmallVectorImpl<ISD::InputArg> & Ins,const SDLoc & DL,SelectionDAG & DAG,SmallVectorImpl<SDValue> & InVals) const7256 SDValue AArch64TargetLowering::LowerFormalArguments(
7257 SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
7258 const SmallVectorImpl<ISD::InputArg> &Ins, const SDLoc &DL,
7259 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
7260 MachineFunction &MF = DAG.getMachineFunction();
7261 const Function &F = MF.getFunction();
7262 MachineFrameInfo &MFI = MF.getFrameInfo();
7263 bool IsWin64 =
7264 Subtarget->isCallingConvWin64(F.getCallingConv(), F.isVarArg());
7265 bool StackViaX4 = CallConv == CallingConv::ARM64EC_Thunk_X64 ||
7266 (isVarArg && Subtarget->isWindowsArm64EC());
7267 AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
7268
7269 SmallVector<ISD::OutputArg, 4> Outs;
7270 GetReturnInfo(CallConv, F.getReturnType(), F.getAttributes(), Outs,
7271 DAG.getTargetLoweringInfo(), MF.getDataLayout());
7272 if (any_of(Outs, [](ISD::OutputArg &Out){ return Out.VT.isScalableVector(); }))
7273 FuncInfo->setIsSVECC(true);
7274
7275 // Assign locations to all of the incoming arguments.
7276 SmallVector<CCValAssign, 16> ArgLocs;
7277 DenseMap<unsigned, SDValue> CopiedRegs;
7278 CCState CCInfo(CallConv, isVarArg, MF, ArgLocs, *DAG.getContext());
7279
7280 // At this point, Ins[].VT may already be promoted to i32. To correctly
7281 // handle passing i8 as i8 instead of i32 on stack, we pass in both i32 and
7282 // i8 to CC_AArch64_AAPCS with i32 being ValVT and i8 being LocVT.
7283 // Since AnalyzeFormalArguments uses Ins[].VT for both ValVT and LocVT, here
7284 // we use a special version of AnalyzeFormalArguments to pass in ValVT and
7285 // LocVT.
7286 unsigned NumArgs = Ins.size();
7287 Function::const_arg_iterator CurOrigArg = F.arg_begin();
7288 unsigned CurArgIdx = 0;
7289 for (unsigned i = 0; i != NumArgs; ++i) {
7290 MVT ValVT = Ins[i].VT;
7291 if (Ins[i].isOrigArg()) {
7292 std::advance(CurOrigArg, Ins[i].getOrigArgIndex() - CurArgIdx);
7293 CurArgIdx = Ins[i].getOrigArgIndex();
7294
7295 // Get type of the original argument.
7296 EVT ActualVT = getValueType(DAG.getDataLayout(), CurOrigArg->getType(),
7297 /*AllowUnknown*/ true);
7298 MVT ActualMVT = ActualVT.isSimple() ? ActualVT.getSimpleVT() : MVT::Other;
7299 // If ActualMVT is i1/i8/i16, we should set LocVT to i8/i8/i16.
7300 if (ActualMVT == MVT::i1 || ActualMVT == MVT::i8)
7301 ValVT = MVT::i8;
7302 else if (ActualMVT == MVT::i16)
7303 ValVT = MVT::i16;
7304 }
7305 bool UseVarArgCC = false;
7306 if (IsWin64)
7307 UseVarArgCC = isVarArg;
7308 CCAssignFn *AssignFn = CCAssignFnForCall(CallConv, UseVarArgCC);
7309 bool Res =
7310 AssignFn(i, ValVT, ValVT, CCValAssign::Full, Ins[i].Flags, CCInfo);
7311 assert(!Res && "Call operand has unhandled type");
7312 (void)Res;
7313 }
7314
7315 SMEAttrs Attrs(MF.getFunction());
7316 bool IsLocallyStreaming =
7317 !Attrs.hasStreamingInterface() && Attrs.hasStreamingBody();
7318 assert(Chain.getOpcode() == ISD::EntryToken && "Unexpected Chain value");
7319 SDValue Glue = Chain.getValue(1);
7320
7321 SmallVector<SDValue, 16> ArgValues;
7322 unsigned ExtraArgLocs = 0;
7323 for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
7324 CCValAssign &VA = ArgLocs[i - ExtraArgLocs];
7325
7326 if (Ins[i].Flags.isByVal()) {
7327 // Byval is used for HFAs in the PCS, but the system should work in a
7328 // non-compliant manner for larger structs.
7329 EVT PtrVT = getPointerTy(DAG.getDataLayout());
7330 int Size = Ins[i].Flags.getByValSize();
7331 unsigned NumRegs = (Size + 7) / 8;
7332
7333 // FIXME: This works on big-endian for composite byvals, which are the common
7334 // case. It should also work for fundamental types too.
7335 unsigned FrameIdx =
7336 MFI.CreateFixedObject(8 * NumRegs, VA.getLocMemOffset(), false);
7337 SDValue FrameIdxN = DAG.getFrameIndex(FrameIdx, PtrVT);
7338 InVals.push_back(FrameIdxN);
7339
7340 continue;
7341 }
7342
7343 if (Ins[i].Flags.isSwiftAsync())
7344 MF.getInfo<AArch64FunctionInfo>()->setHasSwiftAsyncContext(true);
7345
7346 SDValue ArgValue;
7347 if (VA.isRegLoc()) {
7348 // Arguments stored in registers.
7349 EVT RegVT = VA.getLocVT();
7350 const TargetRegisterClass *RC;
7351
7352 if (RegVT == MVT::i32)
7353 RC = &AArch64::GPR32RegClass;
7354 else if (RegVT == MVT::i64)
7355 RC = &AArch64::GPR64RegClass;
7356 else if (RegVT == MVT::f16 || RegVT == MVT::bf16)
7357 RC = &AArch64::FPR16RegClass;
7358 else if (RegVT == MVT::f32)
7359 RC = &AArch64::FPR32RegClass;
7360 else if (RegVT == MVT::f64 || RegVT.is64BitVector())
7361 RC = &AArch64::FPR64RegClass;
7362 else if (RegVT == MVT::f128 || RegVT.is128BitVector())
7363 RC = &AArch64::FPR128RegClass;
7364 else if (RegVT.isScalableVector() &&
7365 RegVT.getVectorElementType() == MVT::i1) {
7366 FuncInfo->setIsSVECC(true);
7367 RC = &AArch64::PPRRegClass;
7368 } else if (RegVT == MVT::aarch64svcount) {
7369 FuncInfo->setIsSVECC(true);
7370 RC = &AArch64::PPRRegClass;
7371 } else if (RegVT.isScalableVector()) {
7372 FuncInfo->setIsSVECC(true);
7373 RC = &AArch64::ZPRRegClass;
7374 } else
7375 llvm_unreachable("RegVT not supported by FORMAL_ARGUMENTS Lowering");
7376
7377 // Transform the arguments in physical registers into virtual ones.
7378 Register Reg = MF.addLiveIn(VA.getLocReg(), RC);
7379
7380 if (IsLocallyStreaming) {
7381 // LocallyStreamingFunctions must insert the SMSTART in the correct
7382 // position, so we use Glue to ensure no instructions can be scheduled
7383 // between the chain of:
7384 // t0: ch,glue = EntryNode
7385 // t1: res,ch,glue = CopyFromReg
7386 // ...
7387 // tn: res,ch,glue = CopyFromReg t(n-1), ..
7388 // t(n+1): ch, glue = SMSTART t0:0, ...., tn:2
7389 // ^^^^^^
7390 // This will be the new Chain/Root node.
7391 ArgValue = DAG.getCopyFromReg(Chain, DL, Reg, RegVT, Glue);
7392 Glue = ArgValue.getValue(2);
7393 if (isPassedInFPR(ArgValue.getValueType())) {
7394 ArgValue =
7395 DAG.getNode(AArch64ISD::COALESCER_BARRIER, DL,
7396 DAG.getVTList(ArgValue.getValueType(), MVT::Glue),
7397 {ArgValue, Glue});
7398 Glue = ArgValue.getValue(1);
7399 }
7400 } else
7401 ArgValue = DAG.getCopyFromReg(Chain, DL, Reg, RegVT);
7402
7403 // If this is an 8, 16 or 32-bit value, it is really passed promoted
7404 // to 64 bits. Insert an assert[sz]ext to capture this, then
7405 // truncate to the right size.
7406 switch (VA.getLocInfo()) {
7407 default:
7408 llvm_unreachable("Unknown loc info!");
7409 case CCValAssign::Full:
7410 break;
7411 case CCValAssign::Indirect:
7412 assert(
7413 (VA.getValVT().isScalableVT() || Subtarget->isWindowsArm64EC()) &&
7414 "Indirect arguments should be scalable on most subtargets");
7415 break;
7416 case CCValAssign::BCvt:
7417 ArgValue = DAG.getNode(ISD::BITCAST, DL, VA.getValVT(), ArgValue);
7418 break;
7419 case CCValAssign::AExt:
7420 case CCValAssign::SExt:
7421 case CCValAssign::ZExt:
7422 break;
7423 case CCValAssign::AExtUpper:
7424 ArgValue = DAG.getNode(ISD::SRL, DL, RegVT, ArgValue,
7425 DAG.getConstant(32, DL, RegVT));
7426 ArgValue = DAG.getZExtOrTrunc(ArgValue, DL, VA.getValVT());
7427 break;
7428 }
7429 } else { // VA.isRegLoc()
7430 assert(VA.isMemLoc() && "CCValAssign is neither reg nor mem");
7431 unsigned ArgOffset = VA.getLocMemOffset();
7432 unsigned ArgSize = (VA.getLocInfo() == CCValAssign::Indirect
7433 ? VA.getLocVT().getSizeInBits()
7434 : VA.getValVT().getSizeInBits()) / 8;
7435
7436 uint32_t BEAlign = 0;
7437 if (!Subtarget->isLittleEndian() && ArgSize < 8 &&
7438 !Ins[i].Flags.isInConsecutiveRegs())
7439 BEAlign = 8 - ArgSize;
7440
7441 SDValue FIN;
7442 MachinePointerInfo PtrInfo;
7443 if (StackViaX4) {
7444 // In both the ARM64EC varargs convention and the thunk convention,
7445 // arguments on the stack are accessed relative to x4, not sp. In
7446 // the thunk convention, there's an additional offset of 32 bytes
7447 // to account for the shadow store.
7448 unsigned ObjOffset = ArgOffset + BEAlign;
7449 if (CallConv == CallingConv::ARM64EC_Thunk_X64)
7450 ObjOffset += 32;
7451 Register VReg = MF.addLiveIn(AArch64::X4, &AArch64::GPR64RegClass);
7452 SDValue Val = DAG.getCopyFromReg(Chain, DL, VReg, MVT::i64);
7453 FIN = DAG.getNode(ISD::ADD, DL, MVT::i64, Val,
7454 DAG.getConstant(ObjOffset, DL, MVT::i64));
7455 PtrInfo = MachinePointerInfo::getUnknownStack(MF);
7456 } else {
7457 int FI = MFI.CreateFixedObject(ArgSize, ArgOffset + BEAlign, true);
7458
7459 // Create load nodes to retrieve arguments from the stack.
7460 FIN = DAG.getFrameIndex(FI, getPointerTy(DAG.getDataLayout()));
7461 PtrInfo = MachinePointerInfo::getFixedStack(MF, FI);
7462 }
7463
7464 // For NON_EXTLOAD, generic code in getLoad assert(ValVT == MemVT)
7465 ISD::LoadExtType ExtType = ISD::NON_EXTLOAD;
7466 MVT MemVT = VA.getValVT();
7467
7468 switch (VA.getLocInfo()) {
7469 default:
7470 break;
7471 case CCValAssign::Trunc:
7472 case CCValAssign::BCvt:
7473 MemVT = VA.getLocVT();
7474 break;
7475 case CCValAssign::Indirect:
7476 assert((VA.getValVT().isScalableVector() ||
7477 Subtarget->isWindowsArm64EC()) &&
7478 "Indirect arguments should be scalable on most subtargets");
7479 MemVT = VA.getLocVT();
7480 break;
7481 case CCValAssign::SExt:
7482 ExtType = ISD::SEXTLOAD;
7483 break;
7484 case CCValAssign::ZExt:
7485 ExtType = ISD::ZEXTLOAD;
7486 break;
7487 case CCValAssign::AExt:
7488 ExtType = ISD::EXTLOAD;
7489 break;
7490 }
7491
7492 ArgValue = DAG.getExtLoad(ExtType, DL, VA.getLocVT(), Chain, FIN, PtrInfo,
7493 MemVT);
7494 }
7495
7496 if (VA.getLocInfo() == CCValAssign::Indirect) {
7497 assert((VA.getValVT().isScalableVT() ||
7498 Subtarget->isWindowsArm64EC()) &&
7499 "Indirect arguments should be scalable on most subtargets");
7500
7501 uint64_t PartSize = VA.getValVT().getStoreSize().getKnownMinValue();
7502 unsigned NumParts = 1;
7503 if (Ins[i].Flags.isInConsecutiveRegs()) {
7504 while (!Ins[i + NumParts - 1].Flags.isInConsecutiveRegsLast())
7505 ++NumParts;
7506 }
7507
7508 MVT PartLoad = VA.getValVT();
7509 SDValue Ptr = ArgValue;
7510
7511 // Ensure we generate all loads for each tuple part, whilst updating the
7512 // pointer after each load correctly using vscale.
7513 while (NumParts > 0) {
7514 ArgValue = DAG.getLoad(PartLoad, DL, Chain, Ptr, MachinePointerInfo());
7515 InVals.push_back(ArgValue);
7516 NumParts--;
7517 if (NumParts > 0) {
7518 SDValue BytesIncrement;
7519 if (PartLoad.isScalableVector()) {
7520 BytesIncrement = DAG.getVScale(
7521 DL, Ptr.getValueType(),
7522 APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize));
7523 } else {
7524 BytesIncrement = DAG.getConstant(
7525 APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize), DL,
7526 Ptr.getValueType());
7527 }
7528 SDNodeFlags Flags;
7529 Flags.setNoUnsignedWrap(true);
7530 Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
7531 BytesIncrement, Flags);
7532 ExtraArgLocs++;
7533 i++;
7534 }
7535 }
7536 } else {
7537 if (Subtarget->isTargetILP32() && Ins[i].Flags.isPointer())
7538 ArgValue = DAG.getNode(ISD::AssertZext, DL, ArgValue.getValueType(),
7539 ArgValue, DAG.getValueType(MVT::i32));
7540
7541 // i1 arguments are zero-extended to i8 by the caller. Emit a
7542 // hint to reflect this.
7543 if (Ins[i].isOrigArg()) {
7544 Argument *OrigArg = F.getArg(Ins[i].getOrigArgIndex());
7545 if (OrigArg->getType()->isIntegerTy(1)) {
7546 if (!Ins[i].Flags.isZExt()) {
7547 ArgValue = DAG.getNode(AArch64ISD::ASSERT_ZEXT_BOOL, DL,
7548 ArgValue.getValueType(), ArgValue);
7549 }
7550 }
7551 }
7552
7553 InVals.push_back(ArgValue);
7554 }
7555 }
7556 assert((ArgLocs.size() + ExtraArgLocs) == Ins.size());
7557
7558 // Insert the SMSTART if this is a locally streaming function and
7559 // make sure it is Glued to the last CopyFromReg value.
7560 if (IsLocallyStreaming) {
7561 SDValue PStateSM;
7562 if (Attrs.hasStreamingCompatibleInterface()) {
7563 PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
7564 Register Reg = MF.getRegInfo().createVirtualRegister(
7565 getRegClassFor(PStateSM.getValueType().getSimpleVT()));
7566 FuncInfo->setPStateSMReg(Reg);
7567 Chain = DAG.getCopyToReg(Chain, DL, Reg, PStateSM);
7568 Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
7569 AArch64SME::IfCallerIsNonStreaming, PStateSM);
7570 } else
7571 Chain = changeStreamingMode(DAG, DL, /*Enable*/ true, Chain, Glue,
7572 AArch64SME::Always);
7573
7574 // Ensure that the SMSTART happens after the CopyWithChain such that its
7575 // chain result is used.
7576 for (unsigned I=0; I<InVals.size(); ++I) {
7577 Register Reg = MF.getRegInfo().createVirtualRegister(
7578 getRegClassFor(InVals[I].getValueType().getSimpleVT()));
7579 Chain = DAG.getCopyToReg(Chain, DL, Reg, InVals[I]);
7580 InVals[I] = DAG.getCopyFromReg(Chain, DL, Reg,
7581 InVals[I].getValueType());
7582 }
7583 }
7584
7585 // varargs
7586 if (isVarArg) {
7587 if (!Subtarget->isTargetDarwin() || IsWin64) {
7588 // The AAPCS variadic function ABI is identical to the non-variadic
7589 // one. As a result there may be more arguments in registers and we should
7590 // save them for future reference.
7591 // Win64 variadic functions also pass arguments in registers, but all float
7592 // arguments are passed in integer registers.
7593 saveVarArgRegisters(CCInfo, DAG, DL, Chain);
7594 }
7595
7596 // This will point to the next argument passed via stack.
7597 unsigned VarArgsOffset = CCInfo.getStackSize();
7598 // We currently pass all varargs at 8-byte alignment, or 4 for ILP32
7599 VarArgsOffset = alignTo(VarArgsOffset, Subtarget->isTargetILP32() ? 4 : 8);
7600 FuncInfo->setVarArgsStackOffset(VarArgsOffset);
7601 FuncInfo->setVarArgsStackIndex(
7602 MFI.CreateFixedObject(4, VarArgsOffset, true));
7603
7604 if (MFI.hasMustTailInVarArgFunc()) {
7605 SmallVector<MVT, 2> RegParmTypes;
7606 RegParmTypes.push_back(MVT::i64);
7607 RegParmTypes.push_back(MVT::f128);
7608 // Compute the set of forwarded registers. The rest are scratch.
7609 SmallVectorImpl<ForwardedRegister> &Forwards =
7610 FuncInfo->getForwardedMustTailRegParms();
7611 CCInfo.analyzeMustTailForwardedRegisters(Forwards, RegParmTypes,
7612 CC_AArch64_AAPCS);
7613
7614 // Conservatively forward X8, since it might be used for aggregate return.
7615 if (!CCInfo.isAllocated(AArch64::X8)) {
7616 Register X8VReg = MF.addLiveIn(AArch64::X8, &AArch64::GPR64RegClass);
7617 Forwards.push_back(ForwardedRegister(X8VReg, AArch64::X8, MVT::i64));
7618 }
7619 }
7620 }
7621
7622 // On Windows, InReg pointers must be returned, so record the pointer in a
7623 // virtual register at the start of the function so it can be returned in the
7624 // epilogue.
7625 if (IsWin64 || F.getCallingConv() == CallingConv::ARM64EC_Thunk_X64) {
7626 for (unsigned I = 0, E = Ins.size(); I != E; ++I) {
7627 if ((F.getCallingConv() == CallingConv::ARM64EC_Thunk_X64 ||
7628 Ins[I].Flags.isInReg()) &&
7629 Ins[I].Flags.isSRet()) {
7630 assert(!FuncInfo->getSRetReturnReg());
7631
7632 MVT PtrTy = getPointerTy(DAG.getDataLayout());
7633 Register Reg =
7634 MF.getRegInfo().createVirtualRegister(getRegClassFor(PtrTy));
7635 FuncInfo->setSRetReturnReg(Reg);
7636
7637 SDValue Copy = DAG.getCopyToReg(DAG.getEntryNode(), DL, Reg, InVals[I]);
7638 Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, Copy, Chain);
7639 break;
7640 }
7641 }
7642 }
7643
7644 unsigned StackArgSize = CCInfo.getStackSize();
7645 bool TailCallOpt = MF.getTarget().Options.GuaranteedTailCallOpt;
7646 if (DoesCalleeRestoreStack(CallConv, TailCallOpt)) {
7647 // This is a non-standard ABI so by fiat I say we're allowed to make full
7648 // use of the stack area to be popped, which must be aligned to 16 bytes in
7649 // any case:
7650 StackArgSize = alignTo(StackArgSize, 16);
7651
7652 // If we're expected to restore the stack (e.g. fastcc) then we'll be adding
7653 // a multiple of 16.
7654 FuncInfo->setArgumentStackToRestore(StackArgSize);
7655
7656 // This realignment carries over to the available bytes below. Our own
7657 // callers will guarantee the space is free by giving an aligned value to
7658 // CALLSEQ_START.
7659 }
7660 // Even if we're not expected to free up the space, it's useful to know how
7661 // much is there while considering tail calls (because we can reuse it).
7662 FuncInfo->setBytesInStackArgArea(StackArgSize);
7663
7664 if (Subtarget->hasCustomCallingConv())
7665 Subtarget->getRegisterInfo()->UpdateCustomCalleeSavedRegs(MF);
7666
7667 // Create a 16 Byte TPIDR2 object. The dynamic buffer
7668 // will be expanded and stored in the static object later using a pseudonode.
7669 if (SMEAttrs(MF.getFunction()).hasZAState()) {
7670 TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
7671 TPIDR2.FrameIndex = MFI.CreateStackObject(16, Align(16), false);
7672 SDValue SVL = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
7673 DAG.getConstant(1, DL, MVT::i32));
7674
7675 SDValue Buffer;
7676 if (!Subtarget->isTargetWindows() && !hasInlineStackProbe(MF)) {
7677 Buffer = DAG.getNode(AArch64ISD::ALLOCATE_ZA_BUFFER, DL,
7678 DAG.getVTList(MVT::i64, MVT::Other), {Chain, SVL});
7679 } else {
7680 SDValue Size = DAG.getNode(ISD::MUL, DL, MVT::i64, SVL, SVL);
7681 Buffer = DAG.getNode(ISD::DYNAMIC_STACKALLOC, DL,
7682 DAG.getVTList(MVT::i64, MVT::Other),
7683 {Chain, Size, DAG.getConstant(1, DL, MVT::i64)});
7684 MFI.CreateVariableSizedObject(Align(16), nullptr);
7685 }
7686 Chain = DAG.getNode(
7687 AArch64ISD::INIT_TPIDR2OBJ, DL, DAG.getVTList(MVT::Other),
7688 {/*Chain*/ Buffer.getValue(1), /*Buffer ptr*/ Buffer.getValue(0)});
7689 }
7690
7691 if (CallConv == CallingConv::PreserveNone) {
7692 for (const ISD::InputArg &I : Ins) {
7693 if (I.Flags.isSwiftSelf() || I.Flags.isSwiftError() ||
7694 I.Flags.isSwiftAsync()) {
7695 MachineFunction &MF = DAG.getMachineFunction();
7696 DAG.getContext()->diagnose(DiagnosticInfoUnsupported(
7697 MF.getFunction(),
7698 "Swift attributes can't be used with preserve_none",
7699 DL.getDebugLoc()));
7700 break;
7701 }
7702 }
7703 }
7704
7705 return Chain;
7706 }
7707
saveVarArgRegisters(CCState & CCInfo,SelectionDAG & DAG,const SDLoc & DL,SDValue & Chain) const7708 void AArch64TargetLowering::saveVarArgRegisters(CCState &CCInfo,
7709 SelectionDAG &DAG,
7710 const SDLoc &DL,
7711 SDValue &Chain) const {
7712 MachineFunction &MF = DAG.getMachineFunction();
7713 MachineFrameInfo &MFI = MF.getFrameInfo();
7714 AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
7715 auto PtrVT = getPointerTy(DAG.getDataLayout());
7716 Function &F = MF.getFunction();
7717 bool IsWin64 =
7718 Subtarget->isCallingConvWin64(F.getCallingConv(), F.isVarArg());
7719
7720 SmallVector<SDValue, 8> MemOps;
7721
7722 auto GPRArgRegs = AArch64::getGPRArgRegs();
7723 unsigned NumGPRArgRegs = GPRArgRegs.size();
7724 if (Subtarget->isWindowsArm64EC()) {
7725 // In the ARM64EC ABI, only x0-x3 are used to pass arguments to varargs
7726 // functions.
7727 NumGPRArgRegs = 4;
7728 }
7729 unsigned FirstVariadicGPR = CCInfo.getFirstUnallocated(GPRArgRegs);
7730
7731 unsigned GPRSaveSize = 8 * (NumGPRArgRegs - FirstVariadicGPR);
7732 int GPRIdx = 0;
7733 if (GPRSaveSize != 0) {
7734 if (IsWin64) {
7735 GPRIdx = MFI.CreateFixedObject(GPRSaveSize, -(int)GPRSaveSize, false);
7736 if (GPRSaveSize & 15)
7737 // The extra size here, if triggered, will always be 8.
7738 MFI.CreateFixedObject(16 - (GPRSaveSize & 15), -(int)alignTo(GPRSaveSize, 16), false);
7739 } else
7740 GPRIdx = MFI.CreateStackObject(GPRSaveSize, Align(8), false);
7741
7742 SDValue FIN;
7743 if (Subtarget->isWindowsArm64EC()) {
7744 // With the Arm64EC ABI, we reserve the save area as usual, but we
7745 // compute its address relative to x4. For a normal AArch64->AArch64
7746 // call, x4 == sp on entry, but calls from an entry thunk can pass in a
7747 // different address.
7748 Register VReg = MF.addLiveIn(AArch64::X4, &AArch64::GPR64RegClass);
7749 SDValue Val = DAG.getCopyFromReg(Chain, DL, VReg, MVT::i64);
7750 FIN = DAG.getNode(ISD::SUB, DL, MVT::i64, Val,
7751 DAG.getConstant(GPRSaveSize, DL, MVT::i64));
7752 } else {
7753 FIN = DAG.getFrameIndex(GPRIdx, PtrVT);
7754 }
7755
7756 for (unsigned i = FirstVariadicGPR; i < NumGPRArgRegs; ++i) {
7757 Register VReg = MF.addLiveIn(GPRArgRegs[i], &AArch64::GPR64RegClass);
7758 SDValue Val = DAG.getCopyFromReg(Chain, DL, VReg, MVT::i64);
7759 SDValue Store =
7760 DAG.getStore(Val.getValue(1), DL, Val, FIN,
7761 IsWin64 ? MachinePointerInfo::getFixedStack(
7762 MF, GPRIdx, (i - FirstVariadicGPR) * 8)
7763 : MachinePointerInfo::getStack(MF, i * 8));
7764 MemOps.push_back(Store);
7765 FIN =
7766 DAG.getNode(ISD::ADD, DL, PtrVT, FIN, DAG.getConstant(8, DL, PtrVT));
7767 }
7768 }
7769 FuncInfo->setVarArgsGPRIndex(GPRIdx);
7770 FuncInfo->setVarArgsGPRSize(GPRSaveSize);
7771
7772 if (Subtarget->hasFPARMv8() && !IsWin64) {
7773 auto FPRArgRegs = AArch64::getFPRArgRegs();
7774 const unsigned NumFPRArgRegs = FPRArgRegs.size();
7775 unsigned FirstVariadicFPR = CCInfo.getFirstUnallocated(FPRArgRegs);
7776
7777 unsigned FPRSaveSize = 16 * (NumFPRArgRegs - FirstVariadicFPR);
7778 int FPRIdx = 0;
7779 if (FPRSaveSize != 0) {
7780 FPRIdx = MFI.CreateStackObject(FPRSaveSize, Align(16), false);
7781
7782 SDValue FIN = DAG.getFrameIndex(FPRIdx, PtrVT);
7783
7784 for (unsigned i = FirstVariadicFPR; i < NumFPRArgRegs; ++i) {
7785 Register VReg = MF.addLiveIn(FPRArgRegs[i], &AArch64::FPR128RegClass);
7786 SDValue Val = DAG.getCopyFromReg(Chain, DL, VReg, MVT::f128);
7787
7788 SDValue Store = DAG.getStore(Val.getValue(1), DL, Val, FIN,
7789 MachinePointerInfo::getStack(MF, i * 16));
7790 MemOps.push_back(Store);
7791 FIN = DAG.getNode(ISD::ADD, DL, PtrVT, FIN,
7792 DAG.getConstant(16, DL, PtrVT));
7793 }
7794 }
7795 FuncInfo->setVarArgsFPRIndex(FPRIdx);
7796 FuncInfo->setVarArgsFPRSize(FPRSaveSize);
7797 }
7798
7799 if (!MemOps.empty()) {
7800 Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOps);
7801 }
7802 }
7803
7804 /// LowerCallResult - Lower the result values of a call into the
7805 /// appropriate copies out of appropriate physical registers.
LowerCallResult(SDValue Chain,SDValue InGlue,CallingConv::ID CallConv,bool isVarArg,const SmallVectorImpl<CCValAssign> & RVLocs,const SDLoc & DL,SelectionDAG & DAG,SmallVectorImpl<SDValue> & InVals,bool isThisReturn,SDValue ThisVal,bool RequiresSMChange) const7806 SDValue AArch64TargetLowering::LowerCallResult(
7807 SDValue Chain, SDValue InGlue, CallingConv::ID CallConv, bool isVarArg,
7808 const SmallVectorImpl<CCValAssign> &RVLocs, const SDLoc &DL,
7809 SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals, bool isThisReturn,
7810 SDValue ThisVal, bool RequiresSMChange) const {
7811 DenseMap<unsigned, SDValue> CopiedRegs;
7812 // Copy all of the result registers out of their specified physreg.
7813 for (unsigned i = 0; i != RVLocs.size(); ++i) {
7814 CCValAssign VA = RVLocs[i];
7815
7816 // Pass 'this' value directly from the argument to return value, to avoid
7817 // reg unit interference
7818 if (i == 0 && isThisReturn) {
7819 assert(!VA.needsCustom() && VA.getLocVT() == MVT::i64 &&
7820 "unexpected return calling convention register assignment");
7821 InVals.push_back(ThisVal);
7822 continue;
7823 }
7824
7825 // Avoid copying a physreg twice since RegAllocFast is incompetent and only
7826 // allows one use of a physreg per block.
7827 SDValue Val = CopiedRegs.lookup(VA.getLocReg());
7828 if (!Val) {
7829 Val =
7830 DAG.getCopyFromReg(Chain, DL, VA.getLocReg(), VA.getLocVT(), InGlue);
7831 Chain = Val.getValue(1);
7832 InGlue = Val.getValue(2);
7833 CopiedRegs[VA.getLocReg()] = Val;
7834 }
7835
7836 switch (VA.getLocInfo()) {
7837 default:
7838 llvm_unreachable("Unknown loc info!");
7839 case CCValAssign::Full:
7840 break;
7841 case CCValAssign::BCvt:
7842 Val = DAG.getNode(ISD::BITCAST, DL, VA.getValVT(), Val);
7843 break;
7844 case CCValAssign::AExtUpper:
7845 Val = DAG.getNode(ISD::SRL, DL, VA.getLocVT(), Val,
7846 DAG.getConstant(32, DL, VA.getLocVT()));
7847 [[fallthrough]];
7848 case CCValAssign::AExt:
7849 [[fallthrough]];
7850 case CCValAssign::ZExt:
7851 Val = DAG.getZExtOrTrunc(Val, DL, VA.getValVT());
7852 break;
7853 }
7854
7855 if (RequiresSMChange && isPassedInFPR(VA.getValVT()))
7856 Val = DAG.getNode(AArch64ISD::COALESCER_BARRIER, DL, Val.getValueType(),
7857 Val);
7858
7859 InVals.push_back(Val);
7860 }
7861
7862 return Chain;
7863 }
7864
7865 /// Return true if the calling convention is one that we can guarantee TCO for.
canGuaranteeTCO(CallingConv::ID CC,bool GuaranteeTailCalls)7866 static bool canGuaranteeTCO(CallingConv::ID CC, bool GuaranteeTailCalls) {
7867 return (CC == CallingConv::Fast && GuaranteeTailCalls) ||
7868 CC == CallingConv::Tail || CC == CallingConv::SwiftTail;
7869 }
7870
7871 /// Return true if we might ever do TCO for calls with this calling convention.
mayTailCallThisCC(CallingConv::ID CC)7872 static bool mayTailCallThisCC(CallingConv::ID CC) {
7873 switch (CC) {
7874 case CallingConv::C:
7875 case CallingConv::AArch64_SVE_VectorCall:
7876 case CallingConv::PreserveMost:
7877 case CallingConv::PreserveAll:
7878 case CallingConv::PreserveNone:
7879 case CallingConv::Swift:
7880 case CallingConv::SwiftTail:
7881 case CallingConv::Tail:
7882 case CallingConv::Fast:
7883 return true;
7884 default:
7885 return false;
7886 }
7887 }
7888
7889 /// Return true if the call convention supports varargs
7890 /// Currently only those that pass varargs like the C
7891 /// calling convention does are eligible
7892 /// Calling conventions listed in this function must also
7893 /// be properly handled in AArch64Subtarget::isCallingConvWin64
callConvSupportsVarArgs(CallingConv::ID CC)7894 static bool callConvSupportsVarArgs(CallingConv::ID CC) {
7895 switch (CC) {
7896 case CallingConv::C:
7897 case CallingConv::PreserveNone:
7898 return true;
7899 default:
7900 return false;
7901 }
7902 }
7903
analyzeCallOperands(const AArch64TargetLowering & TLI,const AArch64Subtarget * Subtarget,const TargetLowering::CallLoweringInfo & CLI,CCState & CCInfo)7904 static void analyzeCallOperands(const AArch64TargetLowering &TLI,
7905 const AArch64Subtarget *Subtarget,
7906 const TargetLowering::CallLoweringInfo &CLI,
7907 CCState &CCInfo) {
7908 const SelectionDAG &DAG = CLI.DAG;
7909 CallingConv::ID CalleeCC = CLI.CallConv;
7910 bool IsVarArg = CLI.IsVarArg;
7911 const SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
7912 bool IsCalleeWin64 = Subtarget->isCallingConvWin64(CalleeCC, IsVarArg);
7913
7914 // For Arm64EC thunks, allocate 32 extra bytes at the bottom of the stack
7915 // for the shadow store.
7916 if (CalleeCC == CallingConv::ARM64EC_Thunk_X64)
7917 CCInfo.AllocateStack(32, Align(16));
7918
7919 unsigned NumArgs = Outs.size();
7920 for (unsigned i = 0; i != NumArgs; ++i) {
7921 MVT ArgVT = Outs[i].VT;
7922 ISD::ArgFlagsTy ArgFlags = Outs[i].Flags;
7923
7924 bool UseVarArgCC = false;
7925 if (IsVarArg) {
7926 // On Windows, the fixed arguments in a vararg call are passed in GPRs
7927 // too, so use the vararg CC to force them to integer registers.
7928 if (IsCalleeWin64) {
7929 UseVarArgCC = true;
7930 } else {
7931 UseVarArgCC = !Outs[i].IsFixed;
7932 }
7933 }
7934
7935 if (!UseVarArgCC) {
7936 // Get type of the original argument.
7937 EVT ActualVT =
7938 TLI.getValueType(DAG.getDataLayout(), CLI.Args[Outs[i].OrigArgIndex].Ty,
7939 /*AllowUnknown*/ true);
7940 MVT ActualMVT = ActualVT.isSimple() ? ActualVT.getSimpleVT() : ArgVT;
7941 // If ActualMVT is i1/i8/i16, we should set LocVT to i8/i8/i16.
7942 if (ActualMVT == MVT::i1 || ActualMVT == MVT::i8)
7943 ArgVT = MVT::i8;
7944 else if (ActualMVT == MVT::i16)
7945 ArgVT = MVT::i16;
7946 }
7947
7948 CCAssignFn *AssignFn = TLI.CCAssignFnForCall(CalleeCC, UseVarArgCC);
7949 bool Res = AssignFn(i, ArgVT, ArgVT, CCValAssign::Full, ArgFlags, CCInfo);
7950 assert(!Res && "Call operand has unhandled type");
7951 (void)Res;
7952 }
7953 }
7954
isEligibleForTailCallOptimization(const CallLoweringInfo & CLI) const7955 bool AArch64TargetLowering::isEligibleForTailCallOptimization(
7956 const CallLoweringInfo &CLI) const {
7957 CallingConv::ID CalleeCC = CLI.CallConv;
7958 if (!mayTailCallThisCC(CalleeCC))
7959 return false;
7960
7961 SDValue Callee = CLI.Callee;
7962 bool IsVarArg = CLI.IsVarArg;
7963 const SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
7964 const SmallVector<SDValue, 32> &OutVals = CLI.OutVals;
7965 const SmallVector<ISD::InputArg, 32> &Ins = CLI.Ins;
7966 const SelectionDAG &DAG = CLI.DAG;
7967 MachineFunction &MF = DAG.getMachineFunction();
7968 const Function &CallerF = MF.getFunction();
7969 CallingConv::ID CallerCC = CallerF.getCallingConv();
7970
7971 // SME Streaming functions are not eligible for TCO as they may require
7972 // the streaming mode or ZA to be restored after returning from the call.
7973 SMEAttrs CallerAttrs(MF.getFunction());
7974 auto CalleeAttrs = CLI.CB ? SMEAttrs(*CLI.CB) : SMEAttrs(SMEAttrs::Normal);
7975 if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
7976 CallerAttrs.requiresLazySave(CalleeAttrs) ||
7977 CallerAttrs.hasStreamingBody())
7978 return false;
7979
7980 // Functions using the C or Fast calling convention that have an SVE signature
7981 // preserve more registers and should assume the SVE_VectorCall CC.
7982 // The check for matching callee-saved regs will determine whether it is
7983 // eligible for TCO.
7984 if ((CallerCC == CallingConv::C || CallerCC == CallingConv::Fast) &&
7985 MF.getInfo<AArch64FunctionInfo>()->isSVECC())
7986 CallerCC = CallingConv::AArch64_SVE_VectorCall;
7987
7988 bool CCMatch = CallerCC == CalleeCC;
7989
7990 // When using the Windows calling convention on a non-windows OS, we want
7991 // to back up and restore X18 in such functions; we can't do a tail call
7992 // from those functions.
7993 if (CallerCC == CallingConv::Win64 && !Subtarget->isTargetWindows() &&
7994 CalleeCC != CallingConv::Win64)
7995 return false;
7996
7997 // Byval parameters hand the function a pointer directly into the stack area
7998 // we want to reuse during a tail call. Working around this *is* possible (see
7999 // X86) but less efficient and uglier in LowerCall.
8000 for (Function::const_arg_iterator i = CallerF.arg_begin(),
8001 e = CallerF.arg_end();
8002 i != e; ++i) {
8003 if (i->hasByValAttr())
8004 return false;
8005
8006 // On Windows, "inreg" attributes signify non-aggregate indirect returns.
8007 // In this case, it is necessary to save/restore X0 in the callee. Tail
8008 // call opt interferes with this. So we disable tail call opt when the
8009 // caller has an argument with "inreg" attribute.
8010
8011 // FIXME: Check whether the callee also has an "inreg" argument.
8012 if (i->hasInRegAttr())
8013 return false;
8014 }
8015
8016 if (canGuaranteeTCO(CalleeCC, getTargetMachine().Options.GuaranteedTailCallOpt))
8017 return CCMatch;
8018
8019 // Externally-defined functions with weak linkage should not be
8020 // tail-called on AArch64 when the OS does not support dynamic
8021 // pre-emption of symbols, as the AAELF spec requires normal calls
8022 // to undefined weak functions to be replaced with a NOP or jump to the
8023 // next instruction. The behaviour of branch instructions in this
8024 // situation (as used for tail calls) is implementation-defined, so we
8025 // cannot rely on the linker replacing the tail call with a return.
8026 if (GlobalAddressSDNode *G = dyn_cast<GlobalAddressSDNode>(Callee)) {
8027 const GlobalValue *GV = G->getGlobal();
8028 const Triple &TT = getTargetMachine().getTargetTriple();
8029 if (GV->hasExternalWeakLinkage() &&
8030 (!TT.isOSWindows() || TT.isOSBinFormatELF() || TT.isOSBinFormatMachO()))
8031 return false;
8032 }
8033
8034 // Now we search for cases where we can use a tail call without changing the
8035 // ABI. Sibcall is used in some places (particularly gcc) to refer to this
8036 // concept.
8037
8038 // I want anyone implementing a new calling convention to think long and hard
8039 // about this assert.
8040 if (IsVarArg && !callConvSupportsVarArgs(CalleeCC))
8041 report_fatal_error("Unsupported variadic calling convention");
8042
8043 LLVMContext &C = *DAG.getContext();
8044 // Check that the call results are passed in the same way.
8045 if (!CCState::resultsCompatible(CalleeCC, CallerCC, MF, C, Ins,
8046 CCAssignFnForCall(CalleeCC, IsVarArg),
8047 CCAssignFnForCall(CallerCC, IsVarArg)))
8048 return false;
8049 // The callee has to preserve all registers the caller needs to preserve.
8050 const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
8051 const uint32_t *CallerPreserved = TRI->getCallPreservedMask(MF, CallerCC);
8052 if (!CCMatch) {
8053 const uint32_t *CalleePreserved = TRI->getCallPreservedMask(MF, CalleeCC);
8054 if (Subtarget->hasCustomCallingConv()) {
8055 TRI->UpdateCustomCallPreservedMask(MF, &CallerPreserved);
8056 TRI->UpdateCustomCallPreservedMask(MF, &CalleePreserved);
8057 }
8058 if (!TRI->regmaskSubsetEqual(CallerPreserved, CalleePreserved))
8059 return false;
8060 }
8061
8062 // Nothing more to check if the callee is taking no arguments
8063 if (Outs.empty())
8064 return true;
8065
8066 SmallVector<CCValAssign, 16> ArgLocs;
8067 CCState CCInfo(CalleeCC, IsVarArg, MF, ArgLocs, C);
8068
8069 analyzeCallOperands(*this, Subtarget, CLI, CCInfo);
8070
8071 if (IsVarArg && !(CLI.CB && CLI.CB->isMustTailCall())) {
8072 // When we are musttail, additional checks have been done and we can safely ignore this check
8073 // At least two cases here: if caller is fastcc then we can't have any
8074 // memory arguments (we'd be expected to clean up the stack afterwards). If
8075 // caller is C then we could potentially use its argument area.
8076
8077 // FIXME: for now we take the most conservative of these in both cases:
8078 // disallow all variadic memory operands.
8079 for (const CCValAssign &ArgLoc : ArgLocs)
8080 if (!ArgLoc.isRegLoc())
8081 return false;
8082 }
8083
8084 const AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
8085
8086 // If any of the arguments is passed indirectly, it must be SVE, so the
8087 // 'getBytesInStackArgArea' is not sufficient to determine whether we need to
8088 // allocate space on the stack. That is why we determine this explicitly here
8089 // the call cannot be a tailcall.
8090 if (llvm::any_of(ArgLocs, [&](CCValAssign &A) {
8091 assert((A.getLocInfo() != CCValAssign::Indirect ||
8092 A.getValVT().isScalableVector() ||
8093 Subtarget->isWindowsArm64EC()) &&
8094 "Expected value to be scalable");
8095 return A.getLocInfo() == CCValAssign::Indirect;
8096 }))
8097 return false;
8098
8099 // If the stack arguments for this call do not fit into our own save area then
8100 // the call cannot be made tail.
8101 if (CCInfo.getStackSize() > FuncInfo->getBytesInStackArgArea())
8102 return false;
8103
8104 const MachineRegisterInfo &MRI = MF.getRegInfo();
8105 if (!parametersInCSRMatch(MRI, CallerPreserved, ArgLocs, OutVals))
8106 return false;
8107
8108 return true;
8109 }
8110
addTokenForArgument(SDValue Chain,SelectionDAG & DAG,MachineFrameInfo & MFI,int ClobberedFI) const8111 SDValue AArch64TargetLowering::addTokenForArgument(SDValue Chain,
8112 SelectionDAG &DAG,
8113 MachineFrameInfo &MFI,
8114 int ClobberedFI) const {
8115 SmallVector<SDValue, 8> ArgChains;
8116 int64_t FirstByte = MFI.getObjectOffset(ClobberedFI);
8117 int64_t LastByte = FirstByte + MFI.getObjectSize(ClobberedFI) - 1;
8118
8119 // Include the original chain at the beginning of the list. When this is
8120 // used by target LowerCall hooks, this helps legalize find the
8121 // CALLSEQ_BEGIN node.
8122 ArgChains.push_back(Chain);
8123
8124 // Add a chain value for each stack argument corresponding
8125 for (SDNode *U : DAG.getEntryNode().getNode()->uses())
8126 if (LoadSDNode *L = dyn_cast<LoadSDNode>(U))
8127 if (FrameIndexSDNode *FI = dyn_cast<FrameIndexSDNode>(L->getBasePtr()))
8128 if (FI->getIndex() < 0) {
8129 int64_t InFirstByte = MFI.getObjectOffset(FI->getIndex());
8130 int64_t InLastByte = InFirstByte;
8131 InLastByte += MFI.getObjectSize(FI->getIndex()) - 1;
8132
8133 if ((InFirstByte <= FirstByte && FirstByte <= InLastByte) ||
8134 (FirstByte <= InFirstByte && InFirstByte <= LastByte))
8135 ArgChains.push_back(SDValue(L, 1));
8136 }
8137
8138 // Build a tokenfactor for all the chains.
8139 return DAG.getNode(ISD::TokenFactor, SDLoc(Chain), MVT::Other, ArgChains);
8140 }
8141
DoesCalleeRestoreStack(CallingConv::ID CallCC,bool TailCallOpt) const8142 bool AArch64TargetLowering::DoesCalleeRestoreStack(CallingConv::ID CallCC,
8143 bool TailCallOpt) const {
8144 return (CallCC == CallingConv::Fast && TailCallOpt) ||
8145 CallCC == CallingConv::Tail || CallCC == CallingConv::SwiftTail;
8146 }
8147
8148 // Check if the value is zero-extended from i1 to i8
checkZExtBool(SDValue Arg,const SelectionDAG & DAG)8149 static bool checkZExtBool(SDValue Arg, const SelectionDAG &DAG) {
8150 unsigned SizeInBits = Arg.getValueType().getSizeInBits();
8151 if (SizeInBits < 8)
8152 return false;
8153
8154 APInt RequredZero(SizeInBits, 0xFE);
8155 KnownBits Bits = DAG.computeKnownBits(Arg, 4);
8156 bool ZExtBool = (Bits.Zero & RequredZero) == RequredZero;
8157 return ZExtBool;
8158 }
8159
AdjustInstrPostInstrSelection(MachineInstr & MI,SDNode * Node) const8160 void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
8161 SDNode *Node) const {
8162 // Live-in physreg copies that are glued to SMSTART are applied as
8163 // implicit-def's in the InstrEmitter. Here we remove them, allowing the
8164 // register allocator to pass call args in callee saved regs, without extra
8165 // copies to avoid these fake clobbers of actually-preserved GPRs.
8166 if (MI.getOpcode() == AArch64::MSRpstatesvcrImm1 ||
8167 MI.getOpcode() == AArch64::MSRpstatePseudo) {
8168 for (unsigned I = MI.getNumOperands() - 1; I > 0; --I)
8169 if (MachineOperand &MO = MI.getOperand(I);
8170 MO.isReg() && MO.isImplicit() && MO.isDef() &&
8171 (AArch64::GPR32RegClass.contains(MO.getReg()) ||
8172 AArch64::GPR64RegClass.contains(MO.getReg())))
8173 MI.removeOperand(I);
8174
8175 // The SVE vector length can change when entering/leaving streaming mode.
8176 if (MI.getOperand(0).getImm() == AArch64SVCR::SVCRSM ||
8177 MI.getOperand(0).getImm() == AArch64SVCR::SVCRSMZA) {
8178 MI.addOperand(MachineOperand::CreateReg(AArch64::VG, /*IsDef=*/false,
8179 /*IsImplicit=*/true));
8180 MI.addOperand(MachineOperand::CreateReg(AArch64::VG, /*IsDef=*/true,
8181 /*IsImplicit=*/true));
8182 }
8183 }
8184
8185 // Add an implicit use of 'VG' for ADDXri/SUBXri, which are instructions that
8186 // have nothing to do with VG, were it not that they are used to materialise a
8187 // frame-address. If they contain a frame-index to a scalable vector, this
8188 // will likely require an ADDVL instruction to materialise the address, thus
8189 // reading VG.
8190 const MachineFunction &MF = *MI.getMF();
8191 if (MF.getInfo<AArch64FunctionInfo>()->hasStreamingModeChanges() &&
8192 (MI.getOpcode() == AArch64::ADDXri ||
8193 MI.getOpcode() == AArch64::SUBXri)) {
8194 const MachineOperand &MO = MI.getOperand(1);
8195 if (MO.isFI() && MF.getFrameInfo().getStackID(MO.getIndex()) ==
8196 TargetStackID::ScalableVector)
8197 MI.addOperand(MachineOperand::CreateReg(AArch64::VG, /*IsDef=*/false,
8198 /*IsImplicit=*/true));
8199 }
8200 }
8201
changeStreamingMode(SelectionDAG & DAG,SDLoc DL,bool Enable,SDValue Chain,SDValue InGlue,unsigned Condition,SDValue PStateSM) const8202 SDValue AArch64TargetLowering::changeStreamingMode(SelectionDAG &DAG, SDLoc DL,
8203 bool Enable, SDValue Chain,
8204 SDValue InGlue,
8205 unsigned Condition,
8206 SDValue PStateSM) const {
8207 MachineFunction &MF = DAG.getMachineFunction();
8208 AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
8209 FuncInfo->setHasStreamingModeChanges(true);
8210
8211 const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
8212 SDValue RegMask = DAG.getRegisterMask(TRI->getSMStartStopCallPreservedMask());
8213 SDValue MSROp =
8214 DAG.getTargetConstant((int32_t)AArch64SVCR::SVCRSM, DL, MVT::i32);
8215 SDValue ConditionOp = DAG.getTargetConstant(Condition, DL, MVT::i64);
8216 SmallVector<SDValue> Ops = {Chain, MSROp, ConditionOp};
8217 if (Condition != AArch64SME::Always) {
8218 assert(PStateSM && "PStateSM should be defined");
8219 Ops.push_back(PStateSM);
8220 }
8221 Ops.push_back(RegMask);
8222
8223 if (InGlue)
8224 Ops.push_back(InGlue);
8225
8226 unsigned Opcode = Enable ? AArch64ISD::SMSTART : AArch64ISD::SMSTOP;
8227 return DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
8228 }
8229
getSMCondition(const SMEAttrs & CallerAttrs,const SMEAttrs & CalleeAttrs)8230 static unsigned getSMCondition(const SMEAttrs &CallerAttrs,
8231 const SMEAttrs &CalleeAttrs) {
8232 if (!CallerAttrs.hasStreamingCompatibleInterface() ||
8233 CallerAttrs.hasStreamingBody())
8234 return AArch64SME::Always;
8235 if (CalleeAttrs.hasNonStreamingInterface())
8236 return AArch64SME::IfCallerIsStreaming;
8237 if (CalleeAttrs.hasStreamingInterface())
8238 return AArch64SME::IfCallerIsNonStreaming;
8239
8240 llvm_unreachable("Unsupported attributes");
8241 }
8242
8243 /// LowerCall - Lower a call to a callseq_start + CALL + callseq_end chain,
8244 /// and add input and output parameter nodes.
8245 SDValue
LowerCall(CallLoweringInfo & CLI,SmallVectorImpl<SDValue> & InVals) const8246 AArch64TargetLowering::LowerCall(CallLoweringInfo &CLI,
8247 SmallVectorImpl<SDValue> &InVals) const {
8248 SelectionDAG &DAG = CLI.DAG;
8249 SDLoc &DL = CLI.DL;
8250 SmallVector<ISD::OutputArg, 32> &Outs = CLI.Outs;
8251 SmallVector<SDValue, 32> &OutVals = CLI.OutVals;
8252 SmallVector<ISD::InputArg, 32> &Ins = CLI.Ins;
8253 SDValue Chain = CLI.Chain;
8254 SDValue Callee = CLI.Callee;
8255 bool &IsTailCall = CLI.IsTailCall;
8256 CallingConv::ID &CallConv = CLI.CallConv;
8257 bool IsVarArg = CLI.IsVarArg;
8258
8259 MachineFunction &MF = DAG.getMachineFunction();
8260 MachineFunction::CallSiteInfo CSInfo;
8261 bool IsThisReturn = false;
8262
8263 AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
8264 bool TailCallOpt = MF.getTarget().Options.GuaranteedTailCallOpt;
8265 bool IsCFICall = CLI.CB && CLI.CB->isIndirectCall() && CLI.CFIType;
8266 bool IsSibCall = false;
8267 bool GuardWithBTI = false;
8268
8269 if (CLI.CB && CLI.CB->hasFnAttr(Attribute::ReturnsTwice) &&
8270 !Subtarget->noBTIAtReturnTwice()) {
8271 GuardWithBTI = FuncInfo->branchTargetEnforcement();
8272 }
8273
8274 // Analyze operands of the call, assigning locations to each operand.
8275 SmallVector<CCValAssign, 16> ArgLocs;
8276 CCState CCInfo(CallConv, IsVarArg, MF, ArgLocs, *DAG.getContext());
8277
8278 if (IsVarArg) {
8279 unsigned NumArgs = Outs.size();
8280
8281 for (unsigned i = 0; i != NumArgs; ++i) {
8282 if (!Outs[i].IsFixed && Outs[i].VT.isScalableVector())
8283 report_fatal_error("Passing SVE types to variadic functions is "
8284 "currently not supported");
8285 }
8286 }
8287
8288 analyzeCallOperands(*this, Subtarget, CLI, CCInfo);
8289
8290 CCAssignFn *RetCC = CCAssignFnForReturn(CallConv);
8291 // Assign locations to each value returned by this call.
8292 SmallVector<CCValAssign, 16> RVLocs;
8293 CCState RetCCInfo(CallConv, IsVarArg, DAG.getMachineFunction(), RVLocs,
8294 *DAG.getContext());
8295 RetCCInfo.AnalyzeCallResult(Ins, RetCC);
8296
8297 // Check callee args/returns for SVE registers and set calling convention
8298 // accordingly.
8299 if (CallConv == CallingConv::C || CallConv == CallingConv::Fast) {
8300 auto HasSVERegLoc = [](CCValAssign &Loc) {
8301 if (!Loc.isRegLoc())
8302 return false;
8303 return AArch64::ZPRRegClass.contains(Loc.getLocReg()) ||
8304 AArch64::PPRRegClass.contains(Loc.getLocReg());
8305 };
8306 if (any_of(RVLocs, HasSVERegLoc) || any_of(ArgLocs, HasSVERegLoc))
8307 CallConv = CallingConv::AArch64_SVE_VectorCall;
8308 }
8309
8310 if (IsTailCall) {
8311 // Check if it's really possible to do a tail call.
8312 IsTailCall = isEligibleForTailCallOptimization(CLI);
8313
8314 // A sibling call is one where we're under the usual C ABI and not planning
8315 // to change that but can still do a tail call:
8316 if (!TailCallOpt && IsTailCall && CallConv != CallingConv::Tail &&
8317 CallConv != CallingConv::SwiftTail)
8318 IsSibCall = true;
8319
8320 if (IsTailCall)
8321 ++NumTailCalls;
8322 }
8323
8324 if (!IsTailCall && CLI.CB && CLI.CB->isMustTailCall())
8325 report_fatal_error("failed to perform tail call elimination on a call "
8326 "site marked musttail");
8327
8328 // Get a count of how many bytes are to be pushed on the stack.
8329 unsigned NumBytes = CCInfo.getStackSize();
8330
8331 if (IsSibCall) {
8332 // Since we're not changing the ABI to make this a tail call, the memory
8333 // operands are already available in the caller's incoming argument space.
8334 NumBytes = 0;
8335 }
8336
8337 // FPDiff is the byte offset of the call's argument area from the callee's.
8338 // Stores to callee stack arguments will be placed in FixedStackSlots offset
8339 // by this amount for a tail call. In a sibling call it must be 0 because the
8340 // caller will deallocate the entire stack and the callee still expects its
8341 // arguments to begin at SP+0. Completely unused for non-tail calls.
8342 int FPDiff = 0;
8343
8344 if (IsTailCall && !IsSibCall) {
8345 unsigned NumReusableBytes = FuncInfo->getBytesInStackArgArea();
8346
8347 // Since callee will pop argument stack as a tail call, we must keep the
8348 // popped size 16-byte aligned.
8349 NumBytes = alignTo(NumBytes, 16);
8350
8351 // FPDiff will be negative if this tail call requires more space than we
8352 // would automatically have in our incoming argument space. Positive if we
8353 // can actually shrink the stack.
8354 FPDiff = NumReusableBytes - NumBytes;
8355
8356 // Update the required reserved area if this is the tail call requiring the
8357 // most argument stack space.
8358 if (FPDiff < 0 && FuncInfo->getTailCallReservedStack() < (unsigned)-FPDiff)
8359 FuncInfo->setTailCallReservedStack(-FPDiff);
8360
8361 // The stack pointer must be 16-byte aligned at all times it's used for a
8362 // memory operation, which in practice means at *all* times and in
8363 // particular across call boundaries. Therefore our own arguments started at
8364 // a 16-byte aligned SP and the delta applied for the tail call should
8365 // satisfy the same constraint.
8366 assert(FPDiff % 16 == 0 && "unaligned stack on tail call");
8367 }
8368
8369 // Determine whether we need any streaming mode changes.
8370 SMEAttrs CalleeAttrs, CallerAttrs(MF.getFunction());
8371 if (CLI.CB)
8372 CalleeAttrs = SMEAttrs(*CLI.CB);
8373 else if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
8374 CalleeAttrs = SMEAttrs(ES->getSymbol());
8375
8376 auto DescribeCallsite =
8377 [&](OptimizationRemarkAnalysis &R) -> OptimizationRemarkAnalysis & {
8378 R << "call from '" << ore::NV("Caller", MF.getName()) << "' to '";
8379 if (auto *ES = dyn_cast<ExternalSymbolSDNode>(CLI.Callee))
8380 R << ore::NV("Callee", ES->getSymbol());
8381 else if (CLI.CB && CLI.CB->getCalledFunction())
8382 R << ore::NV("Callee", CLI.CB->getCalledFunction()->getName());
8383 else
8384 R << "unknown callee";
8385 R << "'";
8386 return R;
8387 };
8388
8389 bool RequiresLazySave = CallerAttrs.requiresLazySave(CalleeAttrs);
8390 if (RequiresLazySave) {
8391 const TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8392 MachinePointerInfo MPI =
8393 MachinePointerInfo::getStack(MF, TPIDR2.FrameIndex);
8394 SDValue TPIDR2ObjAddr = DAG.getFrameIndex(
8395 TPIDR2.FrameIndex,
8396 DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
8397 SDValue NumZaSaveSlicesAddr =
8398 DAG.getNode(ISD::ADD, DL, TPIDR2ObjAddr.getValueType(), TPIDR2ObjAddr,
8399 DAG.getConstant(8, DL, TPIDR2ObjAddr.getValueType()));
8400 SDValue NumZaSaveSlices = DAG.getNode(AArch64ISD::RDSVL, DL, MVT::i64,
8401 DAG.getConstant(1, DL, MVT::i32));
8402 Chain = DAG.getTruncStore(Chain, DL, NumZaSaveSlices, NumZaSaveSlicesAddr,
8403 MPI, MVT::i16);
8404 Chain = DAG.getNode(
8405 ISD::INTRINSIC_VOID, DL, MVT::Other, Chain,
8406 DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
8407 TPIDR2ObjAddr);
8408 OptimizationRemarkEmitter ORE(&MF.getFunction());
8409 ORE.emit([&]() {
8410 auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMELazySaveZA",
8411 CLI.CB)
8412 : OptimizationRemarkAnalysis("sme", "SMELazySaveZA",
8413 &MF.getFunction());
8414 return DescribeCallsite(R) << " sets up a lazy save for ZA";
8415 });
8416 }
8417
8418 SDValue PStateSM;
8419 bool RequiresSMChange = CallerAttrs.requiresSMChange(CalleeAttrs);
8420 if (RequiresSMChange) {
8421 if (CallerAttrs.hasStreamingInterfaceOrBody())
8422 PStateSM = DAG.getConstant(1, DL, MVT::i64);
8423 else if (CallerAttrs.hasNonStreamingInterface())
8424 PStateSM = DAG.getConstant(0, DL, MVT::i64);
8425 else
8426 PStateSM = getRuntimePStateSM(DAG, Chain, DL, MVT::i64);
8427 OptimizationRemarkEmitter ORE(&MF.getFunction());
8428 ORE.emit([&]() {
8429 auto R = CLI.CB ? OptimizationRemarkAnalysis("sme", "SMETransition",
8430 CLI.CB)
8431 : OptimizationRemarkAnalysis("sme", "SMETransition",
8432 &MF.getFunction());
8433 DescribeCallsite(R) << " requires a streaming mode transition";
8434 return R;
8435 });
8436 }
8437
8438 SDValue ZTFrameIdx;
8439 MachineFrameInfo &MFI = MF.getFrameInfo();
8440 bool ShouldPreserveZT0 = CallerAttrs.requiresPreservingZT0(CalleeAttrs);
8441
8442 // If the caller has ZT0 state which will not be preserved by the callee,
8443 // spill ZT0 before the call.
8444 if (ShouldPreserveZT0) {
8445 unsigned ZTObj = MFI.CreateSpillStackObject(64, Align(16));
8446 ZTFrameIdx = DAG.getFrameIndex(
8447 ZTObj,
8448 DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
8449
8450 Chain = DAG.getNode(AArch64ISD::SAVE_ZT, DL, DAG.getVTList(MVT::Other),
8451 {Chain, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
8452 }
8453
8454 // If caller shares ZT0 but the callee is not shared ZA, we need to stop
8455 // PSTATE.ZA before the call if there is no lazy-save active.
8456 bool DisableZA = CallerAttrs.requiresDisablingZABeforeCall(CalleeAttrs);
8457 assert((!DisableZA || !RequiresLazySave) &&
8458 "Lazy-save should have PSTATE.SM=1 on entry to the function");
8459
8460 if (DisableZA)
8461 Chain = DAG.getNode(
8462 AArch64ISD::SMSTOP, DL, MVT::Other, Chain,
8463 DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
8464 DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
8465
8466 // Adjust the stack pointer for the new arguments...
8467 // These operations are automatically eliminated by the prolog/epilog pass
8468 if (!IsSibCall)
8469 Chain = DAG.getCALLSEQ_START(Chain, IsTailCall ? 0 : NumBytes, 0, DL);
8470
8471 SDValue StackPtr = DAG.getCopyFromReg(Chain, DL, AArch64::SP,
8472 getPointerTy(DAG.getDataLayout()));
8473
8474 SmallVector<std::pair<unsigned, SDValue>, 8> RegsToPass;
8475 SmallSet<unsigned, 8> RegsUsed;
8476 SmallVector<SDValue, 8> MemOpChains;
8477 auto PtrVT = getPointerTy(DAG.getDataLayout());
8478
8479 if (IsVarArg && CLI.CB && CLI.CB->isMustTailCall()) {
8480 const auto &Forwards = FuncInfo->getForwardedMustTailRegParms();
8481 for (const auto &F : Forwards) {
8482 SDValue Val = DAG.getCopyFromReg(Chain, DL, F.VReg, F.VT);
8483 RegsToPass.emplace_back(F.PReg, Val);
8484 }
8485 }
8486
8487 // Walk the register/memloc assignments, inserting copies/loads.
8488 unsigned ExtraArgLocs = 0;
8489 for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
8490 CCValAssign &VA = ArgLocs[i - ExtraArgLocs];
8491 SDValue Arg = OutVals[i];
8492 ISD::ArgFlagsTy Flags = Outs[i].Flags;
8493
8494 // Promote the value if needed.
8495 switch (VA.getLocInfo()) {
8496 default:
8497 llvm_unreachable("Unknown loc info!");
8498 case CCValAssign::Full:
8499 break;
8500 case CCValAssign::SExt:
8501 Arg = DAG.getNode(ISD::SIGN_EXTEND, DL, VA.getLocVT(), Arg);
8502 break;
8503 case CCValAssign::ZExt:
8504 Arg = DAG.getNode(ISD::ZERO_EXTEND, DL, VA.getLocVT(), Arg);
8505 break;
8506 case CCValAssign::AExt:
8507 if (Outs[i].ArgVT == MVT::i1) {
8508 // AAPCS requires i1 to be zero-extended to 8-bits by the caller.
8509 //
8510 // Check if we actually have to do this, because the value may
8511 // already be zero-extended.
8512 //
8513 // We cannot just emit a (zext i8 (trunc (assert-zext i8)))
8514 // and rely on DAGCombiner to fold this, because the following
8515 // (anyext i32) is combined with (zext i8) in DAG.getNode:
8516 //
8517 // (ext (zext x)) -> (zext x)
8518 //
8519 // This will give us (zext i32), which we cannot remove, so
8520 // try to check this beforehand.
8521 if (!checkZExtBool(Arg, DAG)) {
8522 Arg = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Arg);
8523 Arg = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i8, Arg);
8524 }
8525 }
8526 Arg = DAG.getNode(ISD::ANY_EXTEND, DL, VA.getLocVT(), Arg);
8527 break;
8528 case CCValAssign::AExtUpper:
8529 assert(VA.getValVT() == MVT::i32 && "only expect 32 -> 64 upper bits");
8530 Arg = DAG.getNode(ISD::ANY_EXTEND, DL, VA.getLocVT(), Arg);
8531 Arg = DAG.getNode(ISD::SHL, DL, VA.getLocVT(), Arg,
8532 DAG.getConstant(32, DL, VA.getLocVT()));
8533 break;
8534 case CCValAssign::BCvt:
8535 Arg = DAG.getBitcast(VA.getLocVT(), Arg);
8536 break;
8537 case CCValAssign::Trunc:
8538 Arg = DAG.getZExtOrTrunc(Arg, DL, VA.getLocVT());
8539 break;
8540 case CCValAssign::FPExt:
8541 Arg = DAG.getNode(ISD::FP_EXTEND, DL, VA.getLocVT(), Arg);
8542 break;
8543 case CCValAssign::Indirect:
8544 bool isScalable = VA.getValVT().isScalableVT();
8545 assert((isScalable || Subtarget->isWindowsArm64EC()) &&
8546 "Indirect arguments should be scalable on most subtargets");
8547
8548 uint64_t StoreSize = VA.getValVT().getStoreSize().getKnownMinValue();
8549 uint64_t PartSize = StoreSize;
8550 unsigned NumParts = 1;
8551 if (Outs[i].Flags.isInConsecutiveRegs()) {
8552 while (!Outs[i + NumParts - 1].Flags.isInConsecutiveRegsLast())
8553 ++NumParts;
8554 StoreSize *= NumParts;
8555 }
8556
8557 Type *Ty = EVT(VA.getValVT()).getTypeForEVT(*DAG.getContext());
8558 Align Alignment = DAG.getDataLayout().getPrefTypeAlign(Ty);
8559 MachineFrameInfo &MFI = MF.getFrameInfo();
8560 int FI = MFI.CreateStackObject(StoreSize, Alignment, false);
8561 if (isScalable)
8562 MFI.setStackID(FI, TargetStackID::ScalableVector);
8563
8564 MachinePointerInfo MPI = MachinePointerInfo::getFixedStack(MF, FI);
8565 SDValue Ptr = DAG.getFrameIndex(
8566 FI, DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
8567 SDValue SpillSlot = Ptr;
8568
8569 // Ensure we generate all stores for each tuple part, whilst updating the
8570 // pointer after each store correctly using vscale.
8571 while (NumParts) {
8572 SDValue Store = DAG.getStore(Chain, DL, OutVals[i], Ptr, MPI);
8573 MemOpChains.push_back(Store);
8574
8575 NumParts--;
8576 if (NumParts > 0) {
8577 SDValue BytesIncrement;
8578 if (isScalable) {
8579 BytesIncrement = DAG.getVScale(
8580 DL, Ptr.getValueType(),
8581 APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize));
8582 } else {
8583 BytesIncrement = DAG.getConstant(
8584 APInt(Ptr.getValueSizeInBits().getFixedValue(), PartSize), DL,
8585 Ptr.getValueType());
8586 }
8587 SDNodeFlags Flags;
8588 Flags.setNoUnsignedWrap(true);
8589
8590 MPI = MachinePointerInfo(MPI.getAddrSpace());
8591 Ptr = DAG.getNode(ISD::ADD, DL, Ptr.getValueType(), Ptr,
8592 BytesIncrement, Flags);
8593 ExtraArgLocs++;
8594 i++;
8595 }
8596 }
8597
8598 Arg = SpillSlot;
8599 break;
8600 }
8601
8602 if (VA.isRegLoc()) {
8603 if (i == 0 && Flags.isReturned() && !Flags.isSwiftSelf() &&
8604 Outs[0].VT == MVT::i64) {
8605 assert(VA.getLocVT() == MVT::i64 &&
8606 "unexpected calling convention register assignment");
8607 assert(!Ins.empty() && Ins[0].VT == MVT::i64 &&
8608 "unexpected use of 'returned'");
8609 IsThisReturn = true;
8610 }
8611 if (RegsUsed.count(VA.getLocReg())) {
8612 // If this register has already been used then we're trying to pack
8613 // parts of an [N x i32] into an X-register. The extension type will
8614 // take care of putting the two halves in the right place but we have to
8615 // combine them.
8616 SDValue &Bits =
8617 llvm::find_if(RegsToPass,
8618 [=](const std::pair<unsigned, SDValue> &Elt) {
8619 return Elt.first == VA.getLocReg();
8620 })
8621 ->second;
8622 Bits = DAG.getNode(ISD::OR, DL, Bits.getValueType(), Bits, Arg);
8623 // Call site info is used for function's parameter entry value
8624 // tracking. For now we track only simple cases when parameter
8625 // is transferred through whole register.
8626 llvm::erase_if(CSInfo.ArgRegPairs,
8627 [&VA](MachineFunction::ArgRegPair ArgReg) {
8628 return ArgReg.Reg == VA.getLocReg();
8629 });
8630 } else {
8631 // Add an extra level of indirection for streaming mode changes by
8632 // using a pseudo copy node that cannot be rematerialised between a
8633 // smstart/smstop and the call by the simple register coalescer.
8634 if (RequiresSMChange && isPassedInFPR(Arg.getValueType()))
8635 Arg = DAG.getNode(AArch64ISD::COALESCER_BARRIER, DL,
8636 Arg.getValueType(), Arg);
8637 RegsToPass.emplace_back(VA.getLocReg(), Arg);
8638 RegsUsed.insert(VA.getLocReg());
8639 const TargetOptions &Options = DAG.getTarget().Options;
8640 if (Options.EmitCallSiteInfo)
8641 CSInfo.ArgRegPairs.emplace_back(VA.getLocReg(), i);
8642 }
8643 } else {
8644 assert(VA.isMemLoc());
8645
8646 SDValue DstAddr;
8647 MachinePointerInfo DstInfo;
8648
8649 // FIXME: This works on big-endian for composite byvals, which are the
8650 // common case. It should also work for fundamental types too.
8651 uint32_t BEAlign = 0;
8652 unsigned OpSize;
8653 if (VA.getLocInfo() == CCValAssign::Indirect ||
8654 VA.getLocInfo() == CCValAssign::Trunc)
8655 OpSize = VA.getLocVT().getFixedSizeInBits();
8656 else
8657 OpSize = Flags.isByVal() ? Flags.getByValSize() * 8
8658 : VA.getValVT().getSizeInBits();
8659 OpSize = (OpSize + 7) / 8;
8660 if (!Subtarget->isLittleEndian() && !Flags.isByVal() &&
8661 !Flags.isInConsecutiveRegs()) {
8662 if (OpSize < 8)
8663 BEAlign = 8 - OpSize;
8664 }
8665 unsigned LocMemOffset = VA.getLocMemOffset();
8666 int32_t Offset = LocMemOffset + BEAlign;
8667 SDValue PtrOff = DAG.getIntPtrConstant(Offset, DL);
8668 PtrOff = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, PtrOff);
8669
8670 if (IsTailCall) {
8671 Offset = Offset + FPDiff;
8672 int FI = MF.getFrameInfo().CreateFixedObject(OpSize, Offset, true);
8673
8674 DstAddr = DAG.getFrameIndex(FI, PtrVT);
8675 DstInfo = MachinePointerInfo::getFixedStack(MF, FI);
8676
8677 // Make sure any stack arguments overlapping with where we're storing
8678 // are loaded before this eventual operation. Otherwise they'll be
8679 // clobbered.
8680 Chain = addTokenForArgument(Chain, DAG, MF.getFrameInfo(), FI);
8681 } else {
8682 SDValue PtrOff = DAG.getIntPtrConstant(Offset, DL);
8683
8684 DstAddr = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, PtrOff);
8685 DstInfo = MachinePointerInfo::getStack(MF, LocMemOffset);
8686 }
8687
8688 if (Outs[i].Flags.isByVal()) {
8689 SDValue SizeNode =
8690 DAG.getConstant(Outs[i].Flags.getByValSize(), DL, MVT::i64);
8691 SDValue Cpy = DAG.getMemcpy(
8692 Chain, DL, DstAddr, Arg, SizeNode,
8693 Outs[i].Flags.getNonZeroByValAlign(),
8694 /*isVol = */ false, /*AlwaysInline = */ false,
8695 /*CI=*/nullptr, std::nullopt, DstInfo, MachinePointerInfo());
8696
8697 MemOpChains.push_back(Cpy);
8698 } else {
8699 // Since we pass i1/i8/i16 as i1/i8/i16 on stack and Arg is already
8700 // promoted to a legal register type i32, we should truncate Arg back to
8701 // i1/i8/i16.
8702 if (VA.getValVT() == MVT::i1 || VA.getValVT() == MVT::i8 ||
8703 VA.getValVT() == MVT::i16)
8704 Arg = DAG.getNode(ISD::TRUNCATE, DL, VA.getValVT(), Arg);
8705
8706 SDValue Store = DAG.getStore(Chain, DL, Arg, DstAddr, DstInfo);
8707 MemOpChains.push_back(Store);
8708 }
8709 }
8710 }
8711
8712 if (IsVarArg && Subtarget->isWindowsArm64EC()) {
8713 SDValue ParamPtr = StackPtr;
8714 if (IsTailCall) {
8715 // Create a dummy object at the top of the stack that can be used to get
8716 // the SP after the epilogue
8717 int FI = MF.getFrameInfo().CreateFixedObject(1, FPDiff, true);
8718 ParamPtr = DAG.getFrameIndex(FI, PtrVT);
8719 }
8720
8721 // For vararg calls, the Arm64EC ABI requires values in x4 and x5
8722 // describing the argument list. x4 contains the address of the
8723 // first stack parameter. x5 contains the size in bytes of all parameters
8724 // passed on the stack.
8725 RegsToPass.emplace_back(AArch64::X4, ParamPtr);
8726 RegsToPass.emplace_back(AArch64::X5,
8727 DAG.getConstant(NumBytes, DL, MVT::i64));
8728 }
8729
8730 if (!MemOpChains.empty())
8731 Chain = DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOpChains);
8732
8733 SDValue InGlue;
8734 if (RequiresSMChange) {
8735 if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
8736 Chain = DAG.getNode(AArch64ISD::VG_SAVE, DL,
8737 DAG.getVTList(MVT::Other, MVT::Glue), Chain);
8738 InGlue = Chain.getValue(1);
8739 }
8740
8741 SDValue NewChain = changeStreamingMode(
8742 DAG, DL, CalleeAttrs.hasStreamingInterface(), Chain, InGlue,
8743 getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
8744 Chain = NewChain.getValue(0);
8745 InGlue = NewChain.getValue(1);
8746 }
8747
8748 // Build a sequence of copy-to-reg nodes chained together with token chain
8749 // and flag operands which copy the outgoing args into the appropriate regs.
8750 for (auto &RegToPass : RegsToPass) {
8751 Chain = DAG.getCopyToReg(Chain, DL, RegToPass.first,
8752 RegToPass.second, InGlue);
8753 InGlue = Chain.getValue(1);
8754 }
8755
8756 // If the callee is a GlobalAddress/ExternalSymbol node (quite common, every
8757 // direct call is) turn it into a TargetGlobalAddress/TargetExternalSymbol
8758 // node so that legalize doesn't hack it.
8759 if (auto *G = dyn_cast<GlobalAddressSDNode>(Callee)) {
8760 auto GV = G->getGlobal();
8761 unsigned OpFlags =
8762 Subtarget->classifyGlobalFunctionReference(GV, getTargetMachine());
8763 if (OpFlags & AArch64II::MO_GOT) {
8764 Callee = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, OpFlags);
8765 Callee = DAG.getNode(AArch64ISD::LOADgot, DL, PtrVT, Callee);
8766 } else {
8767 const GlobalValue *GV = G->getGlobal();
8768 Callee = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, OpFlags);
8769 }
8770 } else if (auto *S = dyn_cast<ExternalSymbolSDNode>(Callee)) {
8771 bool UseGot = (getTargetMachine().getCodeModel() == CodeModel::Large &&
8772 Subtarget->isTargetMachO()) ||
8773 MF.getFunction().getParent()->getRtLibUseGOT();
8774 const char *Sym = S->getSymbol();
8775 if (UseGot) {
8776 Callee = DAG.getTargetExternalSymbol(Sym, PtrVT, AArch64II::MO_GOT);
8777 Callee = DAG.getNode(AArch64ISD::LOADgot, DL, PtrVT, Callee);
8778 } else {
8779 Callee = DAG.getTargetExternalSymbol(Sym, PtrVT, 0);
8780 }
8781 }
8782
8783 // We don't usually want to end the call-sequence here because we would tidy
8784 // the frame up *after* the call, however in the ABI-changing tail-call case
8785 // we've carefully laid out the parameters so that when sp is reset they'll be
8786 // in the correct location.
8787 if (IsTailCall && !IsSibCall) {
8788 Chain = DAG.getCALLSEQ_END(Chain, 0, 0, InGlue, DL);
8789 InGlue = Chain.getValue(1);
8790 }
8791
8792 unsigned Opc = IsTailCall ? AArch64ISD::TC_RETURN : AArch64ISD::CALL;
8793
8794 std::vector<SDValue> Ops;
8795 Ops.push_back(Chain);
8796 Ops.push_back(Callee);
8797
8798 // Calls with operand bundle "clang.arc.attachedcall" are special. They should
8799 // be expanded to the call, directly followed by a special marker sequence and
8800 // a call to an ObjC library function. Use CALL_RVMARKER to do that.
8801 if (CLI.CB && objcarc::hasAttachedCallOpBundle(CLI.CB)) {
8802 assert(!IsTailCall &&
8803 "tail calls cannot be marked with clang.arc.attachedcall");
8804 Opc = AArch64ISD::CALL_RVMARKER;
8805
8806 // Add a target global address for the retainRV/claimRV runtime function
8807 // just before the call target.
8808 Function *ARCFn = *objcarc::getAttachedARCFunction(CLI.CB);
8809 auto GA = DAG.getTargetGlobalAddress(ARCFn, DL, PtrVT);
8810 Ops.insert(Ops.begin() + 1, GA);
8811 } else if (CallConv == CallingConv::ARM64EC_Thunk_X64) {
8812 Opc = AArch64ISD::CALL_ARM64EC_TO_X64;
8813 } else if (GuardWithBTI) {
8814 Opc = AArch64ISD::CALL_BTI;
8815 }
8816
8817 if (IsTailCall) {
8818 // Each tail call may have to adjust the stack by a different amount, so
8819 // this information must travel along with the operation for eventual
8820 // consumption by emitEpilogue.
8821 Ops.push_back(DAG.getTargetConstant(FPDiff, DL, MVT::i32));
8822 }
8823
8824 if (CLI.PAI) {
8825 const uint64_t Key = CLI.PAI->Key;
8826 assert((Key == AArch64PACKey::IA || Key == AArch64PACKey::IB) &&
8827 "Invalid auth call key");
8828
8829 // Split the discriminator into address/integer components.
8830 SDValue AddrDisc, IntDisc;
8831 std::tie(IntDisc, AddrDisc) =
8832 extractPtrauthBlendDiscriminators(CLI.PAI->Discriminator, &DAG);
8833
8834 if (Opc == AArch64ISD::CALL_RVMARKER)
8835 Opc = AArch64ISD::AUTH_CALL_RVMARKER;
8836 else
8837 Opc = IsTailCall ? AArch64ISD::AUTH_TC_RETURN : AArch64ISD::AUTH_CALL;
8838 Ops.push_back(DAG.getTargetConstant(Key, DL, MVT::i32));
8839 Ops.push_back(IntDisc);
8840 Ops.push_back(AddrDisc);
8841 }
8842
8843 // Add argument registers to the end of the list so that they are known live
8844 // into the call.
8845 for (auto &RegToPass : RegsToPass)
8846 Ops.push_back(DAG.getRegister(RegToPass.first,
8847 RegToPass.second.getValueType()));
8848
8849 // Add a register mask operand representing the call-preserved registers.
8850 const uint32_t *Mask;
8851 const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
8852 if (IsThisReturn) {
8853 // For 'this' returns, use the X0-preserving mask if applicable
8854 Mask = TRI->getThisReturnPreservedMask(MF, CallConv);
8855 if (!Mask) {
8856 IsThisReturn = false;
8857 Mask = TRI->getCallPreservedMask(MF, CallConv);
8858 }
8859 } else
8860 Mask = TRI->getCallPreservedMask(MF, CallConv);
8861
8862 if (Subtarget->hasCustomCallingConv())
8863 TRI->UpdateCustomCallPreservedMask(MF, &Mask);
8864
8865 if (TRI->isAnyArgRegReserved(MF))
8866 TRI->emitReservedArgRegCallError(MF);
8867
8868 assert(Mask && "Missing call preserved mask for calling convention");
8869 Ops.push_back(DAG.getRegisterMask(Mask));
8870
8871 if (InGlue.getNode())
8872 Ops.push_back(InGlue);
8873
8874 SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
8875
8876 // If we're doing a tall call, use a TC_RETURN here rather than an
8877 // actual call instruction.
8878 if (IsTailCall) {
8879 MF.getFrameInfo().setHasTailCall();
8880 SDValue Ret = DAG.getNode(Opc, DL, NodeTys, Ops);
8881 if (IsCFICall)
8882 Ret.getNode()->setCFIType(CLI.CFIType->getZExtValue());
8883
8884 DAG.addNoMergeSiteInfo(Ret.getNode(), CLI.NoMerge);
8885 DAG.addCallSiteInfo(Ret.getNode(), std::move(CSInfo));
8886 return Ret;
8887 }
8888
8889 // Returns a chain and a flag for retval copy to use.
8890 Chain = DAG.getNode(Opc, DL, NodeTys, Ops);
8891 if (IsCFICall)
8892 Chain.getNode()->setCFIType(CLI.CFIType->getZExtValue());
8893
8894 DAG.addNoMergeSiteInfo(Chain.getNode(), CLI.NoMerge);
8895 InGlue = Chain.getValue(1);
8896 DAG.addCallSiteInfo(Chain.getNode(), std::move(CSInfo));
8897
8898 uint64_t CalleePopBytes =
8899 DoesCalleeRestoreStack(CallConv, TailCallOpt) ? alignTo(NumBytes, 16) : 0;
8900
8901 Chain = DAG.getCALLSEQ_END(Chain, NumBytes, CalleePopBytes, InGlue, DL);
8902 InGlue = Chain.getValue(1);
8903
8904 // Handle result values, copying them out of physregs into vregs that we
8905 // return.
8906 SDValue Result = LowerCallResult(
8907 Chain, InGlue, CallConv, IsVarArg, RVLocs, DL, DAG, InVals, IsThisReturn,
8908 IsThisReturn ? OutVals[0] : SDValue(), RequiresSMChange);
8909
8910 if (!Ins.empty())
8911 InGlue = Result.getValue(Result->getNumValues() - 1);
8912
8913 if (RequiresSMChange) {
8914 assert(PStateSM && "Expected a PStateSM to be set");
8915 Result = changeStreamingMode(
8916 DAG, DL, !CalleeAttrs.hasStreamingInterface(), Result, InGlue,
8917 getSMCondition(CallerAttrs, CalleeAttrs), PStateSM);
8918
8919 if (!Subtarget->isTargetDarwin() || Subtarget->hasSVE()) {
8920 InGlue = Result.getValue(1);
8921 Result =
8922 DAG.getNode(AArch64ISD::VG_RESTORE, DL,
8923 DAG.getVTList(MVT::Other, MVT::Glue), {Result, InGlue});
8924 }
8925 }
8926
8927 if (CallerAttrs.requiresEnablingZAAfterCall(CalleeAttrs))
8928 // Unconditionally resume ZA.
8929 Result = DAG.getNode(
8930 AArch64ISD::SMSTART, DL, MVT::Other, Result,
8931 DAG.getTargetConstant((int32_t)(AArch64SVCR::SVCRZA), DL, MVT::i32),
8932 DAG.getConstant(AArch64SME::Always, DL, MVT::i64));
8933
8934 if (ShouldPreserveZT0)
8935 Result =
8936 DAG.getNode(AArch64ISD::RESTORE_ZT, DL, DAG.getVTList(MVT::Other),
8937 {Result, DAG.getConstant(0, DL, MVT::i32), ZTFrameIdx});
8938
8939 if (RequiresLazySave) {
8940 // Conditionally restore the lazy save using a pseudo node.
8941 TPIDR2Object &TPIDR2 = FuncInfo->getTPIDR2Obj();
8942 SDValue RegMask = DAG.getRegisterMask(
8943 TRI->SMEABISupportRoutinesCallPreservedMaskFromX0());
8944 SDValue RestoreRoutine = DAG.getTargetExternalSymbol(
8945 "__arm_tpidr2_restore", getPointerTy(DAG.getDataLayout()));
8946 SDValue TPIDR2_EL0 = DAG.getNode(
8947 ISD::INTRINSIC_W_CHAIN, DL, MVT::i64, Result,
8948 DAG.getConstant(Intrinsic::aarch64_sme_get_tpidr2, DL, MVT::i32));
8949
8950 // Copy the address of the TPIDR2 block into X0 before 'calling' the
8951 // RESTORE_ZA pseudo.
8952 SDValue Glue;
8953 SDValue TPIDR2Block = DAG.getFrameIndex(
8954 TPIDR2.FrameIndex,
8955 DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout()));
8956 Result = DAG.getCopyToReg(Result, DL, AArch64::X0, TPIDR2Block, Glue);
8957 Result =
8958 DAG.getNode(AArch64ISD::RESTORE_ZA, DL, MVT::Other,
8959 {Result, TPIDR2_EL0, DAG.getRegister(AArch64::X0, MVT::i64),
8960 RestoreRoutine, RegMask, Result.getValue(1)});
8961
8962 // Finally reset the TPIDR2_EL0 register to 0.
8963 Result = DAG.getNode(
8964 ISD::INTRINSIC_VOID, DL, MVT::Other, Result,
8965 DAG.getConstant(Intrinsic::aarch64_sme_set_tpidr2, DL, MVT::i32),
8966 DAG.getConstant(0, DL, MVT::i64));
8967 TPIDR2.Uses++;
8968 }
8969
8970 if (RequiresSMChange || RequiresLazySave || ShouldPreserveZT0) {
8971 for (unsigned I = 0; I < InVals.size(); ++I) {
8972 // The smstart/smstop is chained as part of the call, but when the
8973 // resulting chain is discarded (which happens when the call is not part
8974 // of a chain, e.g. a call to @llvm.cos()), we need to ensure the
8975 // smstart/smstop is chained to the result value. We can do that by doing
8976 // a vreg -> vreg copy.
8977 Register Reg = MF.getRegInfo().createVirtualRegister(
8978 getRegClassFor(InVals[I].getValueType().getSimpleVT()));
8979 SDValue X = DAG.getCopyToReg(Result, DL, Reg, InVals[I]);
8980 InVals[I] = DAG.getCopyFromReg(X, DL, Reg,
8981 InVals[I].getValueType());
8982 }
8983 }
8984
8985 if (CallConv == CallingConv::PreserveNone) {
8986 for (const ISD::OutputArg &O : Outs) {
8987 if (O.Flags.isSwiftSelf() || O.Flags.isSwiftError() ||
8988 O.Flags.isSwiftAsync()) {
8989 MachineFunction &MF = DAG.getMachineFunction();
8990 DAG.getContext()->diagnose(DiagnosticInfoUnsupported(
8991 MF.getFunction(),
8992 "Swift attributes can't be used with preserve_none",
8993 DL.getDebugLoc()));
8994 break;
8995 }
8996 }
8997 }
8998
8999 return Result;
9000 }
9001
CanLowerReturn(CallingConv::ID CallConv,MachineFunction & MF,bool isVarArg,const SmallVectorImpl<ISD::OutputArg> & Outs,LLVMContext & Context) const9002 bool AArch64TargetLowering::CanLowerReturn(
9003 CallingConv::ID CallConv, MachineFunction &MF, bool isVarArg,
9004 const SmallVectorImpl<ISD::OutputArg> &Outs, LLVMContext &Context) const {
9005 CCAssignFn *RetCC = CCAssignFnForReturn(CallConv);
9006 SmallVector<CCValAssign, 16> RVLocs;
9007 CCState CCInfo(CallConv, isVarArg, MF, RVLocs, Context);
9008 return CCInfo.CheckReturn(Outs, RetCC);
9009 }
9010
9011 SDValue
LowerReturn(SDValue Chain,CallingConv::ID CallConv,bool isVarArg,const SmallVectorImpl<ISD::OutputArg> & Outs,const SmallVectorImpl<SDValue> & OutVals,const SDLoc & DL,SelectionDAG & DAG) const9012 AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,
9013 bool isVarArg,
9014 const SmallVectorImpl<ISD::OutputArg> &Outs,
9015 const SmallVectorImpl<SDValue> &OutVals,
9016 const SDLoc &DL, SelectionDAG &DAG) const {
9017 auto &MF = DAG.getMachineFunction();
9018 auto *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
9019
9020 CCAssignFn *RetCC = CCAssignFnForReturn(CallConv);
9021 SmallVector<CCValAssign, 16> RVLocs;
9022 CCState CCInfo(CallConv, isVarArg, MF, RVLocs, *DAG.getContext());
9023 CCInfo.AnalyzeReturn(Outs, RetCC);
9024
9025 // Copy the result values into the output registers.
9026 SDValue Glue;
9027 SmallVector<std::pair<unsigned, SDValue>, 4> RetVals;
9028 SmallSet<unsigned, 4> RegsUsed;
9029 for (unsigned i = 0, realRVLocIdx = 0; i != RVLocs.size();
9030 ++i, ++realRVLocIdx) {
9031 CCValAssign &VA = RVLocs[i];
9032 assert(VA.isRegLoc() && "Can only return in registers!");
9033 SDValue Arg = OutVals[realRVLocIdx];
9034
9035 switch (VA.getLocInfo()) {
9036 default:
9037 llvm_unreachable("Unknown loc info!");
9038 case CCValAssign::Full:
9039 if (Outs[i].ArgVT == MVT::i1) {
9040 // AAPCS requires i1 to be zero-extended to i8 by the producer of the
9041 // value. This is strictly redundant on Darwin (which uses "zeroext
9042 // i1"), but will be optimised out before ISel.
9043 Arg = DAG.getNode(ISD::TRUNCATE, DL, MVT::i1, Arg);
9044 Arg = DAG.getNode(ISD::ZERO_EXTEND, DL, VA.getLocVT(), Arg);
9045 }
9046 break;
9047 case CCValAssign::BCvt:
9048 Arg = DAG.getNode(ISD::BITCAST, DL, VA.getLocVT(), Arg);
9049 break;
9050 case CCValAssign::AExt:
9051 case CCValAssign::ZExt:
9052 Arg = DAG.getZExtOrTrunc(Arg, DL, VA.getLocVT());
9053 break;
9054 case CCValAssign::AExtUpper:
9055 assert(VA.getValVT() == MVT::i32 && "only expect 32 -> 64 upper bits");
9056 Arg = DAG.getZExtOrTrunc(Arg, DL, VA.getLocVT());
9057 Arg = DAG.getNode(ISD::SHL, DL, VA.getLocVT(), Arg,
9058 DAG.getConstant(32, DL, VA.getLocVT()));
9059 break;
9060 }
9061
9062 if (RegsUsed.count(VA.getLocReg())) {
9063 SDValue &Bits =
9064 llvm::find_if(RetVals, [=](const std::pair<unsigned, SDValue> &Elt) {
9065 return Elt.first == VA.getLocReg();
9066 })->second;
9067 Bits = DAG.getNode(ISD::OR, DL, Bits.getValueType(), Bits, Arg);
9068 } else {
9069 RetVals.emplace_back(VA.getLocReg(), Arg);
9070 RegsUsed.insert(VA.getLocReg());
9071 }
9072 }
9073
9074 const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
9075
9076 // Emit SMSTOP before returning from a locally streaming function
9077 SMEAttrs FuncAttrs(MF.getFunction());
9078 if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface()) {
9079 if (FuncAttrs.hasStreamingCompatibleInterface()) {
9080 Register Reg = FuncInfo->getPStateSMReg();
9081 assert(Reg.isValid() && "PStateSM Register is invalid");
9082 SDValue PStateSM = DAG.getCopyFromReg(Chain, DL, Reg, MVT::i64);
9083 Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
9084 /*Glue*/ SDValue(),
9085 AArch64SME::IfCallerIsNonStreaming, PStateSM);
9086 } else
9087 Chain = changeStreamingMode(DAG, DL, /*Enable*/ false, Chain,
9088 /*Glue*/ SDValue(), AArch64SME::Always);
9089 Glue = Chain.getValue(1);
9090 }
9091
9092 SmallVector<SDValue, 4> RetOps(1, Chain);
9093 for (auto &RetVal : RetVals) {
9094 if (FuncAttrs.hasStreamingBody() && !FuncAttrs.hasStreamingInterface() &&
9095 isPassedInFPR(RetVal.second.getValueType()))
9096 RetVal.second = DAG.getNode(AArch64ISD::COALESCER_BARRIER, DL,
9097 RetVal.second.getValueType(), RetVal.second);
9098 Chain = DAG.getCopyToReg(Chain, DL, RetVal.first, RetVal.second, Glue);
9099 Glue = Chain.getValue(1);
9100 RetOps.push_back(
9101 DAG.getRegister(RetVal.first, RetVal.second.getValueType()));
9102 }
9103
9104 // Windows AArch64 ABIs require that for returning structs by value we copy
9105 // the sret argument into X0 for the return.
9106 // We saved the argument into a virtual register in the entry block,
9107 // so now we copy the value out and into X0.
9108 if (unsigned SRetReg = FuncInfo->getSRetReturnReg()) {
9109 SDValue Val = DAG.getCopyFromReg(RetOps[0], DL, SRetReg,
9110 getPointerTy(MF.getDataLayout()));
9111
9112 unsigned RetValReg = AArch64::X0;
9113 if (CallConv == CallingConv::ARM64EC_Thunk_X64)
9114 RetValReg = AArch64::X8;
9115 Chain = DAG.getCopyToReg(Chain, DL, RetValReg, Val, Glue);
9116 Glue = Chain.getValue(1);
9117
9118 RetOps.push_back(
9119 DAG.getRegister(RetValReg, getPointerTy(DAG.getDataLayout())));
9120 }
9121
9122 const MCPhysReg *I = TRI->getCalleeSavedRegsViaCopy(&MF);
9123 if (I) {
9124 for (; *I; ++I) {
9125 if (AArch64::GPR64RegClass.contains(*I))
9126 RetOps.push_back(DAG.getRegister(*I, MVT::i64));
9127 else if (AArch64::FPR64RegClass.contains(*I))
9128 RetOps.push_back(DAG.getRegister(*I, MVT::getFloatingPointVT(64)));
9129 else
9130 llvm_unreachable("Unexpected register class in CSRsViaCopy!");
9131 }
9132 }
9133
9134 RetOps[0] = Chain; // Update chain.
9135
9136 // Add the glue if we have it.
9137 if (Glue.getNode())
9138 RetOps.push_back(Glue);
9139
9140 if (CallConv == CallingConv::ARM64EC_Thunk_X64) {
9141 // ARM64EC entry thunks use a special return sequence: instead of a regular
9142 // "ret" instruction, they need to explicitly call the emulator.
9143 EVT PtrVT = getPointerTy(DAG.getDataLayout());
9144 SDValue Arm64ECRetDest =
9145 DAG.getExternalSymbol("__os_arm64x_dispatch_ret", PtrVT);
9146 Arm64ECRetDest =
9147 getAddr(cast<ExternalSymbolSDNode>(Arm64ECRetDest), DAG, 0);
9148 Arm64ECRetDest = DAG.getLoad(PtrVT, DL, DAG.getEntryNode(), Arm64ECRetDest,
9149 MachinePointerInfo());
9150 RetOps.insert(RetOps.begin() + 1, Arm64ECRetDest);
9151 RetOps.insert(RetOps.begin() + 2, DAG.getTargetConstant(0, DL, MVT::i32));
9152 return DAG.getNode(AArch64ISD::TC_RETURN, DL, MVT::Other, RetOps);
9153 }
9154
9155 return DAG.getNode(AArch64ISD::RET_GLUE, DL, MVT::Other, RetOps);
9156 }
9157
9158 //===----------------------------------------------------------------------===//
9159 // Other Lowering Code
9160 //===----------------------------------------------------------------------===//
9161
getTargetNode(GlobalAddressSDNode * N,EVT Ty,SelectionDAG & DAG,unsigned Flag) const9162 SDValue AArch64TargetLowering::getTargetNode(GlobalAddressSDNode *N, EVT Ty,
9163 SelectionDAG &DAG,
9164 unsigned Flag) const {
9165 return DAG.getTargetGlobalAddress(N->getGlobal(), SDLoc(N), Ty,
9166 N->getOffset(), Flag);
9167 }
9168
getTargetNode(JumpTableSDNode * N,EVT Ty,SelectionDAG & DAG,unsigned Flag) const9169 SDValue AArch64TargetLowering::getTargetNode(JumpTableSDNode *N, EVT Ty,
9170 SelectionDAG &DAG,
9171 unsigned Flag) const {
9172 return DAG.getTargetJumpTable(N->getIndex(), Ty, Flag);
9173 }
9174
getTargetNode(ConstantPoolSDNode * N,EVT Ty,SelectionDAG & DAG,unsigned Flag) const9175 SDValue AArch64TargetLowering::getTargetNode(ConstantPoolSDNode *N, EVT Ty,
9176 SelectionDAG &DAG,
9177 unsigned Flag) const {
9178 return DAG.getTargetConstantPool(N->getConstVal(), Ty, N->getAlign(),
9179 N->getOffset(), Flag);
9180 }
9181
getTargetNode(BlockAddressSDNode * N,EVT Ty,SelectionDAG & DAG,unsigned Flag) const9182 SDValue AArch64TargetLowering::getTargetNode(BlockAddressSDNode* N, EVT Ty,
9183 SelectionDAG &DAG,
9184 unsigned Flag) const {
9185 return DAG.getTargetBlockAddress(N->getBlockAddress(), Ty, 0, Flag);
9186 }
9187
getTargetNode(ExternalSymbolSDNode * N,EVT Ty,SelectionDAG & DAG,unsigned Flag) const9188 SDValue AArch64TargetLowering::getTargetNode(ExternalSymbolSDNode *N, EVT Ty,
9189 SelectionDAG &DAG,
9190 unsigned Flag) const {
9191 return DAG.getTargetExternalSymbol(N->getSymbol(), Ty, Flag);
9192 }
9193
9194 // (loadGOT sym)
9195 template <class NodeTy>
getGOT(NodeTy * N,SelectionDAG & DAG,unsigned Flags) const9196 SDValue AArch64TargetLowering::getGOT(NodeTy *N, SelectionDAG &DAG,
9197 unsigned Flags) const {
9198 LLVM_DEBUG(dbgs() << "AArch64TargetLowering::getGOT\n");
9199 SDLoc DL(N);
9200 EVT Ty = getPointerTy(DAG.getDataLayout());
9201 SDValue GotAddr = getTargetNode(N, Ty, DAG, AArch64II::MO_GOT | Flags);
9202 // FIXME: Once remat is capable of dealing with instructions with register
9203 // operands, expand this into two nodes instead of using a wrapper node.
9204 return DAG.getNode(AArch64ISD::LOADgot, DL, Ty, GotAddr);
9205 }
9206
9207 // (wrapper %highest(sym), %higher(sym), %hi(sym), %lo(sym))
9208 template <class NodeTy>
getAddrLarge(NodeTy * N,SelectionDAG & DAG,unsigned Flags) const9209 SDValue AArch64TargetLowering::getAddrLarge(NodeTy *N, SelectionDAG &DAG,
9210 unsigned Flags) const {
9211 LLVM_DEBUG(dbgs() << "AArch64TargetLowering::getAddrLarge\n");
9212 SDLoc DL(N);
9213 EVT Ty = getPointerTy(DAG.getDataLayout());
9214 const unsigned char MO_NC = AArch64II::MO_NC;
9215 return DAG.getNode(
9216 AArch64ISD::WrapperLarge, DL, Ty,
9217 getTargetNode(N, Ty, DAG, AArch64II::MO_G3 | Flags),
9218 getTargetNode(N, Ty, DAG, AArch64II::MO_G2 | MO_NC | Flags),
9219 getTargetNode(N, Ty, DAG, AArch64II::MO_G1 | MO_NC | Flags),
9220 getTargetNode(N, Ty, DAG, AArch64II::MO_G0 | MO_NC | Flags));
9221 }
9222
9223 // (addlow (adrp %hi(sym)) %lo(sym))
9224 template <class NodeTy>
getAddr(NodeTy * N,SelectionDAG & DAG,unsigned Flags) const9225 SDValue AArch64TargetLowering::getAddr(NodeTy *N, SelectionDAG &DAG,
9226 unsigned Flags) const {
9227 LLVM_DEBUG(dbgs() << "AArch64TargetLowering::getAddr\n");
9228 SDLoc DL(N);
9229 EVT Ty = getPointerTy(DAG.getDataLayout());
9230 SDValue Hi = getTargetNode(N, Ty, DAG, AArch64II::MO_PAGE | Flags);
9231 SDValue Lo = getTargetNode(N, Ty, DAG,
9232 AArch64II::MO_PAGEOFF | AArch64II::MO_NC | Flags);
9233 SDValue ADRP = DAG.getNode(AArch64ISD::ADRP, DL, Ty, Hi);
9234 return DAG.getNode(AArch64ISD::ADDlow, DL, Ty, ADRP, Lo);
9235 }
9236
9237 // (adr sym)
9238 template <class NodeTy>
getAddrTiny(NodeTy * N,SelectionDAG & DAG,unsigned Flags) const9239 SDValue AArch64TargetLowering::getAddrTiny(NodeTy *N, SelectionDAG &DAG,
9240 unsigned Flags) const {
9241 LLVM_DEBUG(dbgs() << "AArch64TargetLowering::getAddrTiny\n");
9242 SDLoc DL(N);
9243 EVT Ty = getPointerTy(DAG.getDataLayout());
9244 SDValue Sym = getTargetNode(N, Ty, DAG, Flags);
9245 return DAG.getNode(AArch64ISD::ADR, DL, Ty, Sym);
9246 }
9247
LowerGlobalAddress(SDValue Op,SelectionDAG & DAG) const9248 SDValue AArch64TargetLowering::LowerGlobalAddress(SDValue Op,
9249 SelectionDAG &DAG) const {
9250 GlobalAddressSDNode *GN = cast<GlobalAddressSDNode>(Op);
9251 const GlobalValue *GV = GN->getGlobal();
9252 unsigned OpFlags = Subtarget->ClassifyGlobalReference(GV, getTargetMachine());
9253
9254 if (OpFlags != AArch64II::MO_NO_FLAG)
9255 assert(cast<GlobalAddressSDNode>(Op)->getOffset() == 0 &&
9256 "unexpected offset in global node");
9257
9258 // This also catches the large code model case for Darwin, and tiny code
9259 // model with got relocations.
9260 if ((OpFlags & AArch64II::MO_GOT) != 0) {
9261 return getGOT(GN, DAG, OpFlags);
9262 }
9263
9264 SDValue Result;
9265 if (getTargetMachine().getCodeModel() == CodeModel::Large &&
9266 !getTargetMachine().isPositionIndependent()) {
9267 Result = getAddrLarge(GN, DAG, OpFlags);
9268 } else if (getTargetMachine().getCodeModel() == CodeModel::Tiny) {
9269 Result = getAddrTiny(GN, DAG, OpFlags);
9270 } else {
9271 Result = getAddr(GN, DAG, OpFlags);
9272 }
9273 EVT PtrVT = getPointerTy(DAG.getDataLayout());
9274 SDLoc DL(GN);
9275 if (OpFlags & (AArch64II::MO_DLLIMPORT | AArch64II::MO_COFFSTUB))
9276 Result = DAG.getLoad(PtrVT, DL, DAG.getEntryNode(), Result,
9277 MachinePointerInfo::getGOT(DAG.getMachineFunction()));
9278 return Result;
9279 }
9280
9281 /// Convert a TLS address reference into the correct sequence of loads
9282 /// and calls to compute the variable's address (for Darwin, currently) and
9283 /// return an SDValue containing the final node.
9284
9285 /// Darwin only has one TLS scheme which must be capable of dealing with the
9286 /// fully general situation, in the worst case. This means:
9287 /// + "extern __thread" declaration.
9288 /// + Defined in a possibly unknown dynamic library.
9289 ///
9290 /// The general system is that each __thread variable has a [3 x i64] descriptor
9291 /// which contains information used by the runtime to calculate the address. The
9292 /// only part of this the compiler needs to know about is the first xword, which
9293 /// contains a function pointer that must be called with the address of the
9294 /// entire descriptor in "x0".
9295 ///
9296 /// Since this descriptor may be in a different unit, in general even the
9297 /// descriptor must be accessed via an indirect load. The "ideal" code sequence
9298 /// is:
9299 /// adrp x0, _var@TLVPPAGE
9300 /// ldr x0, [x0, _var@TLVPPAGEOFF] ; x0 now contains address of descriptor
9301 /// ldr x1, [x0] ; x1 contains 1st entry of descriptor,
9302 /// ; the function pointer
9303 /// blr x1 ; Uses descriptor address in x0
9304 /// ; Address of _var is now in x0.
9305 ///
9306 /// If the address of _var's descriptor *is* known to the linker, then it can
9307 /// change the first "ldr" instruction to an appropriate "add x0, x0, #imm" for
9308 /// a slight efficiency gain.
9309 SDValue
LowerDarwinGlobalTLSAddress(SDValue Op,SelectionDAG & DAG) const9310 AArch64TargetLowering::LowerDarwinGlobalTLSAddress(SDValue Op,
9311 SelectionDAG &DAG) const {
9312 assert(Subtarget->isTargetDarwin() &&
9313 "This function expects a Darwin target");
9314
9315 SDLoc DL(Op);
9316 MVT PtrVT = getPointerTy(DAG.getDataLayout());
9317 MVT PtrMemVT = getPointerMemTy(DAG.getDataLayout());
9318 const GlobalValue *GV = cast<GlobalAddressSDNode>(Op)->getGlobal();
9319
9320 SDValue TLVPAddr =
9321 DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, AArch64II::MO_TLS);
9322 SDValue DescAddr = DAG.getNode(AArch64ISD::LOADgot, DL, PtrVT, TLVPAddr);
9323
9324 // The first entry in the descriptor is a function pointer that we must call
9325 // to obtain the address of the variable.
9326 SDValue Chain = DAG.getEntryNode();
9327 SDValue FuncTLVGet = DAG.getLoad(
9328 PtrMemVT, DL, Chain, DescAddr,
9329 MachinePointerInfo::getGOT(DAG.getMachineFunction()),
9330 Align(PtrMemVT.getSizeInBits() / 8),
9331 MachineMemOperand::MOInvariant | MachineMemOperand::MODereferenceable);
9332 Chain = FuncTLVGet.getValue(1);
9333
9334 // Extend loaded pointer if necessary (i.e. if ILP32) to DAG pointer.
9335 FuncTLVGet = DAG.getZExtOrTrunc(FuncTLVGet, DL, PtrVT);
9336
9337 MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
9338 MFI.setAdjustsStack(true);
9339
9340 // TLS calls preserve all registers except those that absolutely must be
9341 // trashed: X0 (it takes an argument), LR (it's a call) and NZCV (let's not be
9342 // silly).
9343 const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
9344 const uint32_t *Mask = TRI->getTLSCallPreservedMask();
9345 if (Subtarget->hasCustomCallingConv())
9346 TRI->UpdateCustomCallPreservedMask(DAG.getMachineFunction(), &Mask);
9347
9348 // Finally, we can make the call. This is just a degenerate version of a
9349 // normal AArch64 call node: x0 takes the address of the descriptor, and
9350 // returns the address of the variable in this thread.
9351 Chain = DAG.getCopyToReg(Chain, DL, AArch64::X0, DescAddr, SDValue());
9352
9353 unsigned Opcode = AArch64ISD::CALL;
9354 SmallVector<SDValue, 8> Ops;
9355 Ops.push_back(Chain);
9356 Ops.push_back(FuncTLVGet);
9357
9358 // With ptrauth-calls, the tlv access thunk pointer is authenticated (IA, 0).
9359 if (DAG.getMachineFunction().getFunction().hasFnAttribute("ptrauth-calls")) {
9360 Opcode = AArch64ISD::AUTH_CALL;
9361 Ops.push_back(DAG.getTargetConstant(AArch64PACKey::IA, DL, MVT::i32));
9362 Ops.push_back(DAG.getTargetConstant(0, DL, MVT::i64)); // Integer Disc.
9363 Ops.push_back(DAG.getRegister(AArch64::NoRegister, MVT::i64)); // Addr Disc.
9364 }
9365
9366 Ops.push_back(DAG.getRegister(AArch64::X0, MVT::i64));
9367 Ops.push_back(DAG.getRegisterMask(Mask));
9368 Ops.push_back(Chain.getValue(1));
9369 Chain = DAG.getNode(Opcode, DL, DAG.getVTList(MVT::Other, MVT::Glue), Ops);
9370 return DAG.getCopyFromReg(Chain, DL, AArch64::X0, PtrVT, Chain.getValue(1));
9371 }
9372
9373 /// Convert a thread-local variable reference into a sequence of instructions to
9374 /// compute the variable's address for the local exec TLS model of ELF targets.
9375 /// The sequence depends on the maximum TLS area size.
LowerELFTLSLocalExec(const GlobalValue * GV,SDValue ThreadBase,const SDLoc & DL,SelectionDAG & DAG) const9376 SDValue AArch64TargetLowering::LowerELFTLSLocalExec(const GlobalValue *GV,
9377 SDValue ThreadBase,
9378 const SDLoc &DL,
9379 SelectionDAG &DAG) const {
9380 EVT PtrVT = getPointerTy(DAG.getDataLayout());
9381 SDValue TPOff, Addr;
9382
9383 switch (DAG.getTarget().Options.TLSSize) {
9384 default:
9385 llvm_unreachable("Unexpected TLS size");
9386
9387 case 12: {
9388 // mrs x0, TPIDR_EL0
9389 // add x0, x0, :tprel_lo12:a
9390 SDValue Var = DAG.getTargetGlobalAddress(
9391 GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_PAGEOFF);
9392 return SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, ThreadBase,
9393 Var,
9394 DAG.getTargetConstant(0, DL, MVT::i32)),
9395 0);
9396 }
9397
9398 case 24: {
9399 // mrs x0, TPIDR_EL0
9400 // add x0, x0, :tprel_hi12:a
9401 // add x0, x0, :tprel_lo12_nc:a
9402 SDValue HiVar = DAG.getTargetGlobalAddress(
9403 GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_HI12);
9404 SDValue LoVar = DAG.getTargetGlobalAddress(
9405 GV, DL, PtrVT, 0,
9406 AArch64II::MO_TLS | AArch64II::MO_PAGEOFF | AArch64II::MO_NC);
9407 Addr = SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, ThreadBase,
9408 HiVar,
9409 DAG.getTargetConstant(0, DL, MVT::i32)),
9410 0);
9411 return SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, Addr,
9412 LoVar,
9413 DAG.getTargetConstant(0, DL, MVT::i32)),
9414 0);
9415 }
9416
9417 case 32: {
9418 // mrs x1, TPIDR_EL0
9419 // movz x0, #:tprel_g1:a
9420 // movk x0, #:tprel_g0_nc:a
9421 // add x0, x1, x0
9422 SDValue HiVar = DAG.getTargetGlobalAddress(
9423 GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_G1);
9424 SDValue LoVar = DAG.getTargetGlobalAddress(
9425 GV, DL, PtrVT, 0,
9426 AArch64II::MO_TLS | AArch64II::MO_G0 | AArch64II::MO_NC);
9427 TPOff = SDValue(DAG.getMachineNode(AArch64::MOVZXi, DL, PtrVT, HiVar,
9428 DAG.getTargetConstant(16, DL, MVT::i32)),
9429 0);
9430 TPOff = SDValue(DAG.getMachineNode(AArch64::MOVKXi, DL, PtrVT, TPOff, LoVar,
9431 DAG.getTargetConstant(0, DL, MVT::i32)),
9432 0);
9433 return DAG.getNode(ISD::ADD, DL, PtrVT, ThreadBase, TPOff);
9434 }
9435
9436 case 48: {
9437 // mrs x1, TPIDR_EL0
9438 // movz x0, #:tprel_g2:a
9439 // movk x0, #:tprel_g1_nc:a
9440 // movk x0, #:tprel_g0_nc:a
9441 // add x0, x1, x0
9442 SDValue HiVar = DAG.getTargetGlobalAddress(
9443 GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_G2);
9444 SDValue MiVar = DAG.getTargetGlobalAddress(
9445 GV, DL, PtrVT, 0,
9446 AArch64II::MO_TLS | AArch64II::MO_G1 | AArch64II::MO_NC);
9447 SDValue LoVar = DAG.getTargetGlobalAddress(
9448 GV, DL, PtrVT, 0,
9449 AArch64II::MO_TLS | AArch64II::MO_G0 | AArch64II::MO_NC);
9450 TPOff = SDValue(DAG.getMachineNode(AArch64::MOVZXi, DL, PtrVT, HiVar,
9451 DAG.getTargetConstant(32, DL, MVT::i32)),
9452 0);
9453 TPOff = SDValue(DAG.getMachineNode(AArch64::MOVKXi, DL, PtrVT, TPOff, MiVar,
9454 DAG.getTargetConstant(16, DL, MVT::i32)),
9455 0);
9456 TPOff = SDValue(DAG.getMachineNode(AArch64::MOVKXi, DL, PtrVT, TPOff, LoVar,
9457 DAG.getTargetConstant(0, DL, MVT::i32)),
9458 0);
9459 return DAG.getNode(ISD::ADD, DL, PtrVT, ThreadBase, TPOff);
9460 }
9461 }
9462 }
9463
9464 /// When accessing thread-local variables under either the general-dynamic or
9465 /// local-dynamic system, we make a "TLS-descriptor" call. The variable will
9466 /// have a descriptor, accessible via a PC-relative ADRP, and whose first entry
9467 /// is a function pointer to carry out the resolution.
9468 ///
9469 /// The sequence is:
9470 /// adrp x0, :tlsdesc:var
9471 /// ldr x1, [x0, #:tlsdesc_lo12:var]
9472 /// add x0, x0, #:tlsdesc_lo12:var
9473 /// .tlsdesccall var
9474 /// blr x1
9475 /// (TPIDR_EL0 offset now in x0)
9476 ///
9477 /// The above sequence must be produced unscheduled, to enable the linker to
9478 /// optimize/relax this sequence.
9479 /// Therefore, a pseudo-instruction (TLSDESC_CALLSEQ) is used to represent the
9480 /// above sequence, and expanded really late in the compilation flow, to ensure
9481 /// the sequence is produced as per above.
LowerELFTLSDescCallSeq(SDValue SymAddr,const SDLoc & DL,SelectionDAG & DAG) const9482 SDValue AArch64TargetLowering::LowerELFTLSDescCallSeq(SDValue SymAddr,
9483 const SDLoc &DL,
9484 SelectionDAG &DAG) const {
9485 EVT PtrVT = getPointerTy(DAG.getDataLayout());
9486
9487 SDValue Chain = DAG.getEntryNode();
9488 SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
9489
9490 Chain =
9491 DAG.getNode(AArch64ISD::TLSDESC_CALLSEQ, DL, NodeTys, {Chain, SymAddr});
9492 SDValue Glue = Chain.getValue(1);
9493
9494 return DAG.getCopyFromReg(Chain, DL, AArch64::X0, PtrVT, Glue);
9495 }
9496
9497 SDValue
LowerELFGlobalTLSAddress(SDValue Op,SelectionDAG & DAG) const9498 AArch64TargetLowering::LowerELFGlobalTLSAddress(SDValue Op,
9499 SelectionDAG &DAG) const {
9500 assert(Subtarget->isTargetELF() && "This function expects an ELF target");
9501
9502 const GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(Op);
9503
9504 TLSModel::Model Model = getTargetMachine().getTLSModel(GA->getGlobal());
9505
9506 if (!EnableAArch64ELFLocalDynamicTLSGeneration) {
9507 if (Model == TLSModel::LocalDynamic)
9508 Model = TLSModel::GeneralDynamic;
9509 }
9510
9511 if (getTargetMachine().getCodeModel() == CodeModel::Large &&
9512 Model != TLSModel::LocalExec)
9513 report_fatal_error("ELF TLS only supported in small memory model or "
9514 "in local exec TLS model");
9515 // Different choices can be made for the maximum size of the TLS area for a
9516 // module. For the small address model, the default TLS size is 16MiB and the
9517 // maximum TLS size is 4GiB.
9518 // FIXME: add tiny and large code model support for TLS access models other
9519 // than local exec. We currently generate the same code as small for tiny,
9520 // which may be larger than needed.
9521
9522 SDValue TPOff;
9523 EVT PtrVT = getPointerTy(DAG.getDataLayout());
9524 SDLoc DL(Op);
9525 const GlobalValue *GV = GA->getGlobal();
9526
9527 SDValue ThreadBase = DAG.getNode(AArch64ISD::THREAD_POINTER, DL, PtrVT);
9528
9529 if (Model == TLSModel::LocalExec) {
9530 return LowerELFTLSLocalExec(GV, ThreadBase, DL, DAG);
9531 } else if (Model == TLSModel::InitialExec) {
9532 TPOff = DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, AArch64II::MO_TLS);
9533 TPOff = DAG.getNode(AArch64ISD::LOADgot, DL, PtrVT, TPOff);
9534 } else if (Model == TLSModel::LocalDynamic) {
9535 // Local-dynamic accesses proceed in two phases. A general-dynamic TLS
9536 // descriptor call against the special symbol _TLS_MODULE_BASE_ to calculate
9537 // the beginning of the module's TLS region, followed by a DTPREL offset
9538 // calculation.
9539
9540 // These accesses will need deduplicating if there's more than one.
9541 AArch64FunctionInfo *MFI =
9542 DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
9543 MFI->incNumLocalDynamicTLSAccesses();
9544
9545 // The call needs a relocation too for linker relaxation. It doesn't make
9546 // sense to call it MO_PAGE or MO_PAGEOFF though so we need another copy of
9547 // the address.
9548 SDValue SymAddr = DAG.getTargetExternalSymbol("_TLS_MODULE_BASE_", PtrVT,
9549 AArch64II::MO_TLS);
9550
9551 // Now we can calculate the offset from TPIDR_EL0 to this module's
9552 // thread-local area.
9553 TPOff = LowerELFTLSDescCallSeq(SymAddr, DL, DAG);
9554
9555 // Now use :dtprel_whatever: operations to calculate this variable's offset
9556 // in its thread-storage area.
9557 SDValue HiVar = DAG.getTargetGlobalAddress(
9558 GV, DL, MVT::i64, 0, AArch64II::MO_TLS | AArch64II::MO_HI12);
9559 SDValue LoVar = DAG.getTargetGlobalAddress(
9560 GV, DL, MVT::i64, 0,
9561 AArch64II::MO_TLS | AArch64II::MO_PAGEOFF | AArch64II::MO_NC);
9562
9563 TPOff = SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, TPOff, HiVar,
9564 DAG.getTargetConstant(0, DL, MVT::i32)),
9565 0);
9566 TPOff = SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, TPOff, LoVar,
9567 DAG.getTargetConstant(0, DL, MVT::i32)),
9568 0);
9569 } else if (Model == TLSModel::GeneralDynamic) {
9570 // The call needs a relocation too for linker relaxation. It doesn't make
9571 // sense to call it MO_PAGE or MO_PAGEOFF though so we need another copy of
9572 // the address.
9573 SDValue SymAddr =
9574 DAG.getTargetGlobalAddress(GV, DL, PtrVT, 0, AArch64II::MO_TLS);
9575
9576 // Finally we can make a call to calculate the offset from tpidr_el0.
9577 TPOff = LowerELFTLSDescCallSeq(SymAddr, DL, DAG);
9578 } else
9579 llvm_unreachable("Unsupported ELF TLS access model");
9580
9581 return DAG.getNode(ISD::ADD, DL, PtrVT, ThreadBase, TPOff);
9582 }
9583
9584 SDValue
LowerWindowsGlobalTLSAddress(SDValue Op,SelectionDAG & DAG) const9585 AArch64TargetLowering::LowerWindowsGlobalTLSAddress(SDValue Op,
9586 SelectionDAG &DAG) const {
9587 assert(Subtarget->isTargetWindows() && "Windows specific TLS lowering");
9588
9589 SDValue Chain = DAG.getEntryNode();
9590 EVT PtrVT = getPointerTy(DAG.getDataLayout());
9591 SDLoc DL(Op);
9592
9593 SDValue TEB = DAG.getRegister(AArch64::X18, MVT::i64);
9594
9595 // Load the ThreadLocalStoragePointer from the TEB
9596 // A pointer to the TLS array is located at offset 0x58 from the TEB.
9597 SDValue TLSArray =
9598 DAG.getNode(ISD::ADD, DL, PtrVT, TEB, DAG.getIntPtrConstant(0x58, DL));
9599 TLSArray = DAG.getLoad(PtrVT, DL, Chain, TLSArray, MachinePointerInfo());
9600 Chain = TLSArray.getValue(1);
9601
9602 // Load the TLS index from the C runtime;
9603 // This does the same as getAddr(), but without having a GlobalAddressSDNode.
9604 // This also does the same as LOADgot, but using a generic i32 load,
9605 // while LOADgot only loads i64.
9606 SDValue TLSIndexHi =
9607 DAG.getTargetExternalSymbol("_tls_index", PtrVT, AArch64II::MO_PAGE);
9608 SDValue TLSIndexLo = DAG.getTargetExternalSymbol(
9609 "_tls_index", PtrVT, AArch64II::MO_PAGEOFF | AArch64II::MO_NC);
9610 SDValue ADRP = DAG.getNode(AArch64ISD::ADRP, DL, PtrVT, TLSIndexHi);
9611 SDValue TLSIndex =
9612 DAG.getNode(AArch64ISD::ADDlow, DL, PtrVT, ADRP, TLSIndexLo);
9613 TLSIndex = DAG.getLoad(MVT::i32, DL, Chain, TLSIndex, MachinePointerInfo());
9614 Chain = TLSIndex.getValue(1);
9615
9616 // The pointer to the thread's TLS data area is at the TLS Index scaled by 8
9617 // offset into the TLSArray.
9618 TLSIndex = DAG.getNode(ISD::ZERO_EXTEND, DL, PtrVT, TLSIndex);
9619 SDValue Slot = DAG.getNode(ISD::SHL, DL, PtrVT, TLSIndex,
9620 DAG.getConstant(3, DL, PtrVT));
9621 SDValue TLS = DAG.getLoad(PtrVT, DL, Chain,
9622 DAG.getNode(ISD::ADD, DL, PtrVT, TLSArray, Slot),
9623 MachinePointerInfo());
9624 Chain = TLS.getValue(1);
9625
9626 const GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(Op);
9627 const GlobalValue *GV = GA->getGlobal();
9628 SDValue TGAHi = DAG.getTargetGlobalAddress(
9629 GV, DL, PtrVT, 0, AArch64II::MO_TLS | AArch64II::MO_HI12);
9630 SDValue TGALo = DAG.getTargetGlobalAddress(
9631 GV, DL, PtrVT, 0,
9632 AArch64II::MO_TLS | AArch64II::MO_PAGEOFF | AArch64II::MO_NC);
9633
9634 // Add the offset from the start of the .tls section (section base).
9635 SDValue Addr =
9636 SDValue(DAG.getMachineNode(AArch64::ADDXri, DL, PtrVT, TLS, TGAHi,
9637 DAG.getTargetConstant(0, DL, MVT::i32)),
9638 0);
9639 Addr = DAG.getNode(AArch64ISD::ADDlow, DL, PtrVT, Addr, TGALo);
9640 return Addr;
9641 }
9642
LowerGlobalTLSAddress(SDValue Op,SelectionDAG & DAG) const9643 SDValue AArch64TargetLowering::LowerGlobalTLSAddress(SDValue Op,
9644 SelectionDAG &DAG) const {
9645 const GlobalAddressSDNode *GA = cast<GlobalAddressSDNode>(Op);
9646 if (DAG.getTarget().useEmulatedTLS())
9647 return LowerToTLSEmulatedModel(GA, DAG);
9648
9649 if (Subtarget->isTargetDarwin())
9650 return LowerDarwinGlobalTLSAddress(Op, DAG);
9651 if (Subtarget->isTargetELF())
9652 return LowerELFGlobalTLSAddress(Op, DAG);
9653 if (Subtarget->isTargetWindows())
9654 return LowerWindowsGlobalTLSAddress(Op, DAG);
9655
9656 llvm_unreachable("Unexpected platform trying to use TLS");
9657 }
9658
9659 //===----------------------------------------------------------------------===//
9660 // PtrAuthGlobalAddress lowering
9661 //
9662 // We have 3 lowering alternatives to choose from:
9663 // - MOVaddrPAC: similar to MOVaddr, with added PAC.
9664 // If the GV doesn't need a GOT load (i.e., is locally defined)
9665 // materialize the pointer using adrp+add+pac. See LowerMOVaddrPAC.
9666 //
9667 // - LOADgotPAC: similar to LOADgot, with added PAC.
9668 // If the GV needs a GOT load, materialize the pointer using the usual
9669 // GOT adrp+ldr, +pac. Pointers in GOT are assumed to be not signed, the GOT
9670 // section is assumed to be read-only (for example, via relro mechanism). See
9671 // LowerMOVaddrPAC.
9672 //
9673 // - LOADauthptrstatic: similar to LOADgot, but use a
9674 // special stub slot instead of a GOT slot.
9675 // Load a signed pointer for symbol 'sym' from a stub slot named
9676 // 'sym$auth_ptr$key$disc' filled by dynamic linker during relocation
9677 // resolving. This usually lowers to adrp+ldr, but also emits an entry into
9678 // .data with an @AUTH relocation. See LowerLOADauthptrstatic.
9679 //
9680 // All 3 are pseudos that are expand late to longer sequences: this lets us
9681 // provide integrity guarantees on the to-be-signed intermediate values.
9682 //
9683 // LOADauthptrstatic is undesirable because it requires a large section filled
9684 // with often similarly-signed pointers, making it a good harvesting target.
9685 // Thus, it's only used for ptrauth references to extern_weak to avoid null
9686 // checks.
9687
LowerPtrAuthGlobalAddressStatically(SDValue TGA,SDLoc DL,EVT VT,AArch64PACKey::ID KeyC,SDValue Discriminator,SDValue AddrDiscriminator,SelectionDAG & DAG) const9688 SDValue AArch64TargetLowering::LowerPtrAuthGlobalAddressStatically(
9689 SDValue TGA, SDLoc DL, EVT VT, AArch64PACKey::ID KeyC,
9690 SDValue Discriminator, SDValue AddrDiscriminator, SelectionDAG &DAG) const {
9691 const auto *TGN = cast<GlobalAddressSDNode>(TGA.getNode());
9692 assert(TGN->getGlobal()->hasExternalWeakLinkage());
9693
9694 // Offsets and extern_weak don't mix well: ptrauth aside, you'd get the
9695 // offset alone as a pointer if the symbol wasn't available, which would
9696 // probably break null checks in users. Ptrauth complicates things further:
9697 // error out.
9698 if (TGN->getOffset() != 0)
9699 report_fatal_error(
9700 "unsupported non-zero offset in weak ptrauth global reference");
9701
9702 if (!isNullConstant(AddrDiscriminator))
9703 report_fatal_error("unsupported weak addr-div ptrauth global");
9704
9705 SDValue Key = DAG.getTargetConstant(KeyC, DL, MVT::i32);
9706 return SDValue(DAG.getMachineNode(AArch64::LOADauthptrstatic, DL, MVT::i64,
9707 {TGA, Key, Discriminator}),
9708 0);
9709 }
9710
9711 SDValue
LowerPtrAuthGlobalAddress(SDValue Op,SelectionDAG & DAG) const9712 AArch64TargetLowering::LowerPtrAuthGlobalAddress(SDValue Op,
9713 SelectionDAG &DAG) const {
9714 SDValue Ptr = Op.getOperand(0);
9715 uint64_t KeyC = Op.getConstantOperandVal(1);
9716 SDValue AddrDiscriminator = Op.getOperand(2);
9717 uint64_t DiscriminatorC = Op.getConstantOperandVal(3);
9718 EVT VT = Op.getValueType();
9719 SDLoc DL(Op);
9720
9721 if (KeyC > AArch64PACKey::LAST)
9722 report_fatal_error("key in ptrauth global out of range [0, " +
9723 Twine((int)AArch64PACKey::LAST) + "]");
9724
9725 // Blend only works if the integer discriminator is 16-bit wide.
9726 if (!isUInt<16>(DiscriminatorC))
9727 report_fatal_error(
9728 "constant discriminator in ptrauth global out of range [0, 0xffff]");
9729
9730 // Choosing between 3 lowering alternatives is target-specific.
9731 if (!Subtarget->isTargetELF() && !Subtarget->isTargetMachO())
9732 report_fatal_error("ptrauth global lowering only supported on MachO/ELF");
9733
9734 int64_t PtrOffsetC = 0;
9735 if (Ptr.getOpcode() == ISD::ADD) {
9736 PtrOffsetC = Ptr.getConstantOperandVal(1);
9737 Ptr = Ptr.getOperand(0);
9738 }
9739 const auto *PtrN = cast<GlobalAddressSDNode>(Ptr.getNode());
9740 const GlobalValue *PtrGV = PtrN->getGlobal();
9741
9742 // Classify the reference to determine whether it needs a GOT load.
9743 const unsigned OpFlags =
9744 Subtarget->ClassifyGlobalReference(PtrGV, getTargetMachine());
9745 const bool NeedsGOTLoad = ((OpFlags & AArch64II::MO_GOT) != 0);
9746 assert(((OpFlags & (~AArch64II::MO_GOT)) == 0) &&
9747 "unsupported non-GOT op flags on ptrauth global reference");
9748
9749 // Fold any offset into the GV; our pseudos expect it there.
9750 PtrOffsetC += PtrN->getOffset();
9751 SDValue TPtr = DAG.getTargetGlobalAddress(PtrGV, DL, VT, PtrOffsetC,
9752 /*TargetFlags=*/0);
9753 assert(PtrN->getTargetFlags() == 0 &&
9754 "unsupported target flags on ptrauth global");
9755
9756 SDValue Key = DAG.getTargetConstant(KeyC, DL, MVT::i32);
9757 SDValue Discriminator = DAG.getTargetConstant(DiscriminatorC, DL, MVT::i64);
9758 SDValue TAddrDiscriminator = !isNullConstant(AddrDiscriminator)
9759 ? AddrDiscriminator
9760 : DAG.getRegister(AArch64::XZR, MVT::i64);
9761
9762 // No GOT load needed -> MOVaddrPAC
9763 if (!NeedsGOTLoad) {
9764 assert(!PtrGV->hasExternalWeakLinkage() && "extern_weak should use GOT");
9765 return SDValue(
9766 DAG.getMachineNode(AArch64::MOVaddrPAC, DL, MVT::i64,
9767 {TPtr, Key, TAddrDiscriminator, Discriminator}),
9768 0);
9769 }
9770
9771 // GOT load -> LOADgotPAC
9772 // Note that we disallow extern_weak refs to avoid null checks later.
9773 if (!PtrGV->hasExternalWeakLinkage())
9774 return SDValue(
9775 DAG.getMachineNode(AArch64::LOADgotPAC, DL, MVT::i64,
9776 {TPtr, Key, TAddrDiscriminator, Discriminator}),
9777 0);
9778
9779 // extern_weak ref -> LOADauthptrstatic
9780 return LowerPtrAuthGlobalAddressStatically(
9781 TPtr, DL, VT, (AArch64PACKey::ID)KeyC, Discriminator, AddrDiscriminator,
9782 DAG);
9783 }
9784
9785 // Looks through \param Val to determine the bit that can be used to
9786 // check the sign of the value. It returns the unextended value and
9787 // the sign bit position.
lookThroughSignExtension(SDValue Val)9788 std::pair<SDValue, uint64_t> lookThroughSignExtension(SDValue Val) {
9789 if (Val.getOpcode() == ISD::SIGN_EXTEND_INREG)
9790 return {Val.getOperand(0),
9791 cast<VTSDNode>(Val.getOperand(1))->getVT().getFixedSizeInBits() -
9792 1};
9793
9794 if (Val.getOpcode() == ISD::SIGN_EXTEND)
9795 return {Val.getOperand(0),
9796 Val.getOperand(0)->getValueType(0).getFixedSizeInBits() - 1};
9797
9798 return {Val, Val.getValueSizeInBits() - 1};
9799 }
9800
LowerBR_CC(SDValue Op,SelectionDAG & DAG) const9801 SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const {
9802 SDValue Chain = Op.getOperand(0);
9803 ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(1))->get();
9804 SDValue LHS = Op.getOperand(2);
9805 SDValue RHS = Op.getOperand(3);
9806 SDValue Dest = Op.getOperand(4);
9807 SDLoc dl(Op);
9808
9809 MachineFunction &MF = DAG.getMachineFunction();
9810 // Speculation tracking/SLH assumes that optimized TB(N)Z/CB(N)Z instructions
9811 // will not be produced, as they are conditional branch instructions that do
9812 // not set flags.
9813 bool ProduceNonFlagSettingCondBr =
9814 !MF.getFunction().hasFnAttribute(Attribute::SpeculativeLoadHardening);
9815
9816 // Handle f128 first, since lowering it will result in comparing the return
9817 // value of a libcall against zero, which is just what the rest of LowerBR_CC
9818 // is expecting to deal with.
9819 if (LHS.getValueType() == MVT::f128) {
9820 softenSetCCOperands(DAG, MVT::f128, LHS, RHS, CC, dl, LHS, RHS);
9821
9822 // If softenSetCCOperands returned a scalar, we need to compare the result
9823 // against zero to select between true and false values.
9824 if (!RHS.getNode()) {
9825 RHS = DAG.getConstant(0, dl, LHS.getValueType());
9826 CC = ISD::SETNE;
9827 }
9828 }
9829
9830 // Optimize {s|u}{add|sub|mul}.with.overflow feeding into a branch
9831 // instruction.
9832 if (ISD::isOverflowIntrOpRes(LHS) && isOneConstant(RHS) &&
9833 (CC == ISD::SETEQ || CC == ISD::SETNE)) {
9834 // Only lower legal XALUO ops.
9835 if (!DAG.getTargetLoweringInfo().isTypeLegal(LHS->getValueType(0)))
9836 return SDValue();
9837
9838 // The actual operation with overflow check.
9839 AArch64CC::CondCode OFCC;
9840 SDValue Value, Overflow;
9841 std::tie(Value, Overflow) = getAArch64XALUOOp(OFCC, LHS.getValue(0), DAG);
9842
9843 if (CC == ISD::SETNE)
9844 OFCC = getInvertedCondCode(OFCC);
9845 SDValue CCVal = DAG.getConstant(OFCC, dl, MVT::i32);
9846
9847 return DAG.getNode(AArch64ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
9848 Overflow);
9849 }
9850
9851 if (LHS.getValueType().isInteger()) {
9852 assert((LHS.getValueType() == RHS.getValueType()) &&
9853 (LHS.getValueType() == MVT::i32 || LHS.getValueType() == MVT::i64));
9854
9855 // If the RHS of the comparison is zero, we can potentially fold this
9856 // to a specialized branch.
9857 const ConstantSDNode *RHSC = dyn_cast<ConstantSDNode>(RHS);
9858 if (RHSC && RHSC->getZExtValue() == 0 && ProduceNonFlagSettingCondBr) {
9859 if (CC == ISD::SETEQ) {
9860 // See if we can use a TBZ to fold in an AND as well.
9861 // TBZ has a smaller branch displacement than CBZ. If the offset is
9862 // out of bounds, a late MI-layer pass rewrites branches.
9863 // 403.gcc is an example that hits this case.
9864 if (LHS.getOpcode() == ISD::AND &&
9865 isa<ConstantSDNode>(LHS.getOperand(1)) &&
9866 isPowerOf2_64(LHS.getConstantOperandVal(1))) {
9867 SDValue Test = LHS.getOperand(0);
9868 uint64_t Mask = LHS.getConstantOperandVal(1);
9869 return DAG.getNode(AArch64ISD::TBZ, dl, MVT::Other, Chain, Test,
9870 DAG.getConstant(Log2_64(Mask), dl, MVT::i64),
9871 Dest);
9872 }
9873
9874 return DAG.getNode(AArch64ISD::CBZ, dl, MVT::Other, Chain, LHS, Dest);
9875 } else if (CC == ISD::SETNE) {
9876 // See if we can use a TBZ to fold in an AND as well.
9877 // TBZ has a smaller branch displacement than CBZ. If the offset is
9878 // out of bounds, a late MI-layer pass rewrites branches.
9879 // 403.gcc is an example that hits this case.
9880 if (LHS.getOpcode() == ISD::AND &&
9881 isa<ConstantSDNode>(LHS.getOperand(1)) &&
9882 isPowerOf2_64(LHS.getConstantOperandVal(1))) {
9883 SDValue Test = LHS.getOperand(0);
9884 uint64_t Mask = LHS.getConstantOperandVal(1);
9885 return DAG.getNode(AArch64ISD::TBNZ, dl, MVT::Other, Chain, Test,
9886 DAG.getConstant(Log2_64(Mask), dl, MVT::i64),
9887 Dest);
9888 }
9889
9890 return DAG.getNode(AArch64ISD::CBNZ, dl, MVT::Other, Chain, LHS, Dest);
9891 } else if (CC == ISD::SETLT && LHS.getOpcode() != ISD::AND) {
9892 // Don't combine AND since emitComparison converts the AND to an ANDS
9893 // (a.k.a. TST) and the test in the test bit and branch instruction
9894 // becomes redundant. This would also increase register pressure.
9895 uint64_t SignBitPos;
9896 std::tie(LHS, SignBitPos) = lookThroughSignExtension(LHS);
9897 return DAG.getNode(AArch64ISD::TBNZ, dl, MVT::Other, Chain, LHS,
9898 DAG.getConstant(SignBitPos, dl, MVT::i64), Dest);
9899 }
9900 }
9901 if (RHSC && RHSC->getSExtValue() == -1 && CC == ISD::SETGT &&
9902 LHS.getOpcode() != ISD::AND && ProduceNonFlagSettingCondBr) {
9903 // Don't combine AND since emitComparison converts the AND to an ANDS
9904 // (a.k.a. TST) and the test in the test bit and branch instruction
9905 // becomes redundant. This would also increase register pressure.
9906 uint64_t SignBitPos;
9907 std::tie(LHS, SignBitPos) = lookThroughSignExtension(LHS);
9908 return DAG.getNode(AArch64ISD::TBZ, dl, MVT::Other, Chain, LHS,
9909 DAG.getConstant(SignBitPos, dl, MVT::i64), Dest);
9910 }
9911
9912 SDValue CCVal;
9913 SDValue Cmp = getAArch64Cmp(LHS, RHS, CC, CCVal, DAG, dl);
9914 return DAG.getNode(AArch64ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,
9915 Cmp);
9916 }
9917
9918 assert(LHS.getValueType() == MVT::f16 || LHS.getValueType() == MVT::bf16 ||
9919 LHS.getValueType() == MVT::f32 || LHS.getValueType() == MVT::f64);
9920
9921 // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
9922 // clean. Some of them require two branches to implement.
9923 SDValue Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
9924 AArch64CC::CondCode CC1, CC2;
9925 changeFPCCToAArch64CC(CC, CC1, CC2);
9926 SDValue CC1Val = DAG.getConstant(CC1, dl, MVT::i32);
9927 SDValue BR1 =
9928 DAG.getNode(AArch64ISD::BRCOND, dl, MVT::Other, Chain, Dest, CC1Val, Cmp);
9929 if (CC2 != AArch64CC::AL) {
9930 SDValue CC2Val = DAG.getConstant(CC2, dl, MVT::i32);
9931 return DAG.getNode(AArch64ISD::BRCOND, dl, MVT::Other, BR1, Dest, CC2Val,
9932 Cmp);
9933 }
9934
9935 return BR1;
9936 }
9937
LowerFCOPYSIGN(SDValue Op,SelectionDAG & DAG) const9938 SDValue AArch64TargetLowering::LowerFCOPYSIGN(SDValue Op,
9939 SelectionDAG &DAG) const {
9940 if (!Subtarget->isNeonAvailable() &&
9941 !Subtarget->useSVEForFixedLengthVectors())
9942 return SDValue();
9943
9944 EVT VT = Op.getValueType();
9945 EVT IntVT = VT.changeTypeToInteger();
9946 SDLoc DL(Op);
9947
9948 SDValue In1 = Op.getOperand(0);
9949 SDValue In2 = Op.getOperand(1);
9950 EVT SrcVT = In2.getValueType();
9951
9952 if (!SrcVT.bitsEq(VT))
9953 In2 = DAG.getFPExtendOrRound(In2, DL, VT);
9954
9955 if (VT.isScalableVector())
9956 IntVT =
9957 getPackedSVEVectorVT(VT.getVectorElementType().changeTypeToInteger());
9958
9959 if (VT.isFixedLengthVector() &&
9960 useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) {
9961 EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
9962
9963 In1 = convertToScalableVector(DAG, ContainerVT, In1);
9964 In2 = convertToScalableVector(DAG, ContainerVT, In2);
9965
9966 SDValue Res = DAG.getNode(ISD::FCOPYSIGN, DL, ContainerVT, In1, In2);
9967 return convertFromScalableVector(DAG, VT, Res);
9968 }
9969
9970 auto BitCast = [this](EVT VT, SDValue Op, SelectionDAG &DAG) {
9971 if (VT.isScalableVector())
9972 return getSVESafeBitCast(VT, Op, DAG);
9973
9974 return DAG.getBitcast(VT, Op);
9975 };
9976
9977 SDValue VecVal1, VecVal2;
9978 EVT VecVT;
9979 auto SetVecVal = [&](int Idx = -1) {
9980 if (!VT.isVector()) {
9981 VecVal1 =
9982 DAG.getTargetInsertSubreg(Idx, DL, VecVT, DAG.getUNDEF(VecVT), In1);
9983 VecVal2 =
9984 DAG.getTargetInsertSubreg(Idx, DL, VecVT, DAG.getUNDEF(VecVT), In2);
9985 } else {
9986 VecVal1 = BitCast(VecVT, In1, DAG);
9987 VecVal2 = BitCast(VecVT, In2, DAG);
9988 }
9989 };
9990 if (VT.isVector()) {
9991 VecVT = IntVT;
9992 SetVecVal();
9993 } else if (VT == MVT::f64) {
9994 VecVT = MVT::v2i64;
9995 SetVecVal(AArch64::dsub);
9996 } else if (VT == MVT::f32) {
9997 VecVT = MVT::v4i32;
9998 SetVecVal(AArch64::ssub);
9999 } else if (VT == MVT::f16 || VT == MVT::bf16) {
10000 VecVT = MVT::v8i16;
10001 SetVecVal(AArch64::hsub);
10002 } else {
10003 llvm_unreachable("Invalid type for copysign!");
10004 }
10005
10006 unsigned BitWidth = In1.getScalarValueSizeInBits();
10007 SDValue SignMaskV = DAG.getConstant(~APInt::getSignMask(BitWidth), DL, VecVT);
10008
10009 // We want to materialize a mask with every bit but the high bit set, but the
10010 // AdvSIMD immediate moves cannot materialize that in a single instruction for
10011 // 64-bit elements. Instead, materialize all bits set and then negate that.
10012 if (VT == MVT::f64 || VT == MVT::v2f64) {
10013 SignMaskV = DAG.getConstant(APInt::getAllOnes(BitWidth), DL, VecVT);
10014 SignMaskV = DAG.getNode(ISD::BITCAST, DL, MVT::v2f64, SignMaskV);
10015 SignMaskV = DAG.getNode(ISD::FNEG, DL, MVT::v2f64, SignMaskV);
10016 SignMaskV = DAG.getNode(ISD::BITCAST, DL, MVT::v2i64, SignMaskV);
10017 }
10018
10019 SDValue BSP =
10020 DAG.getNode(AArch64ISD::BSP, DL, VecVT, SignMaskV, VecVal1, VecVal2);
10021 if (VT == MVT::f16 || VT == MVT::bf16)
10022 return DAG.getTargetExtractSubreg(AArch64::hsub, DL, VT, BSP);
10023 if (VT == MVT::f32)
10024 return DAG.getTargetExtractSubreg(AArch64::ssub, DL, VT, BSP);
10025 if (VT == MVT::f64)
10026 return DAG.getTargetExtractSubreg(AArch64::dsub, DL, VT, BSP);
10027
10028 return BitCast(VT, BSP, DAG);
10029 }
10030
LowerCTPOP_PARITY(SDValue Op,SelectionDAG & DAG) const10031 SDValue AArch64TargetLowering::LowerCTPOP_PARITY(SDValue Op,
10032 SelectionDAG &DAG) const {
10033 if (DAG.getMachineFunction().getFunction().hasFnAttribute(
10034 Attribute::NoImplicitFloat))
10035 return SDValue();
10036
10037 EVT VT = Op.getValueType();
10038 if (VT.isScalableVector() ||
10039 useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
10040 return LowerToPredicatedOp(Op, DAG, AArch64ISD::CTPOP_MERGE_PASSTHRU);
10041
10042 if (!Subtarget->isNeonAvailable())
10043 return SDValue();
10044
10045 bool IsParity = Op.getOpcode() == ISD::PARITY;
10046 SDValue Val = Op.getOperand(0);
10047 SDLoc DL(Op);
10048
10049 // for i32, general parity function using EORs is more efficient compared to
10050 // using floating point
10051 if (VT == MVT::i32 && IsParity)
10052 return SDValue();
10053
10054 // If there is no CNT instruction available, GPR popcount can
10055 // be more efficiently lowered to the following sequence that uses
10056 // AdvSIMD registers/instructions as long as the copies to/from
10057 // the AdvSIMD registers are cheap.
10058 // FMOV D0, X0 // copy 64-bit int to vector, high bits zero'd
10059 // CNT V0.8B, V0.8B // 8xbyte pop-counts
10060 // ADDV B0, V0.8B // sum 8xbyte pop-counts
10061 // UMOV X0, V0.B[0] // copy byte result back to integer reg
10062 if (VT == MVT::i32 || VT == MVT::i64) {
10063 if (VT == MVT::i32)
10064 Val = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Val);
10065 Val = DAG.getNode(ISD::BITCAST, DL, MVT::v8i8, Val);
10066
10067 SDValue CtPop = DAG.getNode(ISD::CTPOP, DL, MVT::v8i8, Val);
10068 SDValue UaddLV = DAG.getNode(AArch64ISD::UADDLV, DL, MVT::v4i32, CtPop);
10069 UaddLV = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32, UaddLV,
10070 DAG.getConstant(0, DL, MVT::i64));
10071
10072 if (IsParity)
10073 UaddLV = DAG.getNode(ISD::AND, DL, MVT::i32, UaddLV,
10074 DAG.getConstant(1, DL, MVT::i32));
10075
10076 if (VT == MVT::i64)
10077 UaddLV = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, UaddLV);
10078 return UaddLV;
10079 } else if (VT == MVT::i128) {
10080 Val = DAG.getNode(ISD::BITCAST, DL, MVT::v16i8, Val);
10081
10082 SDValue CtPop = DAG.getNode(ISD::CTPOP, DL, MVT::v16i8, Val);
10083 SDValue UaddLV = DAG.getNode(AArch64ISD::UADDLV, DL, MVT::v4i32, CtPop);
10084 UaddLV = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i32, UaddLV,
10085 DAG.getConstant(0, DL, MVT::i64));
10086
10087 if (IsParity)
10088 UaddLV = DAG.getNode(ISD::AND, DL, MVT::i32, UaddLV,
10089 DAG.getConstant(1, DL, MVT::i32));
10090
10091 return DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i128, UaddLV);
10092 }
10093
10094 assert(!IsParity && "ISD::PARITY of vector types not supported");
10095
10096 assert((VT == MVT::v1i64 || VT == MVT::v2i64 || VT == MVT::v2i32 ||
10097 VT == MVT::v4i32 || VT == MVT::v4i16 || VT == MVT::v8i16) &&
10098 "Unexpected type for custom ctpop lowering");
10099
10100 EVT VT8Bit = VT.is64BitVector() ? MVT::v8i8 : MVT::v16i8;
10101 Val = DAG.getBitcast(VT8Bit, Val);
10102 Val = DAG.getNode(ISD::CTPOP, DL, VT8Bit, Val);
10103
10104 if (Subtarget->hasDotProd() && VT.getScalarSizeInBits() != 16 &&
10105 VT.getVectorNumElements() >= 2) {
10106 EVT DT = VT == MVT::v2i64 ? MVT::v4i32 : VT;
10107 SDValue Zeros = DAG.getConstant(0, DL, DT);
10108 SDValue Ones = DAG.getConstant(1, DL, VT8Bit);
10109
10110 if (VT == MVT::v2i64) {
10111 Val = DAG.getNode(AArch64ISD::UDOT, DL, DT, Zeros, Ones, Val);
10112 Val = DAG.getNode(AArch64ISD::UADDLP, DL, VT, Val);
10113 } else if (VT == MVT::v2i32) {
10114 Val = DAG.getNode(AArch64ISD::UDOT, DL, DT, Zeros, Ones, Val);
10115 } else if (VT == MVT::v4i32) {
10116 Val = DAG.getNode(AArch64ISD::UDOT, DL, DT, Zeros, Ones, Val);
10117 } else {
10118 llvm_unreachable("Unexpected type for custom ctpop lowering");
10119 }
10120
10121 return Val;
10122 }
10123
10124 // Widen v8i8/v16i8 CTPOP result to VT by repeatedly widening pairwise adds.
10125 unsigned EltSize = 8;
10126 unsigned NumElts = VT.is64BitVector() ? 8 : 16;
10127 while (EltSize != VT.getScalarSizeInBits()) {
10128 EltSize *= 2;
10129 NumElts /= 2;
10130 MVT WidenVT = MVT::getVectorVT(MVT::getIntegerVT(EltSize), NumElts);
10131 Val = DAG.getNode(AArch64ISD::UADDLP, DL, WidenVT, Val);
10132 }
10133
10134 return Val;
10135 }
10136
LowerCTTZ(SDValue Op,SelectionDAG & DAG) const10137 SDValue AArch64TargetLowering::LowerCTTZ(SDValue Op, SelectionDAG &DAG) const {
10138 EVT VT = Op.getValueType();
10139 assert(VT.isScalableVector() ||
10140 useSVEForFixedLengthVectorVT(
10141 VT, /*OverrideNEON=*/Subtarget->useSVEForFixedLengthVectors()));
10142
10143 SDLoc DL(Op);
10144 SDValue RBIT = DAG.getNode(ISD::BITREVERSE, DL, VT, Op.getOperand(0));
10145 return DAG.getNode(ISD::CTLZ, DL, VT, RBIT);
10146 }
10147
LowerMinMax(SDValue Op,SelectionDAG & DAG) const10148 SDValue AArch64TargetLowering::LowerMinMax(SDValue Op,
10149 SelectionDAG &DAG) const {
10150
10151 EVT VT = Op.getValueType();
10152 SDLoc DL(Op);
10153 unsigned Opcode = Op.getOpcode();
10154 ISD::CondCode CC;
10155 switch (Opcode) {
10156 default:
10157 llvm_unreachable("Wrong instruction");
10158 case ISD::SMAX:
10159 CC = ISD::SETGT;
10160 break;
10161 case ISD::SMIN:
10162 CC = ISD::SETLT;
10163 break;
10164 case ISD::UMAX:
10165 CC = ISD::SETUGT;
10166 break;
10167 case ISD::UMIN:
10168 CC = ISD::SETULT;
10169 break;
10170 }
10171
10172 if (VT.isScalableVector() ||
10173 useSVEForFixedLengthVectorVT(
10174 VT, /*OverrideNEON=*/Subtarget->useSVEForFixedLengthVectors())) {
10175 switch (Opcode) {
10176 default:
10177 llvm_unreachable("Wrong instruction");
10178 case ISD::SMAX:
10179 return LowerToPredicatedOp(Op, DAG, AArch64ISD::SMAX_PRED);
10180 case ISD::SMIN:
10181 return LowerToPredicatedOp(Op, DAG, AArch64ISD::SMIN_PRED);
10182 case ISD::UMAX:
10183 return LowerToPredicatedOp(Op, DAG, AArch64ISD::UMAX_PRED);
10184 case ISD::UMIN:
10185 return LowerToPredicatedOp(Op, DAG, AArch64ISD::UMIN_PRED);
10186 }
10187 }
10188
10189 SDValue Op0 = Op.getOperand(0);
10190 SDValue Op1 = Op.getOperand(1);
10191 SDValue Cond = DAG.getSetCC(DL, VT, Op0, Op1, CC);
10192 return DAG.getSelect(DL, VT, Cond, Op0, Op1);
10193 }
10194
LowerBitreverse(SDValue Op,SelectionDAG & DAG) const10195 SDValue AArch64TargetLowering::LowerBitreverse(SDValue Op,
10196 SelectionDAG &DAG) const {
10197 EVT VT = Op.getValueType();
10198
10199 if (VT.isScalableVector() ||
10200 useSVEForFixedLengthVectorVT(
10201 VT, /*OverrideNEON=*/Subtarget->useSVEForFixedLengthVectors()))
10202 return LowerToPredicatedOp(Op, DAG, AArch64ISD::BITREVERSE_MERGE_PASSTHRU);
10203
10204 SDLoc DL(Op);
10205 SDValue REVB;
10206 MVT VST;
10207
10208 switch (VT.getSimpleVT().SimpleTy) {
10209 default:
10210 llvm_unreachable("Invalid type for bitreverse!");
10211
10212 case MVT::v2i32: {
10213 VST = MVT::v8i8;
10214 REVB = DAG.getNode(AArch64ISD::REV32, DL, VST, Op.getOperand(0));
10215
10216 break;
10217 }
10218
10219 case MVT::v4i32: {
10220 VST = MVT::v16i8;
10221 REVB = DAG.getNode(AArch64ISD::REV32, DL, VST, Op.getOperand(0));
10222
10223 break;
10224 }
10225
10226 case MVT::v1i64: {
10227 VST = MVT::v8i8;
10228 REVB = DAG.getNode(AArch64ISD::REV64, DL, VST, Op.getOperand(0));
10229
10230 break;
10231 }
10232
10233 case MVT::v2i64: {
10234 VST = MVT::v16i8;
10235 REVB = DAG.getNode(AArch64ISD::REV64, DL, VST, Op.getOperand(0));
10236
10237 break;
10238 }
10239 }
10240
10241 return DAG.getNode(AArch64ISD::NVCAST, DL, VT,
10242 DAG.getNode(ISD::BITREVERSE, DL, VST, REVB));
10243 }
10244
10245 // Check whether the continuous comparison sequence.
10246 static bool
isOrXorChain(SDValue N,unsigned & Num,SmallVector<std::pair<SDValue,SDValue>,16> & WorkList)10247 isOrXorChain(SDValue N, unsigned &Num,
10248 SmallVector<std::pair<SDValue, SDValue>, 16> &WorkList) {
10249 if (Num == MaxXors)
10250 return false;
10251
10252 // Skip the one-use zext
10253 if (N->getOpcode() == ISD::ZERO_EXTEND && N->hasOneUse())
10254 N = N->getOperand(0);
10255
10256 // The leaf node must be XOR
10257 if (N->getOpcode() == ISD::XOR) {
10258 WorkList.push_back(std::make_pair(N->getOperand(0), N->getOperand(1)));
10259 Num++;
10260 return true;
10261 }
10262
10263 // All the non-leaf nodes must be OR.
10264 if (N->getOpcode() != ISD::OR || !N->hasOneUse())
10265 return false;
10266
10267 if (isOrXorChain(N->getOperand(0), Num, WorkList) &&
10268 isOrXorChain(N->getOperand(1), Num, WorkList))
10269 return true;
10270 return false;
10271 }
10272
10273 // Transform chains of ORs and XORs, which usually outlined by memcmp/bmp.
performOrXorChainCombine(SDNode * N,SelectionDAG & DAG)10274 static SDValue performOrXorChainCombine(SDNode *N, SelectionDAG &DAG) {
10275 SDValue LHS = N->getOperand(0);
10276 SDValue RHS = N->getOperand(1);
10277 SDLoc DL(N);
10278 EVT VT = N->getValueType(0);
10279 SmallVector<std::pair<SDValue, SDValue>, 16> WorkList;
10280
10281 // Only handle integer compares.
10282 if (N->getOpcode() != ISD::SETCC)
10283 return SDValue();
10284
10285 ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
10286 // Try to express conjunction "cmp 0 (or (xor A0 A1) (xor B0 B1))" as:
10287 // sub A0, A1; ccmp B0, B1, 0, eq; cmp inv(Cond) flag
10288 unsigned NumXors = 0;
10289 if ((Cond == ISD::SETEQ || Cond == ISD::SETNE) && isNullConstant(RHS) &&
10290 LHS->getOpcode() == ISD::OR && LHS->hasOneUse() &&
10291 isOrXorChain(LHS, NumXors, WorkList)) {
10292 SDValue XOR0, XOR1;
10293 std::tie(XOR0, XOR1) = WorkList[0];
10294 unsigned LogicOp = (Cond == ISD::SETEQ) ? ISD::AND : ISD::OR;
10295 SDValue Cmp = DAG.getSetCC(DL, VT, XOR0, XOR1, Cond);
10296 for (unsigned I = 1; I < WorkList.size(); I++) {
10297 std::tie(XOR0, XOR1) = WorkList[I];
10298 SDValue CmpChain = DAG.getSetCC(DL, VT, XOR0, XOR1, Cond);
10299 Cmp = DAG.getNode(LogicOp, DL, VT, Cmp, CmpChain);
10300 }
10301
10302 // Exit early by inverting the condition, which help reduce indentations.
10303 return Cmp;
10304 }
10305
10306 return SDValue();
10307 }
10308
LowerSETCC(SDValue Op,SelectionDAG & DAG) const10309 SDValue AArch64TargetLowering::LowerSETCC(SDValue Op, SelectionDAG &DAG) const {
10310
10311 if (Op.getValueType().isVector())
10312 return LowerVSETCC(Op, DAG);
10313
10314 bool IsStrict = Op->isStrictFPOpcode();
10315 bool IsSignaling = Op.getOpcode() == ISD::STRICT_FSETCCS;
10316 unsigned OpNo = IsStrict ? 1 : 0;
10317 SDValue Chain;
10318 if (IsStrict)
10319 Chain = Op.getOperand(0);
10320 SDValue LHS = Op.getOperand(OpNo + 0);
10321 SDValue RHS = Op.getOperand(OpNo + 1);
10322 ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(OpNo + 2))->get();
10323 SDLoc dl(Op);
10324
10325 // We chose ZeroOrOneBooleanContents, so use zero and one.
10326 EVT VT = Op.getValueType();
10327 SDValue TVal = DAG.getConstant(1, dl, VT);
10328 SDValue FVal = DAG.getConstant(0, dl, VT);
10329
10330 // Handle f128 first, since one possible outcome is a normal integer
10331 // comparison which gets picked up by the next if statement.
10332 if (LHS.getValueType() == MVT::f128) {
10333 softenSetCCOperands(DAG, MVT::f128, LHS, RHS, CC, dl, LHS, RHS, Chain,
10334 IsSignaling);
10335
10336 // If softenSetCCOperands returned a scalar, use it.
10337 if (!RHS.getNode()) {
10338 assert(LHS.getValueType() == Op.getValueType() &&
10339 "Unexpected setcc expansion!");
10340 return IsStrict ? DAG.getMergeValues({LHS, Chain}, dl) : LHS;
10341 }
10342 }
10343
10344 if (LHS.getValueType().isInteger()) {
10345 SDValue CCVal;
10346 SDValue Cmp = getAArch64Cmp(
10347 LHS, RHS, ISD::getSetCCInverse(CC, LHS.getValueType()), CCVal, DAG, dl);
10348
10349 // Note that we inverted the condition above, so we reverse the order of
10350 // the true and false operands here. This will allow the setcc to be
10351 // matched to a single CSINC instruction.
10352 SDValue Res = DAG.getNode(AArch64ISD::CSEL, dl, VT, FVal, TVal, CCVal, Cmp);
10353 return IsStrict ? DAG.getMergeValues({Res, Chain}, dl) : Res;
10354 }
10355
10356 // Now we know we're dealing with FP values.
10357 assert(LHS.getValueType() == MVT::bf16 || LHS.getValueType() == MVT::f16 ||
10358 LHS.getValueType() == MVT::f32 || LHS.getValueType() == MVT::f64);
10359
10360 // If that fails, we'll need to perform an FCMP + CSEL sequence. Go ahead
10361 // and do the comparison.
10362 SDValue Cmp;
10363 if (IsStrict)
10364 Cmp = emitStrictFPComparison(LHS, RHS, dl, DAG, Chain, IsSignaling);
10365 else
10366 Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
10367
10368 AArch64CC::CondCode CC1, CC2;
10369 changeFPCCToAArch64CC(CC, CC1, CC2);
10370 SDValue Res;
10371 if (CC2 == AArch64CC::AL) {
10372 changeFPCCToAArch64CC(ISD::getSetCCInverse(CC, LHS.getValueType()), CC1,
10373 CC2);
10374 SDValue CC1Val = DAG.getConstant(CC1, dl, MVT::i32);
10375
10376 // Note that we inverted the condition above, so we reverse the order of
10377 // the true and false operands here. This will allow the setcc to be
10378 // matched to a single CSINC instruction.
10379 Res = DAG.getNode(AArch64ISD::CSEL, dl, VT, FVal, TVal, CC1Val, Cmp);
10380 } else {
10381 // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't
10382 // totally clean. Some of them require two CSELs to implement. As is in
10383 // this case, we emit the first CSEL and then emit a second using the output
10384 // of the first as the RHS. We're effectively OR'ing the two CC's together.
10385
10386 // FIXME: It would be nice if we could match the two CSELs to two CSINCs.
10387 SDValue CC1Val = DAG.getConstant(CC1, dl, MVT::i32);
10388 SDValue CS1 =
10389 DAG.getNode(AArch64ISD::CSEL, dl, VT, TVal, FVal, CC1Val, Cmp);
10390
10391 SDValue CC2Val = DAG.getConstant(CC2, dl, MVT::i32);
10392 Res = DAG.getNode(AArch64ISD::CSEL, dl, VT, TVal, CS1, CC2Val, Cmp);
10393 }
10394 return IsStrict ? DAG.getMergeValues({Res, Cmp.getValue(1)}, dl) : Res;
10395 }
10396
LowerSETCCCARRY(SDValue Op,SelectionDAG & DAG) const10397 SDValue AArch64TargetLowering::LowerSETCCCARRY(SDValue Op,
10398 SelectionDAG &DAG) const {
10399
10400 SDValue LHS = Op.getOperand(0);
10401 SDValue RHS = Op.getOperand(1);
10402 EVT VT = LHS.getValueType();
10403 if (VT != MVT::i32 && VT != MVT::i64)
10404 return SDValue();
10405
10406 SDLoc DL(Op);
10407 SDValue Carry = Op.getOperand(2);
10408 // SBCS uses a carry not a borrow so the carry flag should be inverted first.
10409 SDValue InvCarry = valueToCarryFlag(Carry, DAG, true);
10410 SDValue Cmp = DAG.getNode(AArch64ISD::SBCS, DL, DAG.getVTList(VT, MVT::Glue),
10411 LHS, RHS, InvCarry);
10412
10413 EVT OpVT = Op.getValueType();
10414 SDValue TVal = DAG.getConstant(1, DL, OpVT);
10415 SDValue FVal = DAG.getConstant(0, DL, OpVT);
10416
10417 ISD::CondCode Cond = cast<CondCodeSDNode>(Op.getOperand(3))->get();
10418 ISD::CondCode CondInv = ISD::getSetCCInverse(Cond, VT);
10419 SDValue CCVal =
10420 DAG.getConstant(changeIntCCToAArch64CC(CondInv), DL, MVT::i32);
10421 // Inputs are swapped because the condition is inverted. This will allow
10422 // matching with a single CSINC instruction.
10423 return DAG.getNode(AArch64ISD::CSEL, DL, OpVT, FVal, TVal, CCVal,
10424 Cmp.getValue(1));
10425 }
10426
LowerSELECT_CC(ISD::CondCode CC,SDValue LHS,SDValue RHS,SDValue TVal,SDValue FVal,const SDLoc & dl,SelectionDAG & DAG) const10427 SDValue AArch64TargetLowering::LowerSELECT_CC(ISD::CondCode CC, SDValue LHS,
10428 SDValue RHS, SDValue TVal,
10429 SDValue FVal, const SDLoc &dl,
10430 SelectionDAG &DAG) const {
10431 // Handle f128 first, because it will result in a comparison of some RTLIB
10432 // call result against zero.
10433 if (LHS.getValueType() == MVT::f128) {
10434 softenSetCCOperands(DAG, MVT::f128, LHS, RHS, CC, dl, LHS, RHS);
10435
10436 // If softenSetCCOperands returned a scalar, we need to compare the result
10437 // against zero to select between true and false values.
10438 if (!RHS.getNode()) {
10439 RHS = DAG.getConstant(0, dl, LHS.getValueType());
10440 CC = ISD::SETNE;
10441 }
10442 }
10443
10444 // Also handle f16, for which we need to do a f32 comparison.
10445 if ((LHS.getValueType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
10446 LHS.getValueType() == MVT::bf16) {
10447 LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, LHS);
10448 RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::f32, RHS);
10449 }
10450
10451 // Next, handle integers.
10452 if (LHS.getValueType().isInteger()) {
10453 assert((LHS.getValueType() == RHS.getValueType()) &&
10454 (LHS.getValueType() == MVT::i32 || LHS.getValueType() == MVT::i64));
10455
10456 ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(FVal);
10457 ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(TVal);
10458 ConstantSDNode *RHSC = dyn_cast<ConstantSDNode>(RHS);
10459 // Check for sign pattern (SELECT_CC setgt, iN lhs, -1, 1, -1) and transform
10460 // into (OR (ASR lhs, N-1), 1), which requires less instructions for the
10461 // supported types.
10462 if (CC == ISD::SETGT && RHSC && RHSC->isAllOnes() && CTVal && CFVal &&
10463 CTVal->isOne() && CFVal->isAllOnes() &&
10464 LHS.getValueType() == TVal.getValueType()) {
10465 EVT VT = LHS.getValueType();
10466 SDValue Shift =
10467 DAG.getNode(ISD::SRA, dl, VT, LHS,
10468 DAG.getConstant(VT.getSizeInBits() - 1, dl, VT));
10469 return DAG.getNode(ISD::OR, dl, VT, Shift, DAG.getConstant(1, dl, VT));
10470 }
10471
10472 // Check for SMAX(lhs, 0) and SMIN(lhs, 0) patterns.
10473 // (SELECT_CC setgt, lhs, 0, lhs, 0) -> (BIC lhs, (SRA lhs, typesize-1))
10474 // (SELECT_CC setlt, lhs, 0, lhs, 0) -> (AND lhs, (SRA lhs, typesize-1))
10475 // Both require less instructions than compare and conditional select.
10476 if ((CC == ISD::SETGT || CC == ISD::SETLT) && LHS == TVal &&
10477 RHSC && RHSC->isZero() && CFVal && CFVal->isZero() &&
10478 LHS.getValueType() == RHS.getValueType()) {
10479 EVT VT = LHS.getValueType();
10480 SDValue Shift =
10481 DAG.getNode(ISD::SRA, dl, VT, LHS,
10482 DAG.getConstant(VT.getSizeInBits() - 1, dl, VT));
10483
10484 if (CC == ISD::SETGT)
10485 Shift = DAG.getNOT(dl, Shift, VT);
10486
10487 return DAG.getNode(ISD::AND, dl, VT, LHS, Shift);
10488 }
10489
10490 unsigned Opcode = AArch64ISD::CSEL;
10491
10492 // If both the TVal and the FVal are constants, see if we can swap them in
10493 // order to for a CSINV or CSINC out of them.
10494 if (CTVal && CFVal && CTVal->isAllOnes() && CFVal->isZero()) {
10495 std::swap(TVal, FVal);
10496 std::swap(CTVal, CFVal);
10497 CC = ISD::getSetCCInverse(CC, LHS.getValueType());
10498 } else if (CTVal && CFVal && CTVal->isOne() && CFVal->isZero()) {
10499 std::swap(TVal, FVal);
10500 std::swap(CTVal, CFVal);
10501 CC = ISD::getSetCCInverse(CC, LHS.getValueType());
10502 } else if (TVal.getOpcode() == ISD::XOR) {
10503 // If TVal is a NOT we want to swap TVal and FVal so that we can match
10504 // with a CSINV rather than a CSEL.
10505 if (isAllOnesConstant(TVal.getOperand(1))) {
10506 std::swap(TVal, FVal);
10507 std::swap(CTVal, CFVal);
10508 CC = ISD::getSetCCInverse(CC, LHS.getValueType());
10509 }
10510 } else if (TVal.getOpcode() == ISD::SUB) {
10511 // If TVal is a negation (SUB from 0) we want to swap TVal and FVal so
10512 // that we can match with a CSNEG rather than a CSEL.
10513 if (isNullConstant(TVal.getOperand(0))) {
10514 std::swap(TVal, FVal);
10515 std::swap(CTVal, CFVal);
10516 CC = ISD::getSetCCInverse(CC, LHS.getValueType());
10517 }
10518 } else if (CTVal && CFVal) {
10519 const int64_t TrueVal = CTVal->getSExtValue();
10520 const int64_t FalseVal = CFVal->getSExtValue();
10521 bool Swap = false;
10522
10523 // If both TVal and FVal are constants, see if FVal is the
10524 // inverse/negation/increment of TVal and generate a CSINV/CSNEG/CSINC
10525 // instead of a CSEL in that case.
10526 if (TrueVal == ~FalseVal) {
10527 Opcode = AArch64ISD::CSINV;
10528 } else if (FalseVal > std::numeric_limits<int64_t>::min() &&
10529 TrueVal == -FalseVal) {
10530 Opcode = AArch64ISD::CSNEG;
10531 } else if (TVal.getValueType() == MVT::i32) {
10532 // If our operands are only 32-bit wide, make sure we use 32-bit
10533 // arithmetic for the check whether we can use CSINC. This ensures that
10534 // the addition in the check will wrap around properly in case there is
10535 // an overflow (which would not be the case if we do the check with
10536 // 64-bit arithmetic).
10537 const uint32_t TrueVal32 = CTVal->getZExtValue();
10538 const uint32_t FalseVal32 = CFVal->getZExtValue();
10539
10540 if ((TrueVal32 == FalseVal32 + 1) || (TrueVal32 + 1 == FalseVal32)) {
10541 Opcode = AArch64ISD::CSINC;
10542
10543 if (TrueVal32 > FalseVal32) {
10544 Swap = true;
10545 }
10546 }
10547 } else {
10548 // 64-bit check whether we can use CSINC.
10549 const uint64_t TrueVal64 = TrueVal;
10550 const uint64_t FalseVal64 = FalseVal;
10551
10552 if ((TrueVal64 == FalseVal64 + 1) || (TrueVal64 + 1 == FalseVal64)) {
10553 Opcode = AArch64ISD::CSINC;
10554
10555 if (TrueVal > FalseVal) {
10556 Swap = true;
10557 }
10558 }
10559 }
10560
10561 // Swap TVal and FVal if necessary.
10562 if (Swap) {
10563 std::swap(TVal, FVal);
10564 std::swap(CTVal, CFVal);
10565 CC = ISD::getSetCCInverse(CC, LHS.getValueType());
10566 }
10567
10568 if (Opcode != AArch64ISD::CSEL) {
10569 // Drop FVal since we can get its value by simply inverting/negating
10570 // TVal.
10571 FVal = TVal;
10572 }
10573 }
10574
10575 // Avoid materializing a constant when possible by reusing a known value in
10576 // a register. However, don't perform this optimization if the known value
10577 // is one, zero or negative one in the case of a CSEL. We can always
10578 // materialize these values using CSINC, CSEL and CSINV with wzr/xzr as the
10579 // FVal, respectively.
10580 ConstantSDNode *RHSVal = dyn_cast<ConstantSDNode>(RHS);
10581 if (Opcode == AArch64ISD::CSEL && RHSVal && !RHSVal->isOne() &&
10582 !RHSVal->isZero() && !RHSVal->isAllOnes()) {
10583 AArch64CC::CondCode AArch64CC = changeIntCCToAArch64CC(CC);
10584 // Transform "a == C ? C : x" to "a == C ? a : x" and "a != C ? x : C" to
10585 // "a != C ? x : a" to avoid materializing C.
10586 if (CTVal && CTVal == RHSVal && AArch64CC == AArch64CC::EQ)
10587 TVal = LHS;
10588 else if (CFVal && CFVal == RHSVal && AArch64CC == AArch64CC::NE)
10589 FVal = LHS;
10590 } else if (Opcode == AArch64ISD::CSNEG && RHSVal && RHSVal->isOne()) {
10591 assert (CTVal && CFVal && "Expected constant operands for CSNEG.");
10592 // Use a CSINV to transform "a == C ? 1 : -1" to "a == C ? a : -1" to
10593 // avoid materializing C.
10594 AArch64CC::CondCode AArch64CC = changeIntCCToAArch64CC(CC);
10595 if (CTVal == RHSVal && AArch64CC == AArch64CC::EQ) {
10596 Opcode = AArch64ISD::CSINV;
10597 TVal = LHS;
10598 FVal = DAG.getConstant(0, dl, FVal.getValueType());
10599 }
10600 }
10601
10602 SDValue CCVal;
10603 SDValue Cmp = getAArch64Cmp(LHS, RHS, CC, CCVal, DAG, dl);
10604 EVT VT = TVal.getValueType();
10605 return DAG.getNode(Opcode, dl, VT, TVal, FVal, CCVal, Cmp);
10606 }
10607
10608 // Now we know we're dealing with FP values.
10609 assert(LHS.getValueType() == MVT::f16 || LHS.getValueType() == MVT::f32 ||
10610 LHS.getValueType() == MVT::f64);
10611 assert(LHS.getValueType() == RHS.getValueType());
10612 EVT VT = TVal.getValueType();
10613 SDValue Cmp = emitComparison(LHS, RHS, CC, dl, DAG);
10614
10615 // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
10616 // clean. Some of them require two CSELs to implement.
10617 AArch64CC::CondCode CC1, CC2;
10618 changeFPCCToAArch64CC(CC, CC1, CC2);
10619
10620 if (DAG.getTarget().Options.UnsafeFPMath) {
10621 // Transform "a == 0.0 ? 0.0 : x" to "a == 0.0 ? a : x" and
10622 // "a != 0.0 ? x : 0.0" to "a != 0.0 ? x : a" to avoid materializing 0.0.
10623 ConstantFPSDNode *RHSVal = dyn_cast<ConstantFPSDNode>(RHS);
10624 if (RHSVal && RHSVal->isZero()) {
10625 ConstantFPSDNode *CFVal = dyn_cast<ConstantFPSDNode>(FVal);
10626 ConstantFPSDNode *CTVal = dyn_cast<ConstantFPSDNode>(TVal);
10627
10628 if ((CC == ISD::SETEQ || CC == ISD::SETOEQ || CC == ISD::SETUEQ) &&
10629 CTVal && CTVal->isZero() && TVal.getValueType() == LHS.getValueType())
10630 TVal = LHS;
10631 else if ((CC == ISD::SETNE || CC == ISD::SETONE || CC == ISD::SETUNE) &&
10632 CFVal && CFVal->isZero() &&
10633 FVal.getValueType() == LHS.getValueType())
10634 FVal = LHS;
10635 }
10636 }
10637
10638 // Emit first, and possibly only, CSEL.
10639 SDValue CC1Val = DAG.getConstant(CC1, dl, MVT::i32);
10640 SDValue CS1 = DAG.getNode(AArch64ISD::CSEL, dl, VT, TVal, FVal, CC1Val, Cmp);
10641
10642 // If we need a second CSEL, emit it, using the output of the first as the
10643 // RHS. We're effectively OR'ing the two CC's together.
10644 if (CC2 != AArch64CC::AL) {
10645 SDValue CC2Val = DAG.getConstant(CC2, dl, MVT::i32);
10646 return DAG.getNode(AArch64ISD::CSEL, dl, VT, TVal, CS1, CC2Val, Cmp);
10647 }
10648
10649 // Otherwise, return the output of the first CSEL.
10650 return CS1;
10651 }
10652
LowerVECTOR_SPLICE(SDValue Op,SelectionDAG & DAG) const10653 SDValue AArch64TargetLowering::LowerVECTOR_SPLICE(SDValue Op,
10654 SelectionDAG &DAG) const {
10655 EVT Ty = Op.getValueType();
10656 auto Idx = Op.getConstantOperandAPInt(2);
10657 int64_t IdxVal = Idx.getSExtValue();
10658 assert(Ty.isScalableVector() &&
10659 "Only expect scalable vectors for custom lowering of VECTOR_SPLICE");
10660
10661 // We can use the splice instruction for certain index values where we are
10662 // able to efficiently generate the correct predicate. The index will be
10663 // inverted and used directly as the input to the ptrue instruction, i.e.
10664 // -1 -> vl1, -2 -> vl2, etc. The predicate will then be reversed to get the
10665 // splice predicate. However, we can only do this if we can guarantee that
10666 // there are enough elements in the vector, hence we check the index <= min
10667 // number of elements.
10668 std::optional<unsigned> PredPattern;
10669 if (Ty.isScalableVector() && IdxVal < 0 &&
10670 (PredPattern = getSVEPredPatternFromNumElements(std::abs(IdxVal))) !=
10671 std::nullopt) {
10672 SDLoc DL(Op);
10673
10674 // Create a predicate where all but the last -IdxVal elements are false.
10675 EVT PredVT = Ty.changeVectorElementType(MVT::i1);
10676 SDValue Pred = getPTrue(DAG, DL, PredVT, *PredPattern);
10677 Pred = DAG.getNode(ISD::VECTOR_REVERSE, DL, PredVT, Pred);
10678
10679 // Now splice the two inputs together using the predicate.
10680 return DAG.getNode(AArch64ISD::SPLICE, DL, Ty, Pred, Op.getOperand(0),
10681 Op.getOperand(1));
10682 }
10683
10684 // We can select to an EXT instruction when indexing the first 256 bytes.
10685 unsigned BlockSize = AArch64::SVEBitsPerBlock / Ty.getVectorMinNumElements();
10686 if (IdxVal >= 0 && (IdxVal * BlockSize / 8) < 256)
10687 return Op;
10688
10689 return SDValue();
10690 }
10691
LowerSELECT_CC(SDValue Op,SelectionDAG & DAG) const10692 SDValue AArch64TargetLowering::LowerSELECT_CC(SDValue Op,
10693 SelectionDAG &DAG) const {
10694 ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(4))->get();
10695 SDValue LHS = Op.getOperand(0);
10696 SDValue RHS = Op.getOperand(1);
10697 SDValue TVal = Op.getOperand(2);
10698 SDValue FVal = Op.getOperand(3);
10699 SDLoc DL(Op);
10700 return LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
10701 }
10702
LowerSELECT(SDValue Op,SelectionDAG & DAG) const10703 SDValue AArch64TargetLowering::LowerSELECT(SDValue Op,
10704 SelectionDAG &DAG) const {
10705 SDValue CCVal = Op->getOperand(0);
10706 SDValue TVal = Op->getOperand(1);
10707 SDValue FVal = Op->getOperand(2);
10708 SDLoc DL(Op);
10709
10710 EVT Ty = Op.getValueType();
10711 if (Ty == MVT::aarch64svcount) {
10712 TVal = DAG.getNode(ISD::BITCAST, DL, MVT::nxv16i1, TVal);
10713 FVal = DAG.getNode(ISD::BITCAST, DL, MVT::nxv16i1, FVal);
10714 SDValue Sel =
10715 DAG.getNode(ISD::SELECT, DL, MVT::nxv16i1, CCVal, TVal, FVal);
10716 return DAG.getNode(ISD::BITCAST, DL, Ty, Sel);
10717 }
10718
10719 if (Ty.isScalableVector()) {
10720 MVT PredVT = MVT::getVectorVT(MVT::i1, Ty.getVectorElementCount());
10721 SDValue SplatPred = DAG.getNode(ISD::SPLAT_VECTOR, DL, PredVT, CCVal);
10722 return DAG.getNode(ISD::VSELECT, DL, Ty, SplatPred, TVal, FVal);
10723 }
10724
10725 if (useSVEForFixedLengthVectorVT(Ty, !Subtarget->isNeonAvailable())) {
10726 // FIXME: Ideally this would be the same as above using i1 types, however
10727 // for the moment we can't deal with fixed i1 vector types properly, so
10728 // instead extend the predicate to a result type sized integer vector.
10729 MVT SplatValVT = MVT::getIntegerVT(Ty.getScalarSizeInBits());
10730 MVT PredVT = MVT::getVectorVT(SplatValVT, Ty.getVectorElementCount());
10731 SDValue SplatVal = DAG.getSExtOrTrunc(CCVal, DL, SplatValVT);
10732 SDValue SplatPred = DAG.getNode(ISD::SPLAT_VECTOR, DL, PredVT, SplatVal);
10733 return DAG.getNode(ISD::VSELECT, DL, Ty, SplatPred, TVal, FVal);
10734 }
10735
10736 // Optimize {s|u}{add|sub|mul}.with.overflow feeding into a select
10737 // instruction.
10738 if (ISD::isOverflowIntrOpRes(CCVal)) {
10739 // Only lower legal XALUO ops.
10740 if (!DAG.getTargetLoweringInfo().isTypeLegal(CCVal->getValueType(0)))
10741 return SDValue();
10742
10743 AArch64CC::CondCode OFCC;
10744 SDValue Value, Overflow;
10745 std::tie(Value, Overflow) = getAArch64XALUOOp(OFCC, CCVal.getValue(0), DAG);
10746 SDValue CCVal = DAG.getConstant(OFCC, DL, MVT::i32);
10747
10748 return DAG.getNode(AArch64ISD::CSEL, DL, Op.getValueType(), TVal, FVal,
10749 CCVal, Overflow);
10750 }
10751
10752 // Lower it the same way as we would lower a SELECT_CC node.
10753 ISD::CondCode CC;
10754 SDValue LHS, RHS;
10755 if (CCVal.getOpcode() == ISD::SETCC) {
10756 LHS = CCVal.getOperand(0);
10757 RHS = CCVal.getOperand(1);
10758 CC = cast<CondCodeSDNode>(CCVal.getOperand(2))->get();
10759 } else {
10760 LHS = CCVal;
10761 RHS = DAG.getConstant(0, DL, CCVal.getValueType());
10762 CC = ISD::SETNE;
10763 }
10764
10765 // If we are lowering a f16 and we do not have fullf16, convert to a f32 in
10766 // order to use FCSELSrrr
10767 if ((Ty == MVT::f16 || Ty == MVT::bf16) && !Subtarget->hasFullFP16()) {
10768 TVal = DAG.getTargetInsertSubreg(AArch64::hsub, DL, MVT::f32,
10769 DAG.getUNDEF(MVT::f32), TVal);
10770 FVal = DAG.getTargetInsertSubreg(AArch64::hsub, DL, MVT::f32,
10771 DAG.getUNDEF(MVT::f32), FVal);
10772 }
10773
10774 SDValue Res = LowerSELECT_CC(CC, LHS, RHS, TVal, FVal, DL, DAG);
10775
10776 if ((Ty == MVT::f16 || Ty == MVT::bf16) && !Subtarget->hasFullFP16()) {
10777 return DAG.getTargetExtractSubreg(AArch64::hsub, DL, Ty, Res);
10778 }
10779
10780 return Res;
10781 }
10782
LowerJumpTable(SDValue Op,SelectionDAG & DAG) const10783 SDValue AArch64TargetLowering::LowerJumpTable(SDValue Op,
10784 SelectionDAG &DAG) const {
10785 // Jump table entries as PC relative offsets. No additional tweaking
10786 // is necessary here. Just get the address of the jump table.
10787 JumpTableSDNode *JT = cast<JumpTableSDNode>(Op);
10788
10789 CodeModel::Model CM = getTargetMachine().getCodeModel();
10790 if (CM == CodeModel::Large && !getTargetMachine().isPositionIndependent() &&
10791 !Subtarget->isTargetMachO())
10792 return getAddrLarge(JT, DAG);
10793 if (CM == CodeModel::Tiny)
10794 return getAddrTiny(JT, DAG);
10795 return getAddr(JT, DAG);
10796 }
10797
LowerBR_JT(SDValue Op,SelectionDAG & DAG) const10798 SDValue AArch64TargetLowering::LowerBR_JT(SDValue Op,
10799 SelectionDAG &DAG) const {
10800 // Jump table entries as PC relative offsets. No additional tweaking
10801 // is necessary here. Just get the address of the jump table.
10802 SDLoc DL(Op);
10803 SDValue JT = Op.getOperand(1);
10804 SDValue Entry = Op.getOperand(2);
10805 int JTI = cast<JumpTableSDNode>(JT.getNode())->getIndex();
10806
10807 auto *AFI = DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
10808 AFI->setJumpTableEntryInfo(JTI, 4, nullptr);
10809
10810 // With aarch64-jump-table-hardening, we only expand the jump table dispatch
10811 // sequence later, to guarantee the integrity of the intermediate values.
10812 if (DAG.getMachineFunction().getFunction().hasFnAttribute(
10813 "aarch64-jump-table-hardening")) {
10814 CodeModel::Model CM = getTargetMachine().getCodeModel();
10815 if (Subtarget->isTargetMachO()) {
10816 if (CM != CodeModel::Small && CM != CodeModel::Large)
10817 report_fatal_error("Unsupported code-model for hardened jump-table");
10818 } else {
10819 // Note that COFF support would likely also need JUMP_TABLE_DEBUG_INFO.
10820 assert(Subtarget->isTargetELF() &&
10821 "jump table hardening only supported on MachO/ELF");
10822 if (CM != CodeModel::Small)
10823 report_fatal_error("Unsupported code-model for hardened jump-table");
10824 }
10825
10826 SDValue X16Copy = DAG.getCopyToReg(DAG.getEntryNode(), DL, AArch64::X16,
10827 Entry, SDValue());
10828 SDNode *B = DAG.getMachineNode(AArch64::BR_JumpTable, DL, MVT::Other,
10829 DAG.getTargetJumpTable(JTI, MVT::i32),
10830 X16Copy.getValue(0), X16Copy.getValue(1));
10831 return SDValue(B, 0);
10832 }
10833
10834 SDNode *Dest =
10835 DAG.getMachineNode(AArch64::JumpTableDest32, DL, MVT::i64, MVT::i64, JT,
10836 Entry, DAG.getTargetJumpTable(JTI, MVT::i32));
10837 SDValue JTInfo = DAG.getJumpTableDebugInfo(JTI, Op.getOperand(0), DL);
10838 return DAG.getNode(ISD::BRIND, DL, MVT::Other, JTInfo, SDValue(Dest, 0));
10839 }
10840
LowerBRIND(SDValue Op,SelectionDAG & DAG) const10841 SDValue AArch64TargetLowering::LowerBRIND(SDValue Op, SelectionDAG &DAG) const {
10842 SDValue Chain = Op.getOperand(0);
10843 SDValue Dest = Op.getOperand(1);
10844
10845 // BR_JT is lowered to BRIND, but the later lowering is specific to indirectbr
10846 // Skip over the jump-table BRINDs, where the destination is JumpTableDest32.
10847 if (Dest->isMachineOpcode() &&
10848 Dest->getMachineOpcode() == AArch64::JumpTableDest32)
10849 return SDValue();
10850
10851 const MachineFunction &MF = DAG.getMachineFunction();
10852 std::optional<uint16_t> BADisc =
10853 Subtarget->getPtrAuthBlockAddressDiscriminatorIfEnabled(MF.getFunction());
10854 if (!BADisc)
10855 return SDValue();
10856
10857 SDLoc DL(Op);
10858
10859 SDValue Disc = DAG.getTargetConstant(*BADisc, DL, MVT::i64);
10860 SDValue Key = DAG.getTargetConstant(AArch64PACKey::IA, DL, MVT::i32);
10861 SDValue AddrDisc = DAG.getRegister(AArch64::XZR, MVT::i64);
10862
10863 SDNode *BrA = DAG.getMachineNode(AArch64::BRA, DL, MVT::Other,
10864 {Dest, Key, Disc, AddrDisc, Chain});
10865 return SDValue(BrA, 0);
10866 }
10867
LowerConstantPool(SDValue Op,SelectionDAG & DAG) const10868 SDValue AArch64TargetLowering::LowerConstantPool(SDValue Op,
10869 SelectionDAG &DAG) const {
10870 ConstantPoolSDNode *CP = cast<ConstantPoolSDNode>(Op);
10871 CodeModel::Model CM = getTargetMachine().getCodeModel();
10872 if (CM == CodeModel::Large) {
10873 // Use the GOT for the large code model on iOS.
10874 if (Subtarget->isTargetMachO()) {
10875 return getGOT(CP, DAG);
10876 }
10877 if (!getTargetMachine().isPositionIndependent())
10878 return getAddrLarge(CP, DAG);
10879 } else if (CM == CodeModel::Tiny) {
10880 return getAddrTiny(CP, DAG);
10881 }
10882 return getAddr(CP, DAG);
10883 }
10884
LowerBlockAddress(SDValue Op,SelectionDAG & DAG) const10885 SDValue AArch64TargetLowering::LowerBlockAddress(SDValue Op,
10886 SelectionDAG &DAG) const {
10887 BlockAddressSDNode *BAN = cast<BlockAddressSDNode>(Op);
10888 const BlockAddress *BA = BAN->getBlockAddress();
10889
10890 if (std::optional<uint16_t> BADisc =
10891 Subtarget->getPtrAuthBlockAddressDiscriminatorIfEnabled(
10892 *BA->getFunction())) {
10893 SDLoc DL(Op);
10894
10895 // This isn't cheap, but BRIND is rare.
10896 SDValue TargetBA = DAG.getTargetBlockAddress(BA, BAN->getValueType(0));
10897
10898 SDValue Disc = DAG.getTargetConstant(*BADisc, DL, MVT::i64);
10899
10900 SDValue Key = DAG.getTargetConstant(AArch64PACKey::IA, DL, MVT::i32);
10901 SDValue AddrDisc = DAG.getRegister(AArch64::XZR, MVT::i64);
10902
10903 SDNode *MOV =
10904 DAG.getMachineNode(AArch64::MOVaddrPAC, DL, {MVT::Other, MVT::Glue},
10905 {TargetBA, Key, AddrDisc, Disc});
10906 return DAG.getCopyFromReg(SDValue(MOV, 0), DL, AArch64::X16, MVT::i64,
10907 SDValue(MOV, 1));
10908 }
10909
10910 CodeModel::Model CM = getTargetMachine().getCodeModel();
10911 if (CM == CodeModel::Large && !Subtarget->isTargetMachO()) {
10912 if (!getTargetMachine().isPositionIndependent())
10913 return getAddrLarge(BAN, DAG);
10914 } else if (CM == CodeModel::Tiny) {
10915 return getAddrTiny(BAN, DAG);
10916 }
10917 return getAddr(BAN, DAG);
10918 }
10919
LowerDarwin_VASTART(SDValue Op,SelectionDAG & DAG) const10920 SDValue AArch64TargetLowering::LowerDarwin_VASTART(SDValue Op,
10921 SelectionDAG &DAG) const {
10922 AArch64FunctionInfo *FuncInfo =
10923 DAG.getMachineFunction().getInfo<AArch64FunctionInfo>();
10924
10925 SDLoc DL(Op);
10926 SDValue FR = DAG.getFrameIndex(FuncInfo->getVarArgsStackIndex(),
10927 getPointerTy(DAG.getDataLayout()));
10928 FR = DAG.getZExtOrTrunc(FR, DL, getPointerMemTy(DAG.getDataLayout()));
10929 const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
10930 return DAG.getStore(Op.getOperand(0), DL, FR, Op.getOperand(1),
10931 MachinePointerInfo(SV));
10932 }
10933
LowerWin64_VASTART(SDValue Op,SelectionDAG & DAG) const10934 SDValue AArch64TargetLowering::LowerWin64_VASTART(SDValue Op,
10935 SelectionDAG &DAG) const {
10936 MachineFunction &MF = DAG.getMachineFunction();
10937 AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
10938
10939 SDLoc DL(Op);
10940 SDValue FR;
10941 if (Subtarget->isWindowsArm64EC()) {
10942 // With the Arm64EC ABI, we compute the address of the varargs save area
10943 // relative to x4. For a normal AArch64->AArch64 call, x4 == sp on entry,
10944 // but calls from an entry thunk can pass in a different address.
10945 Register VReg = MF.addLiveIn(AArch64::X4, &AArch64::GPR64RegClass);
10946 SDValue Val = DAG.getCopyFromReg(DAG.getEntryNode(), DL, VReg, MVT::i64);
10947 uint64_t StackOffset;
10948 if (FuncInfo->getVarArgsGPRSize() > 0)
10949 StackOffset = -(uint64_t)FuncInfo->getVarArgsGPRSize();
10950 else
10951 StackOffset = FuncInfo->getVarArgsStackOffset();
10952 FR = DAG.getNode(ISD::ADD, DL, MVT::i64, Val,
10953 DAG.getConstant(StackOffset, DL, MVT::i64));
10954 } else {
10955 FR = DAG.getFrameIndex(FuncInfo->getVarArgsGPRSize() > 0
10956 ? FuncInfo->getVarArgsGPRIndex()
10957 : FuncInfo->getVarArgsStackIndex(),
10958 getPointerTy(DAG.getDataLayout()));
10959 }
10960 const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
10961 return DAG.getStore(Op.getOperand(0), DL, FR, Op.getOperand(1),
10962 MachinePointerInfo(SV));
10963 }
10964
LowerAAPCS_VASTART(SDValue Op,SelectionDAG & DAG) const10965 SDValue AArch64TargetLowering::LowerAAPCS_VASTART(SDValue Op,
10966 SelectionDAG &DAG) const {
10967 // The layout of the va_list struct is specified in the AArch64 Procedure Call
10968 // Standard, section B.3.
10969 MachineFunction &MF = DAG.getMachineFunction();
10970 AArch64FunctionInfo *FuncInfo = MF.getInfo<AArch64FunctionInfo>();
10971 unsigned PtrSize = Subtarget->isTargetILP32() ? 4 : 8;
10972 auto PtrMemVT = getPointerMemTy(DAG.getDataLayout());
10973 auto PtrVT = getPointerTy(DAG.getDataLayout());
10974 SDLoc DL(Op);
10975
10976 SDValue Chain = Op.getOperand(0);
10977 SDValue VAList = Op.getOperand(1);
10978 const Value *SV = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
10979 SmallVector<SDValue, 4> MemOps;
10980
10981 // void *__stack at offset 0
10982 unsigned Offset = 0;
10983 SDValue Stack = DAG.getFrameIndex(FuncInfo->getVarArgsStackIndex(), PtrVT);
10984 Stack = DAG.getZExtOrTrunc(Stack, DL, PtrMemVT);
10985 MemOps.push_back(DAG.getStore(Chain, DL, Stack, VAList,
10986 MachinePointerInfo(SV), Align(PtrSize)));
10987
10988 // void *__gr_top at offset 8 (4 on ILP32)
10989 Offset += PtrSize;
10990 int GPRSize = FuncInfo->getVarArgsGPRSize();
10991 if (GPRSize > 0) {
10992 SDValue GRTop, GRTopAddr;
10993
10994 GRTopAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
10995 DAG.getConstant(Offset, DL, PtrVT));
10996
10997 GRTop = DAG.getFrameIndex(FuncInfo->getVarArgsGPRIndex(), PtrVT);
10998 GRTop = DAG.getNode(ISD::ADD, DL, PtrVT, GRTop,
10999 DAG.getConstant(GPRSize, DL, PtrVT));
11000 GRTop = DAG.getZExtOrTrunc(GRTop, DL, PtrMemVT);
11001
11002 MemOps.push_back(DAG.getStore(Chain, DL, GRTop, GRTopAddr,
11003 MachinePointerInfo(SV, Offset),
11004 Align(PtrSize)));
11005 }
11006
11007 // void *__vr_top at offset 16 (8 on ILP32)
11008 Offset += PtrSize;
11009 int FPRSize = FuncInfo->getVarArgsFPRSize();
11010 if (FPRSize > 0) {
11011 SDValue VRTop, VRTopAddr;
11012 VRTopAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
11013 DAG.getConstant(Offset, DL, PtrVT));
11014
11015 VRTop = DAG.getFrameIndex(FuncInfo->getVarArgsFPRIndex(), PtrVT);
11016 VRTop = DAG.getNode(ISD::ADD, DL, PtrVT, VRTop,
11017 DAG.getConstant(FPRSize, DL, PtrVT));
11018 VRTop = DAG.getZExtOrTrunc(VRTop, DL, PtrMemVT);
11019
11020 MemOps.push_back(DAG.getStore(Chain, DL, VRTop, VRTopAddr,
11021 MachinePointerInfo(SV, Offset),
11022 Align(PtrSize)));
11023 }
11024
11025 // int __gr_offs at offset 24 (12 on ILP32)
11026 Offset += PtrSize;
11027 SDValue GROffsAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
11028 DAG.getConstant(Offset, DL, PtrVT));
11029 MemOps.push_back(
11030 DAG.getStore(Chain, DL, DAG.getConstant(-GPRSize, DL, MVT::i32),
11031 GROffsAddr, MachinePointerInfo(SV, Offset), Align(4)));
11032
11033 // int __vr_offs at offset 28 (16 on ILP32)
11034 Offset += 4;
11035 SDValue VROffsAddr = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
11036 DAG.getConstant(Offset, DL, PtrVT));
11037 MemOps.push_back(
11038 DAG.getStore(Chain, DL, DAG.getConstant(-FPRSize, DL, MVT::i32),
11039 VROffsAddr, MachinePointerInfo(SV, Offset), Align(4)));
11040
11041 return DAG.getNode(ISD::TokenFactor, DL, MVT::Other, MemOps);
11042 }
11043
LowerVASTART(SDValue Op,SelectionDAG & DAG) const11044 SDValue AArch64TargetLowering::LowerVASTART(SDValue Op,
11045 SelectionDAG &DAG) const {
11046 MachineFunction &MF = DAG.getMachineFunction();
11047 Function &F = MF.getFunction();
11048
11049 if (Subtarget->isCallingConvWin64(F.getCallingConv(), F.isVarArg()))
11050 return LowerWin64_VASTART(Op, DAG);
11051 else if (Subtarget->isTargetDarwin())
11052 return LowerDarwin_VASTART(Op, DAG);
11053 else
11054 return LowerAAPCS_VASTART(Op, DAG);
11055 }
11056
LowerVACOPY(SDValue Op,SelectionDAG & DAG) const11057 SDValue AArch64TargetLowering::LowerVACOPY(SDValue Op,
11058 SelectionDAG &DAG) const {
11059 // AAPCS has three pointers and two ints (= 32 bytes), Darwin has single
11060 // pointer.
11061 SDLoc DL(Op);
11062 unsigned PtrSize = Subtarget->isTargetILP32() ? 4 : 8;
11063 unsigned VaListSize =
11064 (Subtarget->isTargetDarwin() || Subtarget->isTargetWindows())
11065 ? PtrSize
11066 : Subtarget->isTargetILP32() ? 20 : 32;
11067 const Value *DestSV = cast<SrcValueSDNode>(Op.getOperand(3))->getValue();
11068 const Value *SrcSV = cast<SrcValueSDNode>(Op.getOperand(4))->getValue();
11069
11070 return DAG.getMemcpy(Op.getOperand(0), DL, Op.getOperand(1), Op.getOperand(2),
11071 DAG.getConstant(VaListSize, DL, MVT::i32),
11072 Align(PtrSize), false, false, /*CI=*/nullptr,
11073 std::nullopt, MachinePointerInfo(DestSV),
11074 MachinePointerInfo(SrcSV));
11075 }
11076
LowerVAARG(SDValue Op,SelectionDAG & DAG) const11077 SDValue AArch64TargetLowering::LowerVAARG(SDValue Op, SelectionDAG &DAG) const {
11078 assert(Subtarget->isTargetDarwin() &&
11079 "automatic va_arg instruction only works on Darwin");
11080
11081 const Value *V = cast<SrcValueSDNode>(Op.getOperand(2))->getValue();
11082 EVT VT = Op.getValueType();
11083 SDLoc DL(Op);
11084 SDValue Chain = Op.getOperand(0);
11085 SDValue Addr = Op.getOperand(1);
11086 MaybeAlign Align(Op.getConstantOperandVal(3));
11087 unsigned MinSlotSize = Subtarget->isTargetILP32() ? 4 : 8;
11088 auto PtrVT = getPointerTy(DAG.getDataLayout());
11089 auto PtrMemVT = getPointerMemTy(DAG.getDataLayout());
11090 SDValue VAList =
11091 DAG.getLoad(PtrMemVT, DL, Chain, Addr, MachinePointerInfo(V));
11092 Chain = VAList.getValue(1);
11093 VAList = DAG.getZExtOrTrunc(VAList, DL, PtrVT);
11094
11095 if (VT.isScalableVector())
11096 report_fatal_error("Passing SVE types to variadic functions is "
11097 "currently not supported");
11098
11099 if (Align && *Align > MinSlotSize) {
11100 VAList = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
11101 DAG.getConstant(Align->value() - 1, DL, PtrVT));
11102 VAList = DAG.getNode(ISD::AND, DL, PtrVT, VAList,
11103 DAG.getConstant(-(int64_t)Align->value(), DL, PtrVT));
11104 }
11105
11106 Type *ArgTy = VT.getTypeForEVT(*DAG.getContext());
11107 unsigned ArgSize = DAG.getDataLayout().getTypeAllocSize(ArgTy);
11108
11109 // Scalar integer and FP values smaller than 64 bits are implicitly extended
11110 // up to 64 bits. At the very least, we have to increase the striding of the
11111 // vaargs list to match this, and for FP values we need to introduce
11112 // FP_ROUND nodes as well.
11113 if (VT.isInteger() && !VT.isVector())
11114 ArgSize = std::max(ArgSize, MinSlotSize);
11115 bool NeedFPTrunc = false;
11116 if (VT.isFloatingPoint() && !VT.isVector() && VT != MVT::f64) {
11117 ArgSize = 8;
11118 NeedFPTrunc = true;
11119 }
11120
11121 // Increment the pointer, VAList, to the next vaarg
11122 SDValue VANext = DAG.getNode(ISD::ADD, DL, PtrVT, VAList,
11123 DAG.getConstant(ArgSize, DL, PtrVT));
11124 VANext = DAG.getZExtOrTrunc(VANext, DL, PtrMemVT);
11125
11126 // Store the incremented VAList to the legalized pointer
11127 SDValue APStore =
11128 DAG.getStore(Chain, DL, VANext, Addr, MachinePointerInfo(V));
11129
11130 // Load the actual argument out of the pointer VAList
11131 if (NeedFPTrunc) {
11132 // Load the value as an f64.
11133 SDValue WideFP =
11134 DAG.getLoad(MVT::f64, DL, APStore, VAList, MachinePointerInfo());
11135 // Round the value down to an f32.
11136 SDValue NarrowFP =
11137 DAG.getNode(ISD::FP_ROUND, DL, VT, WideFP.getValue(0),
11138 DAG.getIntPtrConstant(1, DL, /*isTarget=*/true));
11139 SDValue Ops[] = { NarrowFP, WideFP.getValue(1) };
11140 // Merge the rounded value with the chain output of the load.
11141 return DAG.getMergeValues(Ops, DL);
11142 }
11143
11144 return DAG.getLoad(VT, DL, APStore, VAList, MachinePointerInfo());
11145 }
11146
LowerFRAMEADDR(SDValue Op,SelectionDAG & DAG) const11147 SDValue AArch64TargetLowering::LowerFRAMEADDR(SDValue Op,
11148 SelectionDAG &DAG) const {
11149 MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
11150 MFI.setFrameAddressIsTaken(true);
11151
11152 EVT VT = Op.getValueType();
11153 SDLoc DL(Op);
11154 unsigned Depth = Op.getConstantOperandVal(0);
11155 SDValue FrameAddr =
11156 DAG.getCopyFromReg(DAG.getEntryNode(), DL, AArch64::FP, MVT::i64);
11157 while (Depth--)
11158 FrameAddr = DAG.getLoad(VT, DL, DAG.getEntryNode(), FrameAddr,
11159 MachinePointerInfo());
11160
11161 if (Subtarget->isTargetILP32())
11162 FrameAddr = DAG.getNode(ISD::AssertZext, DL, MVT::i64, FrameAddr,
11163 DAG.getValueType(VT));
11164
11165 return FrameAddr;
11166 }
11167
LowerSPONENTRY(SDValue Op,SelectionDAG & DAG) const11168 SDValue AArch64TargetLowering::LowerSPONENTRY(SDValue Op,
11169 SelectionDAG &DAG) const {
11170 MachineFrameInfo &MFI = DAG.getMachineFunction().getFrameInfo();
11171
11172 EVT VT = getPointerTy(DAG.getDataLayout());
11173 SDLoc DL(Op);
11174 int FI = MFI.CreateFixedObject(4, 0, false);
11175 return DAG.getFrameIndex(FI, VT);
11176 }
11177
11178 #define GET_REGISTER_MATCHER
11179 #include "AArch64GenAsmMatcher.inc"
11180
11181 // FIXME? Maybe this could be a TableGen attribute on some registers and
11182 // this table could be generated automatically from RegInfo.
11183 Register AArch64TargetLowering::
getRegisterByName(const char * RegName,LLT VT,const MachineFunction & MF) const11184 getRegisterByName(const char* RegName, LLT VT, const MachineFunction &MF) const {
11185 Register Reg = MatchRegisterName(RegName);
11186 if (AArch64::X1 <= Reg && Reg <= AArch64::X28) {
11187 const AArch64RegisterInfo *MRI = Subtarget->getRegisterInfo();
11188 unsigned DwarfRegNum = MRI->getDwarfRegNum(Reg, false);
11189 if (!Subtarget->isXRegisterReserved(DwarfRegNum) &&
11190 !MRI->isReservedReg(MF, Reg))
11191 Reg = 0;
11192 }
11193 if (Reg)
11194 return Reg;
11195 report_fatal_error(Twine("Invalid register name \""
11196 + StringRef(RegName) + "\"."));
11197 }
11198
LowerADDROFRETURNADDR(SDValue Op,SelectionDAG & DAG) const11199 SDValue AArch64TargetLowering::LowerADDROFRETURNADDR(SDValue Op,
11200 SelectionDAG &DAG) const {
11201 DAG.getMachineFunction().getFrameInfo().setFrameAddressIsTaken(true);
11202
11203 EVT VT = Op.getValueType();
11204 SDLoc DL(Op);
11205
11206 SDValue FrameAddr =
11207 DAG.getCopyFromReg(DAG.getEntryNode(), DL, AArch64::FP, VT);
11208 SDValue Offset = DAG.getConstant(8, DL, getPointerTy(DAG.getDataLayout()));
11209
11210 return DAG.getNode(ISD::ADD, DL, VT, FrameAddr, Offset);
11211 }
11212
LowerRETURNADDR(SDValue Op,SelectionDAG & DAG) const11213 SDValue AArch64TargetLowering::LowerRETURNADDR(SDValue Op,
11214 SelectionDAG &DAG) const {
11215 MachineFunction &MF = DAG.getMachineFunction();
11216 MachineFrameInfo &MFI = MF.getFrameInfo();
11217 MFI.setReturnAddressIsTaken(true);
11218
11219 EVT VT = Op.getValueType();
11220 SDLoc DL(Op);
11221 unsigned Depth = Op.getConstantOperandVal(0);
11222 SDValue ReturnAddress;
11223 if (Depth) {
11224 SDValue FrameAddr = LowerFRAMEADDR(Op, DAG);
11225 SDValue Offset = DAG.getConstant(8, DL, getPointerTy(DAG.getDataLayout()));
11226 ReturnAddress = DAG.getLoad(
11227 VT, DL, DAG.getEntryNode(),
11228 DAG.getNode(ISD::ADD, DL, VT, FrameAddr, Offset), MachinePointerInfo());
11229 } else {
11230 // Return LR, which contains the return address. Mark it an implicit
11231 // live-in.
11232 Register Reg = MF.addLiveIn(AArch64::LR, &AArch64::GPR64RegClass);
11233 ReturnAddress = DAG.getCopyFromReg(DAG.getEntryNode(), DL, Reg, VT);
11234 }
11235
11236 // The XPACLRI instruction assembles to a hint-space instruction before
11237 // Armv8.3-A therefore this instruction can be safely used for any pre
11238 // Armv8.3-A architectures. On Armv8.3-A and onwards XPACI is available so use
11239 // that instead.
11240 SDNode *St;
11241 if (Subtarget->hasPAuth()) {
11242 St = DAG.getMachineNode(AArch64::XPACI, DL, VT, ReturnAddress);
11243 } else {
11244 // XPACLRI operates on LR therefore we must move the operand accordingly.
11245 SDValue Chain =
11246 DAG.getCopyToReg(DAG.getEntryNode(), DL, AArch64::LR, ReturnAddress);
11247 St = DAG.getMachineNode(AArch64::XPACLRI, DL, VT, Chain);
11248 }
11249 return SDValue(St, 0);
11250 }
11251
11252 /// LowerShiftParts - Lower SHL_PARTS/SRA_PARTS/SRL_PARTS, which returns two
11253 /// i32 values and take a 2 x i32 value to shift plus a shift amount.
LowerShiftParts(SDValue Op,SelectionDAG & DAG) const11254 SDValue AArch64TargetLowering::LowerShiftParts(SDValue Op,
11255 SelectionDAG &DAG) const {
11256 SDValue Lo, Hi;
11257 expandShiftParts(Op.getNode(), Lo, Hi, DAG);
11258 return DAG.getMergeValues({Lo, Hi}, SDLoc(Op));
11259 }
11260
isOffsetFoldingLegal(const GlobalAddressSDNode * GA) const11261 bool AArch64TargetLowering::isOffsetFoldingLegal(
11262 const GlobalAddressSDNode *GA) const {
11263 // Offsets are folded in the DAG combine rather than here so that we can
11264 // intelligently choose an offset based on the uses.
11265 return false;
11266 }
11267
isFPImmLegal(const APFloat & Imm,EVT VT,bool OptForSize) const11268 bool AArch64TargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,
11269 bool OptForSize) const {
11270 bool IsLegal = false;
11271 // We can materialize #0.0 as fmov $Rd, XZR for 64-bit, 32-bit cases, and
11272 // 16-bit case when target has full fp16 support.
11273 // We encode bf16 bit patterns as if they were fp16. This results in very
11274 // strange looking assembly but should populate the register with appropriate
11275 // values. Let's say we wanted to encode 0xR3FC0 which is 1.5 in BF16. We will
11276 // end up encoding this as the imm8 0x7f. This imm8 will be expanded to the
11277 // FP16 1.9375 which shares the same bit pattern as BF16 1.5.
11278 // FIXME: We should be able to handle f128 as well with a clever lowering.
11279 const APInt ImmInt = Imm.bitcastToAPInt();
11280 if (VT == MVT::f64)
11281 IsLegal = AArch64_AM::getFP64Imm(ImmInt) != -1 || Imm.isPosZero();
11282 else if (VT == MVT::f32)
11283 IsLegal = AArch64_AM::getFP32Imm(ImmInt) != -1 || Imm.isPosZero();
11284 else if (VT == MVT::f16 || VT == MVT::bf16)
11285 IsLegal =
11286 (Subtarget->hasFullFP16() && AArch64_AM::getFP16Imm(ImmInt) != -1) ||
11287 Imm.isPosZero();
11288
11289 // If we can not materialize in immediate field for fmov, check if the
11290 // value can be encoded as the immediate operand of a logical instruction.
11291 // The immediate value will be created with either MOVZ, MOVN, or ORR.
11292 // TODO: fmov h0, w0 is also legal, however we don't have an isel pattern to
11293 // generate that fmov.
11294 if (!IsLegal && (VT == MVT::f64 || VT == MVT::f32)) {
11295 // The cost is actually exactly the same for mov+fmov vs. adrp+ldr;
11296 // however the mov+fmov sequence is always better because of the reduced
11297 // cache pressure. The timings are still the same if you consider
11298 // movw+movk+fmov vs. adrp+ldr (it's one instruction longer, but the
11299 // movw+movk is fused). So we limit up to 2 instrdduction at most.
11300 SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn;
11301 AArch64_IMM::expandMOVImm(ImmInt.getZExtValue(), VT.getSizeInBits(), Insn);
11302 unsigned Limit = (OptForSize ? 1 : (Subtarget->hasFuseLiterals() ? 5 : 2));
11303 IsLegal = Insn.size() <= Limit;
11304 }
11305
11306 LLVM_DEBUG(dbgs() << (IsLegal ? "Legal " : "Illegal ") << VT
11307 << " imm value: "; Imm.dump(););
11308 return IsLegal;
11309 }
11310
11311 //===----------------------------------------------------------------------===//
11312 // AArch64 Optimization Hooks
11313 //===----------------------------------------------------------------------===//
11314
getEstimate(const AArch64Subtarget * ST,unsigned Opcode,SDValue Operand,SelectionDAG & DAG,int & ExtraSteps)11315 static SDValue getEstimate(const AArch64Subtarget *ST, unsigned Opcode,
11316 SDValue Operand, SelectionDAG &DAG,
11317 int &ExtraSteps) {
11318 EVT VT = Operand.getValueType();
11319 if ((ST->hasNEON() &&
11320 (VT == MVT::f64 || VT == MVT::v1f64 || VT == MVT::v2f64 ||
11321 VT == MVT::f32 || VT == MVT::v1f32 || VT == MVT::v2f32 ||
11322 VT == MVT::v4f32)) ||
11323 (ST->hasSVE() &&
11324 (VT == MVT::nxv8f16 || VT == MVT::nxv4f32 || VT == MVT::nxv2f64))) {
11325 if (ExtraSteps == TargetLoweringBase::ReciprocalEstimate::Unspecified) {
11326 // For the reciprocal estimates, convergence is quadratic, so the number
11327 // of digits is doubled after each iteration. In ARMv8, the accuracy of
11328 // the initial estimate is 2^-8. Thus the number of extra steps to refine
11329 // the result for float (23 mantissa bits) is 2 and for double (52
11330 // mantissa bits) is 3.
11331 constexpr unsigned AccurateBits = 8;
11332 unsigned DesiredBits =
11333 APFloat::semanticsPrecision(DAG.EVTToAPFloatSemantics(VT));
11334 ExtraSteps = DesiredBits <= AccurateBits
11335 ? 0
11336 : Log2_64_Ceil(DesiredBits) - Log2_64_Ceil(AccurateBits);
11337 }
11338
11339 return DAG.getNode(Opcode, SDLoc(Operand), VT, Operand);
11340 }
11341
11342 return SDValue();
11343 }
11344
11345 SDValue
getSqrtInputTest(SDValue Op,SelectionDAG & DAG,const DenormalMode & Mode) const11346 AArch64TargetLowering::getSqrtInputTest(SDValue Op, SelectionDAG &DAG,
11347 const DenormalMode &Mode) const {
11348 SDLoc DL(Op);
11349 EVT VT = Op.getValueType();
11350 EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
11351 SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
11352 return DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
11353 }
11354
11355 SDValue
getSqrtResultForDenormInput(SDValue Op,SelectionDAG & DAG) const11356 AArch64TargetLowering::getSqrtResultForDenormInput(SDValue Op,
11357 SelectionDAG &DAG) const {
11358 return Op;
11359 }
11360
getSqrtEstimate(SDValue Operand,SelectionDAG & DAG,int Enabled,int & ExtraSteps,bool & UseOneConst,bool Reciprocal) const11361 SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand,
11362 SelectionDAG &DAG, int Enabled,
11363 int &ExtraSteps,
11364 bool &UseOneConst,
11365 bool Reciprocal) const {
11366 if (Enabled == ReciprocalEstimate::Enabled ||
11367 (Enabled == ReciprocalEstimate::Unspecified && Subtarget->useRSqrt()))
11368 if (SDValue Estimate = getEstimate(Subtarget, AArch64ISD::FRSQRTE, Operand,
11369 DAG, ExtraSteps)) {
11370 SDLoc DL(Operand);
11371 EVT VT = Operand.getValueType();
11372
11373 SDNodeFlags Flags;
11374 Flags.setAllowReassociation(true);
11375
11376 // Newton reciprocal square root iteration: E * 0.5 * (3 - X * E^2)
11377 // AArch64 reciprocal square root iteration instruction: 0.5 * (3 - M * N)
11378 for (int i = ExtraSteps; i > 0; --i) {
11379 SDValue Step = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Estimate,
11380 Flags);
11381 Step = DAG.getNode(AArch64ISD::FRSQRTS, DL, VT, Operand, Step, Flags);
11382 Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, Flags);
11383 }
11384 if (!Reciprocal)
11385 Estimate = DAG.getNode(ISD::FMUL, DL, VT, Operand, Estimate, Flags);
11386
11387 ExtraSteps = 0;
11388 return Estimate;
11389 }
11390
11391 return SDValue();
11392 }
11393
getRecipEstimate(SDValue Operand,SelectionDAG & DAG,int Enabled,int & ExtraSteps) const11394 SDValue AArch64TargetLowering::getRecipEstimate(SDValue Operand,
11395 SelectionDAG &DAG, int Enabled,
11396 int &ExtraSteps) const {
11397 if (Enabled == ReciprocalEstimate::Enabled)
11398 if (SDValue Estimate = getEstimate(Subtarget, AArch64ISD::FRECPE, Operand,
11399 DAG, ExtraSteps)) {
11400 SDLoc DL(Operand);
11401 EVT VT = Operand.getValueType();
11402
11403 SDNodeFlags Flags;
11404 Flags.setAllowReassociation(true);
11405
11406 // Newton reciprocal iteration: E * (2 - X * E)
11407 // AArch64 reciprocal iteration instruction: (2 - M * N)
11408 for (int i = ExtraSteps; i > 0; --i) {
11409 SDValue Step = DAG.getNode(AArch64ISD::FRECPS, DL, VT, Operand,
11410 Estimate, Flags);
11411 Estimate = DAG.getNode(ISD::FMUL, DL, VT, Estimate, Step, Flags);
11412 }
11413
11414 ExtraSteps = 0;
11415 return Estimate;
11416 }
11417
11418 return SDValue();
11419 }
11420
11421 //===----------------------------------------------------------------------===//
11422 // AArch64 Inline Assembly Support
11423 //===----------------------------------------------------------------------===//
11424
11425 // Table of Constraints
11426 // TODO: This is the current set of constraints supported by ARM for the
11427 // compiler, not all of them may make sense.
11428 //
11429 // r - A general register
11430 // w - An FP/SIMD register of some size in the range v0-v31
11431 // x - An FP/SIMD register of some size in the range v0-v15
11432 // I - Constant that can be used with an ADD instruction
11433 // J - Constant that can be used with a SUB instruction
11434 // K - Constant that can be used with a 32-bit logical instruction
11435 // L - Constant that can be used with a 64-bit logical instruction
11436 // M - Constant that can be used as a 32-bit MOV immediate
11437 // N - Constant that can be used as a 64-bit MOV immediate
11438 // Q - A memory reference with base register and no offset
11439 // S - A symbolic address
11440 // Y - Floating point constant zero
11441 // Z - Integer constant zero
11442 //
11443 // Note that general register operands will be output using their 64-bit x
11444 // register name, whatever the size of the variable, unless the asm operand
11445 // is prefixed by the %w modifier. Floating-point and SIMD register operands
11446 // will be output with the v prefix unless prefixed by the %b, %h, %s, %d or
11447 // %q modifier.
LowerXConstraint(EVT ConstraintVT) const11448 const char *AArch64TargetLowering::LowerXConstraint(EVT ConstraintVT) const {
11449 // At this point, we have to lower this constraint to something else, so we
11450 // lower it to an "r" or "w". However, by doing this we will force the result
11451 // to be in register, while the X constraint is much more permissive.
11452 //
11453 // Although we are correct (we are free to emit anything, without
11454 // constraints), we might break use cases that would expect us to be more
11455 // efficient and emit something else.
11456 if (!Subtarget->hasFPARMv8())
11457 return "r";
11458
11459 if (ConstraintVT.isFloatingPoint())
11460 return "w";
11461
11462 if (ConstraintVT.isVector() &&
11463 (ConstraintVT.getSizeInBits() == 64 ||
11464 ConstraintVT.getSizeInBits() == 128))
11465 return "w";
11466
11467 return "r";
11468 }
11469
11470 enum class PredicateConstraint { Uph, Upl, Upa };
11471
11472 static std::optional<PredicateConstraint>
parsePredicateConstraint(StringRef Constraint)11473 parsePredicateConstraint(StringRef Constraint) {
11474 return StringSwitch<std::optional<PredicateConstraint>>(Constraint)
11475 .Case("Uph", PredicateConstraint::Uph)
11476 .Case("Upl", PredicateConstraint::Upl)
11477 .Case("Upa", PredicateConstraint::Upa)
11478 .Default(std::nullopt);
11479 }
11480
11481 static const TargetRegisterClass *
getPredicateRegisterClass(PredicateConstraint Constraint,EVT VT)11482 getPredicateRegisterClass(PredicateConstraint Constraint, EVT VT) {
11483 if (VT != MVT::aarch64svcount &&
11484 (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1))
11485 return nullptr;
11486
11487 switch (Constraint) {
11488 case PredicateConstraint::Uph:
11489 return VT == MVT::aarch64svcount ? &AArch64::PNR_p8to15RegClass
11490 : &AArch64::PPR_p8to15RegClass;
11491 case PredicateConstraint::Upl:
11492 return VT == MVT::aarch64svcount ? &AArch64::PNR_3bRegClass
11493 : &AArch64::PPR_3bRegClass;
11494 case PredicateConstraint::Upa:
11495 return VT == MVT::aarch64svcount ? &AArch64::PNRRegClass
11496 : &AArch64::PPRRegClass;
11497 }
11498
11499 llvm_unreachable("Missing PredicateConstraint!");
11500 }
11501
11502 enum class ReducedGprConstraint { Uci, Ucj };
11503
11504 static std::optional<ReducedGprConstraint>
parseReducedGprConstraint(StringRef Constraint)11505 parseReducedGprConstraint(StringRef Constraint) {
11506 return StringSwitch<std::optional<ReducedGprConstraint>>(Constraint)
11507 .Case("Uci", ReducedGprConstraint::Uci)
11508 .Case("Ucj", ReducedGprConstraint::Ucj)
11509 .Default(std::nullopt);
11510 }
11511
11512 static const TargetRegisterClass *
getReducedGprRegisterClass(ReducedGprConstraint Constraint,EVT VT)11513 getReducedGprRegisterClass(ReducedGprConstraint Constraint, EVT VT) {
11514 if (!VT.isScalarInteger() || VT.getFixedSizeInBits() > 64)
11515 return nullptr;
11516
11517 switch (Constraint) {
11518 case ReducedGprConstraint::Uci:
11519 return &AArch64::MatrixIndexGPR32_8_11RegClass;
11520 case ReducedGprConstraint::Ucj:
11521 return &AArch64::MatrixIndexGPR32_12_15RegClass;
11522 }
11523
11524 llvm_unreachable("Missing ReducedGprConstraint!");
11525 }
11526
11527 // The set of cc code supported is from
11528 // https://gcc.gnu.org/onlinedocs/gcc/Extended-Asm.html#Flag-Output-Operands
parseConstraintCode(llvm::StringRef Constraint)11529 static AArch64CC::CondCode parseConstraintCode(llvm::StringRef Constraint) {
11530 AArch64CC::CondCode Cond = StringSwitch<AArch64CC::CondCode>(Constraint)
11531 .Case("{@cchi}", AArch64CC::HI)
11532 .Case("{@cccs}", AArch64CC::HS)
11533 .Case("{@cclo}", AArch64CC::LO)
11534 .Case("{@ccls}", AArch64CC::LS)
11535 .Case("{@cccc}", AArch64CC::LO)
11536 .Case("{@cceq}", AArch64CC::EQ)
11537 .Case("{@ccgt}", AArch64CC::GT)
11538 .Case("{@ccge}", AArch64CC::GE)
11539 .Case("{@cclt}", AArch64CC::LT)
11540 .Case("{@ccle}", AArch64CC::LE)
11541 .Case("{@cchs}", AArch64CC::HS)
11542 .Case("{@ccne}", AArch64CC::NE)
11543 .Case("{@ccvc}", AArch64CC::VC)
11544 .Case("{@ccpl}", AArch64CC::PL)
11545 .Case("{@ccvs}", AArch64CC::VS)
11546 .Case("{@ccmi}", AArch64CC::MI)
11547 .Default(AArch64CC::Invalid);
11548 return Cond;
11549 }
11550
11551 /// Helper function to create 'CSET', which is equivalent to 'CSINC <Wd>, WZR,
11552 /// WZR, invert(<cond>)'.
getSETCC(AArch64CC::CondCode CC,SDValue NZCV,const SDLoc & DL,SelectionDAG & DAG)11553 static SDValue getSETCC(AArch64CC::CondCode CC, SDValue NZCV, const SDLoc &DL,
11554 SelectionDAG &DAG) {
11555 return DAG.getNode(
11556 AArch64ISD::CSINC, DL, MVT::i32, DAG.getConstant(0, DL, MVT::i32),
11557 DAG.getConstant(0, DL, MVT::i32),
11558 DAG.getConstant(getInvertedCondCode(CC), DL, MVT::i32), NZCV);
11559 }
11560
11561 // Lower @cc flag output via getSETCC.
LowerAsmOutputForConstraint(SDValue & Chain,SDValue & Glue,const SDLoc & DL,const AsmOperandInfo & OpInfo,SelectionDAG & DAG) const11562 SDValue AArch64TargetLowering::LowerAsmOutputForConstraint(
11563 SDValue &Chain, SDValue &Glue, const SDLoc &DL,
11564 const AsmOperandInfo &OpInfo, SelectionDAG &DAG) const {
11565 AArch64CC::CondCode Cond = parseConstraintCode(OpInfo.ConstraintCode);
11566 if (Cond == AArch64CC::Invalid)
11567 return SDValue();
11568 // The output variable should be a scalar integer.
11569 if (OpInfo.ConstraintVT.isVector() || !OpInfo.ConstraintVT.isInteger() ||
11570 OpInfo.ConstraintVT.getSizeInBits() < 8)
11571 report_fatal_error("Flag output operand is of invalid type");
11572
11573 // Get NZCV register. Only update chain when copyfrom is glued.
11574 if (Glue.getNode()) {
11575 Glue = DAG.getCopyFromReg(Chain, DL, AArch64::NZCV, MVT::i32, Glue);
11576 Chain = Glue.getValue(1);
11577 } else
11578 Glue = DAG.getCopyFromReg(Chain, DL, AArch64::NZCV, MVT::i32);
11579 // Extract CC code.
11580 SDValue CC = getSETCC(Cond, Glue, DL, DAG);
11581
11582 SDValue Result;
11583
11584 // Truncate or ZERO_EXTEND based on value types.
11585 if (OpInfo.ConstraintVT.getSizeInBits() <= 32)
11586 Result = DAG.getNode(ISD::TRUNCATE, DL, OpInfo.ConstraintVT, CC);
11587 else
11588 Result = DAG.getNode(ISD::ZERO_EXTEND, DL, OpInfo.ConstraintVT, CC);
11589
11590 return Result;
11591 }
11592
11593 /// getConstraintType - Given a constraint letter, return the type of
11594 /// constraint it is for this target.
11595 AArch64TargetLowering::ConstraintType
getConstraintType(StringRef Constraint) const11596 AArch64TargetLowering::getConstraintType(StringRef Constraint) const {
11597 if (Constraint.size() == 1) {
11598 switch (Constraint[0]) {
11599 default:
11600 break;
11601 case 'x':
11602 case 'w':
11603 case 'y':
11604 return C_RegisterClass;
11605 // An address with a single base register. Due to the way we
11606 // currently handle addresses it is the same as 'r'.
11607 case 'Q':
11608 return C_Memory;
11609 case 'I':
11610 case 'J':
11611 case 'K':
11612 case 'L':
11613 case 'M':
11614 case 'N':
11615 case 'Y':
11616 case 'Z':
11617 return C_Immediate;
11618 case 'z':
11619 case 'S': // A symbol or label reference with a constant offset
11620 return C_Other;
11621 }
11622 } else if (parsePredicateConstraint(Constraint))
11623 return C_RegisterClass;
11624 else if (parseReducedGprConstraint(Constraint))
11625 return C_RegisterClass;
11626 else if (parseConstraintCode(Constraint) != AArch64CC::Invalid)
11627 return C_Other;
11628 return TargetLowering::getConstraintType(Constraint);
11629 }
11630
11631 /// Examine constraint type and operand type and determine a weight value.
11632 /// This object must already have been set up with the operand type
11633 /// and the current alternative constraint selected.
11634 TargetLowering::ConstraintWeight
getSingleConstraintMatchWeight(AsmOperandInfo & info,const char * constraint) const11635 AArch64TargetLowering::getSingleConstraintMatchWeight(
11636 AsmOperandInfo &info, const char *constraint) const {
11637 ConstraintWeight weight = CW_Invalid;
11638 Value *CallOperandVal = info.CallOperandVal;
11639 // If we don't have a value, we can't do a match,
11640 // but allow it at the lowest weight.
11641 if (!CallOperandVal)
11642 return CW_Default;
11643 Type *type = CallOperandVal->getType();
11644 // Look at the constraint type.
11645 switch (*constraint) {
11646 default:
11647 weight = TargetLowering::getSingleConstraintMatchWeight(info, constraint);
11648 break;
11649 case 'x':
11650 case 'w':
11651 case 'y':
11652 if (type->isFloatingPointTy() || type->isVectorTy())
11653 weight = CW_Register;
11654 break;
11655 case 'z':
11656 weight = CW_Constant;
11657 break;
11658 case 'U':
11659 if (parsePredicateConstraint(constraint) ||
11660 parseReducedGprConstraint(constraint))
11661 weight = CW_Register;
11662 break;
11663 }
11664 return weight;
11665 }
11666
11667 std::pair<unsigned, const TargetRegisterClass *>
getRegForInlineAsmConstraint(const TargetRegisterInfo * TRI,StringRef Constraint,MVT VT) const11668 AArch64TargetLowering::getRegForInlineAsmConstraint(
11669 const TargetRegisterInfo *TRI, StringRef Constraint, MVT VT) const {
11670 if (Constraint.size() == 1) {
11671 switch (Constraint[0]) {
11672 case 'r':
11673 if (VT.isScalableVector())
11674 return std::make_pair(0U, nullptr);
11675 if (Subtarget->hasLS64() && VT.getSizeInBits() == 512)
11676 return std::make_pair(0U, &AArch64::GPR64x8ClassRegClass);
11677 if (VT.getFixedSizeInBits() == 64)
11678 return std::make_pair(0U, &AArch64::GPR64commonRegClass);
11679 return std::make_pair(0U, &AArch64::GPR32commonRegClass);
11680 case 'w': {
11681 if (!Subtarget->hasFPARMv8())
11682 break;
11683 if (VT.isScalableVector()) {
11684 if (VT.getVectorElementType() != MVT::i1)
11685 return std::make_pair(0U, &AArch64::ZPRRegClass);
11686 return std::make_pair(0U, nullptr);
11687 }
11688 if (VT == MVT::Other)
11689 break;
11690 uint64_t VTSize = VT.getFixedSizeInBits();
11691 if (VTSize == 16)
11692 return std::make_pair(0U, &AArch64::FPR16RegClass);
11693 if (VTSize == 32)
11694 return std::make_pair(0U, &AArch64::FPR32RegClass);
11695 if (VTSize == 64)
11696 return std::make_pair(0U, &AArch64::FPR64RegClass);
11697 if (VTSize == 128)
11698 return std::make_pair(0U, &AArch64::FPR128RegClass);
11699 break;
11700 }
11701 // The instructions that this constraint is designed for can
11702 // only take 128-bit registers so just use that regclass.
11703 case 'x':
11704 if (!Subtarget->hasFPARMv8())
11705 break;
11706 if (VT.isScalableVector())
11707 return std::make_pair(0U, &AArch64::ZPR_4bRegClass);
11708 if (VT.getSizeInBits() == 128)
11709 return std::make_pair(0U, &AArch64::FPR128_loRegClass);
11710 break;
11711 case 'y':
11712 if (!Subtarget->hasFPARMv8())
11713 break;
11714 if (VT.isScalableVector())
11715 return std::make_pair(0U, &AArch64::ZPR_3bRegClass);
11716 break;
11717 }
11718 } else {
11719 if (const auto PC = parsePredicateConstraint(Constraint))
11720 if (const auto *RegClass = getPredicateRegisterClass(*PC, VT))
11721 return std::make_pair(0U, RegClass);
11722
11723 if (const auto RGC = parseReducedGprConstraint(Constraint))
11724 if (const auto *RegClass = getReducedGprRegisterClass(*RGC, VT))
11725 return std::make_pair(0U, RegClass);
11726 }
11727 if (StringRef("{cc}").equals_insensitive(Constraint) ||
11728 parseConstraintCode(Constraint) != AArch64CC::Invalid)
11729 return std::make_pair(unsigned(AArch64::NZCV), &AArch64::CCRRegClass);
11730
11731 if (Constraint == "{za}") {
11732 return std::make_pair(unsigned(AArch64::ZA), &AArch64::MPRRegClass);
11733 }
11734
11735 if (Constraint == "{zt0}") {
11736 return std::make_pair(unsigned(AArch64::ZT0), &AArch64::ZTRRegClass);
11737 }
11738
11739 // Use the default implementation in TargetLowering to convert the register
11740 // constraint into a member of a register class.
11741 std::pair<unsigned, const TargetRegisterClass *> Res;
11742 Res = TargetLowering::getRegForInlineAsmConstraint(TRI, Constraint, VT);
11743
11744 // Not found as a standard register?
11745 if (!Res.second) {
11746 unsigned Size = Constraint.size();
11747 if ((Size == 4 || Size == 5) && Constraint[0] == '{' &&
11748 tolower(Constraint[1]) == 'v' && Constraint[Size - 1] == '}') {
11749 int RegNo;
11750 bool Failed = Constraint.slice(2, Size - 1).getAsInteger(10, RegNo);
11751 if (!Failed && RegNo >= 0 && RegNo <= 31) {
11752 // v0 - v31 are aliases of q0 - q31 or d0 - d31 depending on size.
11753 // By default we'll emit v0-v31 for this unless there's a modifier where
11754 // we'll emit the correct register as well.
11755 if (VT != MVT::Other && VT.getSizeInBits() == 64) {
11756 Res.first = AArch64::FPR64RegClass.getRegister(RegNo);
11757 Res.second = &AArch64::FPR64RegClass;
11758 } else {
11759 Res.first = AArch64::FPR128RegClass.getRegister(RegNo);
11760 Res.second = &AArch64::FPR128RegClass;
11761 }
11762 }
11763 }
11764 }
11765
11766 if (Res.second && !Subtarget->hasFPARMv8() &&
11767 !AArch64::GPR32allRegClass.hasSubClassEq(Res.second) &&
11768 !AArch64::GPR64allRegClass.hasSubClassEq(Res.second))
11769 return std::make_pair(0U, nullptr);
11770
11771 return Res;
11772 }
11773
getAsmOperandValueType(const DataLayout & DL,llvm::Type * Ty,bool AllowUnknown) const11774 EVT AArch64TargetLowering::getAsmOperandValueType(const DataLayout &DL,
11775 llvm::Type *Ty,
11776 bool AllowUnknown) const {
11777 if (Subtarget->hasLS64() && Ty->isIntegerTy(512))
11778 return EVT(MVT::i64x8);
11779
11780 return TargetLowering::getAsmOperandValueType(DL, Ty, AllowUnknown);
11781 }
11782
11783 /// LowerAsmOperandForConstraint - Lower the specified operand into the Ops
11784 /// vector. If it is invalid, don't add anything to Ops.
LowerAsmOperandForConstraint(SDValue Op,StringRef Constraint,std::vector<SDValue> & Ops,SelectionDAG & DAG) const11785 void AArch64TargetLowering::LowerAsmOperandForConstraint(
11786 SDValue Op, StringRef Constraint, std::vector<SDValue> &Ops,
11787 SelectionDAG &DAG) const {
11788 SDValue Result;
11789
11790 // Currently only support length 1 constraints.
11791 if (Constraint.size() != 1)
11792 return;
11793
11794 char ConstraintLetter = Constraint[0];
11795 switch (ConstraintLetter) {
11796 default:
11797 break;
11798
11799 // This set of constraints deal with valid constants for various instructions.
11800 // Validate and return a target constant for them if we can.
11801 case 'z': {
11802 // 'z' maps to xzr or wzr so it needs an input of 0.
11803 if (!isNullConstant(Op))
11804 return;
11805
11806 if (Op.getValueType() == MVT::i64)
11807 Result = DAG.getRegister(AArch64::XZR, MVT::i64);
11808 else
11809 Result = DAG.getRegister(AArch64::WZR, MVT::i32);
11810 break;
11811 }
11812 case 'S':
11813 // Use the generic code path for "s". In GCC's aarch64 port, "S" is
11814 // supported for PIC while "s" isn't, making "s" less useful. We implement
11815 // "S" but not "s".
11816 TargetLowering::LowerAsmOperandForConstraint(Op, "s", Ops, DAG);
11817 break;
11818
11819 case 'I':
11820 case 'J':
11821 case 'K':
11822 case 'L':
11823 case 'M':
11824 case 'N':
11825 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Op);
11826 if (!C)
11827 return;
11828
11829 // Grab the value and do some validation.
11830 uint64_t CVal = C->getZExtValue();
11831 switch (ConstraintLetter) {
11832 // The I constraint applies only to simple ADD or SUB immediate operands:
11833 // i.e. 0 to 4095 with optional shift by 12
11834 // The J constraint applies only to ADD or SUB immediates that would be
11835 // valid when negated, i.e. if [an add pattern] were to be output as a SUB
11836 // instruction [or vice versa], in other words -1 to -4095 with optional
11837 // left shift by 12.
11838 case 'I':
11839 if (isUInt<12>(CVal) || isShiftedUInt<12, 12>(CVal))
11840 break;
11841 return;
11842 case 'J': {
11843 uint64_t NVal = -C->getSExtValue();
11844 if (isUInt<12>(NVal) || isShiftedUInt<12, 12>(NVal)) {
11845 CVal = C->getSExtValue();
11846 break;
11847 }
11848 return;
11849 }
11850 // The K and L constraints apply *only* to logical immediates, including
11851 // what used to be the MOVI alias for ORR (though the MOVI alias has now
11852 // been removed and MOV should be used). So these constraints have to
11853 // distinguish between bit patterns that are valid 32-bit or 64-bit
11854 // "bitmask immediates": for example 0xaaaaaaaa is a valid bimm32 (K), but
11855 // not a valid bimm64 (L) where 0xaaaaaaaaaaaaaaaa would be valid, and vice
11856 // versa.
11857 case 'K':
11858 if (AArch64_AM::isLogicalImmediate(CVal, 32))
11859 break;
11860 return;
11861 case 'L':
11862 if (AArch64_AM::isLogicalImmediate(CVal, 64))
11863 break;
11864 return;
11865 // The M and N constraints are a superset of K and L respectively, for use
11866 // with the MOV (immediate) alias. As well as the logical immediates they
11867 // also match 32 or 64-bit immediates that can be loaded either using a
11868 // *single* MOVZ or MOVN , such as 32-bit 0x12340000, 0x00001234, 0xffffedca
11869 // (M) or 64-bit 0x1234000000000000 (N) etc.
11870 // As a note some of this code is liberally stolen from the asm parser.
11871 case 'M': {
11872 if (!isUInt<32>(CVal))
11873 return;
11874 if (AArch64_AM::isLogicalImmediate(CVal, 32))
11875 break;
11876 if ((CVal & 0xFFFF) == CVal)
11877 break;
11878 if ((CVal & 0xFFFF0000ULL) == CVal)
11879 break;
11880 uint64_t NCVal = ~(uint32_t)CVal;
11881 if ((NCVal & 0xFFFFULL) == NCVal)
11882 break;
11883 if ((NCVal & 0xFFFF0000ULL) == NCVal)
11884 break;
11885 return;
11886 }
11887 case 'N': {
11888 if (AArch64_AM::isLogicalImmediate(CVal, 64))
11889 break;
11890 if ((CVal & 0xFFFFULL) == CVal)
11891 break;
11892 if ((CVal & 0xFFFF0000ULL) == CVal)
11893 break;
11894 if ((CVal & 0xFFFF00000000ULL) == CVal)
11895 break;
11896 if ((CVal & 0xFFFF000000000000ULL) == CVal)
11897 break;
11898 uint64_t NCVal = ~CVal;
11899 if ((NCVal & 0xFFFFULL) == NCVal)
11900 break;
11901 if ((NCVal & 0xFFFF0000ULL) == NCVal)
11902 break;
11903 if ((NCVal & 0xFFFF00000000ULL) == NCVal)
11904 break;
11905 if ((NCVal & 0xFFFF000000000000ULL) == NCVal)
11906 break;
11907 return;
11908 }
11909 default:
11910 return;
11911 }
11912
11913 // All assembler immediates are 64-bit integers.
11914 Result = DAG.getTargetConstant(CVal, SDLoc(Op), MVT::i64);
11915 break;
11916 }
11917
11918 if (Result.getNode()) {
11919 Ops.push_back(Result);
11920 return;
11921 }
11922
11923 return TargetLowering::LowerAsmOperandForConstraint(Op, Constraint, Ops, DAG);
11924 }
11925
11926 //===----------------------------------------------------------------------===//
11927 // AArch64 Advanced SIMD Support
11928 //===----------------------------------------------------------------------===//
11929
11930 /// WidenVector - Given a value in the V64 register class, produce the
11931 /// equivalent value in the V128 register class.
WidenVector(SDValue V64Reg,SelectionDAG & DAG)11932 static SDValue WidenVector(SDValue V64Reg, SelectionDAG &DAG) {
11933 EVT VT = V64Reg.getValueType();
11934 unsigned NarrowSize = VT.getVectorNumElements();
11935 MVT EltTy = VT.getVectorElementType().getSimpleVT();
11936 MVT WideTy = MVT::getVectorVT(EltTy, 2 * NarrowSize);
11937 SDLoc DL(V64Reg);
11938
11939 return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, WideTy, DAG.getUNDEF(WideTy),
11940 V64Reg, DAG.getConstant(0, DL, MVT::i64));
11941 }
11942
11943 /// getExtFactor - Determine the adjustment factor for the position when
11944 /// generating an "extract from vector registers" instruction.
getExtFactor(SDValue & V)11945 static unsigned getExtFactor(SDValue &V) {
11946 EVT EltType = V.getValueType().getVectorElementType();
11947 return EltType.getSizeInBits() / 8;
11948 }
11949
11950 // Check if a vector is built from one vector via extracted elements of
11951 // another together with an AND mask, ensuring that all elements fit
11952 // within range. This can be reconstructed using AND and NEON's TBL1.
ReconstructShuffleWithRuntimeMask(SDValue Op,SelectionDAG & DAG)11953 SDValue ReconstructShuffleWithRuntimeMask(SDValue Op, SelectionDAG &DAG) {
11954 assert(Op.getOpcode() == ISD::BUILD_VECTOR && "Unknown opcode!");
11955 SDLoc dl(Op);
11956 EVT VT = Op.getValueType();
11957 assert(!VT.isScalableVector() &&
11958 "Scalable vectors cannot be used with ISD::BUILD_VECTOR");
11959
11960 // Can only recreate a shuffle with 16xi8 or 8xi8 elements, as they map
11961 // directly to TBL1.
11962 if (VT != MVT::v16i8 && VT != MVT::v8i8)
11963 return SDValue();
11964
11965 unsigned NumElts = VT.getVectorNumElements();
11966 assert((NumElts == 8 || NumElts == 16) &&
11967 "Need to have exactly 8 or 16 elements in vector.");
11968
11969 SDValue SourceVec;
11970 SDValue MaskSourceVec;
11971 SmallVector<SDValue, 16> AndMaskConstants;
11972
11973 for (unsigned i = 0; i < NumElts; ++i) {
11974 SDValue V = Op.getOperand(i);
11975 if (V.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
11976 return SDValue();
11977
11978 SDValue OperandSourceVec = V.getOperand(0);
11979 if (!SourceVec)
11980 SourceVec = OperandSourceVec;
11981 else if (SourceVec != OperandSourceVec)
11982 return SDValue();
11983
11984 // This only looks at shuffles with elements that are
11985 // a) truncated by a constant AND mask extracted from a mask vector, or
11986 // b) extracted directly from a mask vector.
11987 SDValue MaskSource = V.getOperand(1);
11988 if (MaskSource.getOpcode() == ISD::AND) {
11989 if (!isa<ConstantSDNode>(MaskSource.getOperand(1)))
11990 return SDValue();
11991
11992 AndMaskConstants.push_back(MaskSource.getOperand(1));
11993 MaskSource = MaskSource->getOperand(0);
11994 } else if (!AndMaskConstants.empty()) {
11995 // Either all or no operands should have an AND mask.
11996 return SDValue();
11997 }
11998
11999 // An ANY_EXTEND may be inserted between the AND and the source vector
12000 // extraction. We don't care about that, so we can just skip it.
12001 if (MaskSource.getOpcode() == ISD::ANY_EXTEND)
12002 MaskSource = MaskSource.getOperand(0);
12003
12004 if (MaskSource.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
12005 return SDValue();
12006
12007 SDValue MaskIdx = MaskSource.getOperand(1);
12008 if (!isa<ConstantSDNode>(MaskIdx) ||
12009 !cast<ConstantSDNode>(MaskIdx)->getConstantIntValue()->equalsInt(i))
12010 return SDValue();
12011
12012 // We only apply this if all elements come from the same vector with the
12013 // same vector type.
12014 if (!MaskSourceVec) {
12015 MaskSourceVec = MaskSource->getOperand(0);
12016 if (MaskSourceVec.getValueType() != VT)
12017 return SDValue();
12018 } else if (MaskSourceVec != MaskSource->getOperand(0)) {
12019 return SDValue();
12020 }
12021 }
12022
12023 // We need a v16i8 for TBL, so we extend the source with a placeholder vector
12024 // for v8i8 to get a v16i8. As the pattern we are replacing is extract +
12025 // insert, we know that the index in the mask must be smaller than the number
12026 // of elements in the source, or we would have an out-of-bounds access.
12027 if (NumElts == 8)
12028 SourceVec = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v16i8, SourceVec,
12029 DAG.getUNDEF(VT));
12030
12031 // Preconditions met, so we can use a vector (AND +) TBL to build this vector.
12032 if (!AndMaskConstants.empty())
12033 MaskSourceVec = DAG.getNode(ISD::AND, dl, VT, MaskSourceVec,
12034 DAG.getBuildVector(VT, dl, AndMaskConstants));
12035
12036 return DAG.getNode(
12037 ISD::INTRINSIC_WO_CHAIN, dl, VT,
12038 DAG.getConstant(Intrinsic::aarch64_neon_tbl1, dl, MVT::i32), SourceVec,
12039 MaskSourceVec);
12040 }
12041
12042 // Gather data to see if the operation can be modelled as a
12043 // shuffle in combination with VEXTs.
ReconstructShuffle(SDValue Op,SelectionDAG & DAG) const12044 SDValue AArch64TargetLowering::ReconstructShuffle(SDValue Op,
12045 SelectionDAG &DAG) const {
12046 assert(Op.getOpcode() == ISD::BUILD_VECTOR && "Unknown opcode!");
12047 LLVM_DEBUG(dbgs() << "AArch64TargetLowering::ReconstructShuffle\n");
12048 SDLoc dl(Op);
12049 EVT VT = Op.getValueType();
12050 assert(!VT.isScalableVector() &&
12051 "Scalable vectors cannot be used with ISD::BUILD_VECTOR");
12052 unsigned NumElts = VT.getVectorNumElements();
12053
12054 struct ShuffleSourceInfo {
12055 SDValue Vec;
12056 unsigned MinElt;
12057 unsigned MaxElt;
12058
12059 // We may insert some combination of BITCASTs and VEXT nodes to force Vec to
12060 // be compatible with the shuffle we intend to construct. As a result
12061 // ShuffleVec will be some sliding window into the original Vec.
12062 SDValue ShuffleVec;
12063
12064 // Code should guarantee that element i in Vec starts at element "WindowBase
12065 // + i * WindowScale in ShuffleVec".
12066 int WindowBase;
12067 int WindowScale;
12068
12069 ShuffleSourceInfo(SDValue Vec)
12070 : Vec(Vec), MinElt(std::numeric_limits<unsigned>::max()), MaxElt(0),
12071 ShuffleVec(Vec), WindowBase(0), WindowScale(1) {}
12072
12073 bool operator ==(SDValue OtherVec) { return Vec == OtherVec; }
12074 };
12075
12076 // First gather all vectors used as an immediate source for this BUILD_VECTOR
12077 // node.
12078 SmallVector<ShuffleSourceInfo, 2> Sources;
12079 for (unsigned i = 0; i < NumElts; ++i) {
12080 SDValue V = Op.getOperand(i);
12081 if (V.isUndef())
12082 continue;
12083 else if (V.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
12084 !isa<ConstantSDNode>(V.getOperand(1)) ||
12085 V.getOperand(0).getValueType().isScalableVector()) {
12086 LLVM_DEBUG(
12087 dbgs() << "Reshuffle failed: "
12088 "a shuffle can only come from building a vector from "
12089 "various elements of other fixed-width vectors, provided "
12090 "their indices are constant\n");
12091 return SDValue();
12092 }
12093
12094 // Add this element source to the list if it's not already there.
12095 SDValue SourceVec = V.getOperand(0);
12096 auto Source = find(Sources, SourceVec);
12097 if (Source == Sources.end())
12098 Source = Sources.insert(Sources.end(), ShuffleSourceInfo(SourceVec));
12099
12100 // Update the minimum and maximum lane number seen.
12101 unsigned EltNo = V.getConstantOperandVal(1);
12102 Source->MinElt = std::min(Source->MinElt, EltNo);
12103 Source->MaxElt = std::max(Source->MaxElt, EltNo);
12104 }
12105
12106 // If we have 3 or 4 sources, try to generate a TBL, which will at least be
12107 // better than moving to/from gpr registers for larger vectors.
12108 if ((Sources.size() == 3 || Sources.size() == 4) && NumElts > 4) {
12109 // Construct a mask for the tbl. We may need to adjust the index for types
12110 // larger than i8.
12111 SmallVector<unsigned, 16> Mask;
12112 unsigned OutputFactor = VT.getScalarSizeInBits() / 8;
12113 for (unsigned I = 0; I < NumElts; ++I) {
12114 SDValue V = Op.getOperand(I);
12115 if (V.isUndef()) {
12116 for (unsigned OF = 0; OF < OutputFactor; OF++)
12117 Mask.push_back(-1);
12118 continue;
12119 }
12120 // Set the Mask lanes adjusted for the size of the input and output
12121 // lanes. The Mask is always i8, so it will set OutputFactor lanes per
12122 // output element, adjusted in their positions per input and output types.
12123 unsigned Lane = V.getConstantOperandVal(1);
12124 for (unsigned S = 0; S < Sources.size(); S++) {
12125 if (V.getOperand(0) == Sources[S].Vec) {
12126 unsigned InputSize = Sources[S].Vec.getScalarValueSizeInBits();
12127 unsigned InputBase = 16 * S + Lane * InputSize / 8;
12128 for (unsigned OF = 0; OF < OutputFactor; OF++)
12129 Mask.push_back(InputBase + OF);
12130 break;
12131 }
12132 }
12133 }
12134
12135 // Construct the tbl3/tbl4 out of an intrinsic, the sources converted to
12136 // v16i8, and the TBLMask
12137 SmallVector<SDValue, 16> TBLOperands;
12138 TBLOperands.push_back(DAG.getConstant(Sources.size() == 3
12139 ? Intrinsic::aarch64_neon_tbl3
12140 : Intrinsic::aarch64_neon_tbl4,
12141 dl, MVT::i32));
12142 for (unsigned i = 0; i < Sources.size(); i++) {
12143 SDValue Src = Sources[i].Vec;
12144 EVT SrcVT = Src.getValueType();
12145 Src = DAG.getBitcast(SrcVT.is64BitVector() ? MVT::v8i8 : MVT::v16i8, Src);
12146 assert((SrcVT.is64BitVector() || SrcVT.is128BitVector()) &&
12147 "Expected a legally typed vector");
12148 if (SrcVT.is64BitVector())
12149 Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v16i8, Src,
12150 DAG.getUNDEF(MVT::v8i8));
12151 TBLOperands.push_back(Src);
12152 }
12153
12154 SmallVector<SDValue, 16> TBLMask;
12155 for (unsigned i = 0; i < Mask.size(); i++)
12156 TBLMask.push_back(DAG.getConstant(Mask[i], dl, MVT::i32));
12157 assert((Mask.size() == 8 || Mask.size() == 16) &&
12158 "Expected a v8i8 or v16i8 Mask");
12159 TBLOperands.push_back(
12160 DAG.getBuildVector(Mask.size() == 8 ? MVT::v8i8 : MVT::v16i8, dl, TBLMask));
12161
12162 SDValue Shuffle =
12163 DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl,
12164 Mask.size() == 8 ? MVT::v8i8 : MVT::v16i8, TBLOperands);
12165 return DAG.getBitcast(VT, Shuffle);
12166 }
12167
12168 if (Sources.size() > 2) {
12169 LLVM_DEBUG(dbgs() << "Reshuffle failed: currently only do something "
12170 << "sensible when at most two source vectors are "
12171 << "involved\n");
12172 return SDValue();
12173 }
12174
12175 // Find out the smallest element size among result and two sources, and use
12176 // it as element size to build the shuffle_vector.
12177 EVT SmallestEltTy = VT.getVectorElementType();
12178 for (auto &Source : Sources) {
12179 EVT SrcEltTy = Source.Vec.getValueType().getVectorElementType();
12180 if (SrcEltTy.bitsLT(SmallestEltTy)) {
12181 SmallestEltTy = SrcEltTy;
12182 }
12183 }
12184 unsigned ResMultiplier =
12185 VT.getScalarSizeInBits() / SmallestEltTy.getFixedSizeInBits();
12186 uint64_t VTSize = VT.getFixedSizeInBits();
12187 NumElts = VTSize / SmallestEltTy.getFixedSizeInBits();
12188 EVT ShuffleVT = EVT::getVectorVT(*DAG.getContext(), SmallestEltTy, NumElts);
12189
12190 // If the source vector is too wide or too narrow, we may nevertheless be able
12191 // to construct a compatible shuffle either by concatenating it with UNDEF or
12192 // extracting a suitable range of elements.
12193 for (auto &Src : Sources) {
12194 EVT SrcVT = Src.ShuffleVec.getValueType();
12195
12196 TypeSize SrcVTSize = SrcVT.getSizeInBits();
12197 if (SrcVTSize == TypeSize::getFixed(VTSize))
12198 continue;
12199
12200 // This stage of the search produces a source with the same element type as
12201 // the original, but with a total width matching the BUILD_VECTOR output.
12202 EVT EltVT = SrcVT.getVectorElementType();
12203 unsigned NumSrcElts = VTSize / EltVT.getFixedSizeInBits();
12204 EVT DestVT = EVT::getVectorVT(*DAG.getContext(), EltVT, NumSrcElts);
12205
12206 if (SrcVTSize.getFixedValue() < VTSize) {
12207 assert(2 * SrcVTSize == VTSize);
12208 // We can pad out the smaller vector for free, so if it's part of a
12209 // shuffle...
12210 Src.ShuffleVec =
12211 DAG.getNode(ISD::CONCAT_VECTORS, dl, DestVT, Src.ShuffleVec,
12212 DAG.getUNDEF(Src.ShuffleVec.getValueType()));
12213 continue;
12214 }
12215
12216 if (SrcVTSize.getFixedValue() != 2 * VTSize) {
12217 LLVM_DEBUG(
12218 dbgs() << "Reshuffle failed: result vector too small to extract\n");
12219 return SDValue();
12220 }
12221
12222 if (Src.MaxElt - Src.MinElt >= NumSrcElts) {
12223 LLVM_DEBUG(
12224 dbgs() << "Reshuffle failed: span too large for a VEXT to cope\n");
12225 return SDValue();
12226 }
12227
12228 if (Src.MinElt >= NumSrcElts) {
12229 // The extraction can just take the second half
12230 Src.ShuffleVec =
12231 DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, DestVT, Src.ShuffleVec,
12232 DAG.getConstant(NumSrcElts, dl, MVT::i64));
12233 Src.WindowBase = -NumSrcElts;
12234 } else if (Src.MaxElt < NumSrcElts) {
12235 // The extraction can just take the first half
12236 Src.ShuffleVec =
12237 DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, DestVT, Src.ShuffleVec,
12238 DAG.getConstant(0, dl, MVT::i64));
12239 } else {
12240 // An actual VEXT is needed
12241 SDValue VEXTSrc1 =
12242 DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, DestVT, Src.ShuffleVec,
12243 DAG.getConstant(0, dl, MVT::i64));
12244 SDValue VEXTSrc2 =
12245 DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, DestVT, Src.ShuffleVec,
12246 DAG.getConstant(NumSrcElts, dl, MVT::i64));
12247 unsigned Imm = Src.MinElt * getExtFactor(VEXTSrc1);
12248
12249 if (!SrcVT.is64BitVector()) {
12250 LLVM_DEBUG(
12251 dbgs() << "Reshuffle failed: don't know how to lower AArch64ISD::EXT "
12252 "for SVE vectors.");
12253 return SDValue();
12254 }
12255
12256 Src.ShuffleVec = DAG.getNode(AArch64ISD::EXT, dl, DestVT, VEXTSrc1,
12257 VEXTSrc2,
12258 DAG.getConstant(Imm, dl, MVT::i32));
12259 Src.WindowBase = -Src.MinElt;
12260 }
12261 }
12262
12263 // Another possible incompatibility occurs from the vector element types. We
12264 // can fix this by bitcasting the source vectors to the same type we intend
12265 // for the shuffle.
12266 for (auto &Src : Sources) {
12267 EVT SrcEltTy = Src.ShuffleVec.getValueType().getVectorElementType();
12268 if (SrcEltTy == SmallestEltTy)
12269 continue;
12270 assert(ShuffleVT.getVectorElementType() == SmallestEltTy);
12271 if (DAG.getDataLayout().isBigEndian()) {
12272 Src.ShuffleVec =
12273 DAG.getNode(AArch64ISD::NVCAST, dl, ShuffleVT, Src.ShuffleVec);
12274 } else {
12275 Src.ShuffleVec = DAG.getNode(ISD::BITCAST, dl, ShuffleVT, Src.ShuffleVec);
12276 }
12277 Src.WindowScale =
12278 SrcEltTy.getFixedSizeInBits() / SmallestEltTy.getFixedSizeInBits();
12279 Src.WindowBase *= Src.WindowScale;
12280 }
12281
12282 // Final check before we try to actually produce a shuffle.
12283 LLVM_DEBUG(for (auto Src
12284 : Sources)
12285 assert(Src.ShuffleVec.getValueType() == ShuffleVT););
12286
12287 // The stars all align, our next step is to produce the mask for the shuffle.
12288 SmallVector<int, 8> Mask(ShuffleVT.getVectorNumElements(), -1);
12289 int BitsPerShuffleLane = ShuffleVT.getScalarSizeInBits();
12290 for (unsigned i = 0; i < VT.getVectorNumElements(); ++i) {
12291 SDValue Entry = Op.getOperand(i);
12292 if (Entry.isUndef())
12293 continue;
12294
12295 auto Src = find(Sources, Entry.getOperand(0));
12296 int EltNo = cast<ConstantSDNode>(Entry.getOperand(1))->getSExtValue();
12297
12298 // EXTRACT_VECTOR_ELT performs an implicit any_ext; BUILD_VECTOR an implicit
12299 // trunc. So only std::min(SrcBits, DestBits) actually get defined in this
12300 // segment.
12301 EVT OrigEltTy = Entry.getOperand(0).getValueType().getVectorElementType();
12302 int BitsDefined = std::min(OrigEltTy.getScalarSizeInBits(),
12303 VT.getScalarSizeInBits());
12304 int LanesDefined = BitsDefined / BitsPerShuffleLane;
12305
12306 // This source is expected to fill ResMultiplier lanes of the final shuffle,
12307 // starting at the appropriate offset.
12308 int *LaneMask = &Mask[i * ResMultiplier];
12309
12310 int ExtractBase = EltNo * Src->WindowScale + Src->WindowBase;
12311 ExtractBase += NumElts * (Src - Sources.begin());
12312 for (int j = 0; j < LanesDefined; ++j)
12313 LaneMask[j] = ExtractBase + j;
12314 }
12315
12316 // Final check before we try to produce nonsense...
12317 if (!isShuffleMaskLegal(Mask, ShuffleVT)) {
12318 LLVM_DEBUG(dbgs() << "Reshuffle failed: illegal shuffle mask\n");
12319 return SDValue();
12320 }
12321
12322 SDValue ShuffleOps[] = { DAG.getUNDEF(ShuffleVT), DAG.getUNDEF(ShuffleVT) };
12323 for (unsigned i = 0; i < Sources.size(); ++i)
12324 ShuffleOps[i] = Sources[i].ShuffleVec;
12325
12326 SDValue Shuffle = DAG.getVectorShuffle(ShuffleVT, dl, ShuffleOps[0],
12327 ShuffleOps[1], Mask);
12328 SDValue V;
12329 if (DAG.getDataLayout().isBigEndian()) {
12330 V = DAG.getNode(AArch64ISD::NVCAST, dl, VT, Shuffle);
12331 } else {
12332 V = DAG.getNode(ISD::BITCAST, dl, VT, Shuffle);
12333 }
12334
12335 LLVM_DEBUG(dbgs() << "Reshuffle, creating node: "; Shuffle.dump();
12336 dbgs() << "Reshuffle, creating node: "; V.dump(););
12337
12338 return V;
12339 }
12340
12341 // check if an EXT instruction can handle the shuffle mask when the
12342 // vector sources of the shuffle are the same.
isSingletonEXTMask(ArrayRef<int> M,EVT VT,unsigned & Imm)12343 static bool isSingletonEXTMask(ArrayRef<int> M, EVT VT, unsigned &Imm) {
12344 unsigned NumElts = VT.getVectorNumElements();
12345
12346 // Assume that the first shuffle index is not UNDEF. Fail if it is.
12347 if (M[0] < 0)
12348 return false;
12349
12350 Imm = M[0];
12351
12352 // If this is a VEXT shuffle, the immediate value is the index of the first
12353 // element. The other shuffle indices must be the successive elements after
12354 // the first one.
12355 unsigned ExpectedElt = Imm;
12356 for (unsigned i = 1; i < NumElts; ++i) {
12357 // Increment the expected index. If it wraps around, just follow it
12358 // back to index zero and keep going.
12359 ++ExpectedElt;
12360 if (ExpectedElt == NumElts)
12361 ExpectedElt = 0;
12362
12363 if (M[i] < 0)
12364 continue; // ignore UNDEF indices
12365 if (ExpectedElt != static_cast<unsigned>(M[i]))
12366 return false;
12367 }
12368
12369 return true;
12370 }
12371
12372 // Detect patterns of a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3,d0,d1,d2,d3 from
12373 // v4i32s. This is really a truncate, which we can construct out of (legal)
12374 // concats and truncate nodes.
ReconstructTruncateFromBuildVector(SDValue V,SelectionDAG & DAG)12375 static SDValue ReconstructTruncateFromBuildVector(SDValue V, SelectionDAG &DAG) {
12376 if (V.getValueType() != MVT::v16i8)
12377 return SDValue();
12378 assert(V.getNumOperands() == 16 && "Expected 16 operands on the BUILDVECTOR");
12379
12380 for (unsigned X = 0; X < 4; X++) {
12381 // Check the first item in each group is an extract from lane 0 of a v4i32
12382 // or v4i16.
12383 SDValue BaseExt = V.getOperand(X * 4);
12384 if (BaseExt.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
12385 (BaseExt.getOperand(0).getValueType() != MVT::v4i16 &&
12386 BaseExt.getOperand(0).getValueType() != MVT::v4i32) ||
12387 !isa<ConstantSDNode>(BaseExt.getOperand(1)) ||
12388 BaseExt.getConstantOperandVal(1) != 0)
12389 return SDValue();
12390 SDValue Base = BaseExt.getOperand(0);
12391 // And check the other items are extracts from the same vector.
12392 for (unsigned Y = 1; Y < 4; Y++) {
12393 SDValue Ext = V.getOperand(X * 4 + Y);
12394 if (Ext.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
12395 Ext.getOperand(0) != Base ||
12396 !isa<ConstantSDNode>(Ext.getOperand(1)) ||
12397 Ext.getConstantOperandVal(1) != Y)
12398 return SDValue();
12399 }
12400 }
12401
12402 // Turn the buildvector into a series of truncates and concates, which will
12403 // become uzip1's. Any v4i32s we found get truncated to v4i16, which are
12404 // concat together to produce 2 v8i16. These are both truncated and concat
12405 // together.
12406 SDLoc DL(V);
12407 SDValue Trunc[4] = {
12408 V.getOperand(0).getOperand(0), V.getOperand(4).getOperand(0),
12409 V.getOperand(8).getOperand(0), V.getOperand(12).getOperand(0)};
12410 for (SDValue &V : Trunc)
12411 if (V.getValueType() == MVT::v4i32)
12412 V = DAG.getNode(ISD::TRUNCATE, DL, MVT::v4i16, V);
12413 SDValue Concat0 =
12414 DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8i16, Trunc[0], Trunc[1]);
12415 SDValue Concat1 =
12416 DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v8i16, Trunc[2], Trunc[3]);
12417 SDValue Trunc0 = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i8, Concat0);
12418 SDValue Trunc1 = DAG.getNode(ISD::TRUNCATE, DL, MVT::v8i8, Concat1);
12419 return DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, Trunc0, Trunc1);
12420 }
12421
12422 /// Check if a vector shuffle corresponds to a DUP instructions with a larger
12423 /// element width than the vector lane type. If that is the case the function
12424 /// returns true and writes the value of the DUP instruction lane operand into
12425 /// DupLaneOp
isWideDUPMask(ArrayRef<int> M,EVT VT,unsigned BlockSize,unsigned & DupLaneOp)12426 static bool isWideDUPMask(ArrayRef<int> M, EVT VT, unsigned BlockSize,
12427 unsigned &DupLaneOp) {
12428 assert((BlockSize == 16 || BlockSize == 32 || BlockSize == 64) &&
12429 "Only possible block sizes for wide DUP are: 16, 32, 64");
12430
12431 if (BlockSize <= VT.getScalarSizeInBits())
12432 return false;
12433 if (BlockSize % VT.getScalarSizeInBits() != 0)
12434 return false;
12435 if (VT.getSizeInBits() % BlockSize != 0)
12436 return false;
12437
12438 size_t SingleVecNumElements = VT.getVectorNumElements();
12439 size_t NumEltsPerBlock = BlockSize / VT.getScalarSizeInBits();
12440 size_t NumBlocks = VT.getSizeInBits() / BlockSize;
12441
12442 // We are looking for masks like
12443 // [0, 1, 0, 1] or [2, 3, 2, 3] or [4, 5, 6, 7, 4, 5, 6, 7] where any element
12444 // might be replaced by 'undefined'. BlockIndices will eventually contain
12445 // lane indices of the duplicated block (i.e. [0, 1], [2, 3] and [4, 5, 6, 7]
12446 // for the above examples)
12447 SmallVector<int, 8> BlockElts(NumEltsPerBlock, -1);
12448 for (size_t BlockIndex = 0; BlockIndex < NumBlocks; BlockIndex++)
12449 for (size_t I = 0; I < NumEltsPerBlock; I++) {
12450 int Elt = M[BlockIndex * NumEltsPerBlock + I];
12451 if (Elt < 0)
12452 continue;
12453 // For now we don't support shuffles that use the second operand
12454 if ((unsigned)Elt >= SingleVecNumElements)
12455 return false;
12456 if (BlockElts[I] < 0)
12457 BlockElts[I] = Elt;
12458 else if (BlockElts[I] != Elt)
12459 return false;
12460 }
12461
12462 // We found a candidate block (possibly with some undefs). It must be a
12463 // sequence of consecutive integers starting with a value divisible by
12464 // NumEltsPerBlock with some values possibly replaced by undef-s.
12465
12466 // Find first non-undef element
12467 auto FirstRealEltIter = find_if(BlockElts, [](int Elt) { return Elt >= 0; });
12468 assert(FirstRealEltIter != BlockElts.end() &&
12469 "Shuffle with all-undefs must have been caught by previous cases, "
12470 "e.g. isSplat()");
12471 if (FirstRealEltIter == BlockElts.end()) {
12472 DupLaneOp = 0;
12473 return true;
12474 }
12475
12476 // Index of FirstRealElt in BlockElts
12477 size_t FirstRealIndex = FirstRealEltIter - BlockElts.begin();
12478
12479 if ((unsigned)*FirstRealEltIter < FirstRealIndex)
12480 return false;
12481 // BlockElts[0] must have the following value if it isn't undef:
12482 size_t Elt0 = *FirstRealEltIter - FirstRealIndex;
12483
12484 // Check the first element
12485 if (Elt0 % NumEltsPerBlock != 0)
12486 return false;
12487 // Check that the sequence indeed consists of consecutive integers (modulo
12488 // undefs)
12489 for (size_t I = 0; I < NumEltsPerBlock; I++)
12490 if (BlockElts[I] >= 0 && (unsigned)BlockElts[I] != Elt0 + I)
12491 return false;
12492
12493 DupLaneOp = Elt0 / NumEltsPerBlock;
12494 return true;
12495 }
12496
12497 // check if an EXT instruction can handle the shuffle mask when the
12498 // vector sources of the shuffle are different.
isEXTMask(ArrayRef<int> M,EVT VT,bool & ReverseEXT,unsigned & Imm)12499 static bool isEXTMask(ArrayRef<int> M, EVT VT, bool &ReverseEXT,
12500 unsigned &Imm) {
12501 // Look for the first non-undef element.
12502 const int *FirstRealElt = find_if(M, [](int Elt) { return Elt >= 0; });
12503
12504 // Benefit form APInt to handle overflow when calculating expected element.
12505 unsigned NumElts = VT.getVectorNumElements();
12506 unsigned MaskBits = APInt(32, NumElts * 2).logBase2();
12507 APInt ExpectedElt = APInt(MaskBits, *FirstRealElt + 1);
12508 // The following shuffle indices must be the successive elements after the
12509 // first real element.
12510 bool FoundWrongElt = std::any_of(FirstRealElt + 1, M.end(), [&](int Elt) {
12511 return Elt != ExpectedElt++ && Elt != -1;
12512 });
12513 if (FoundWrongElt)
12514 return false;
12515
12516 // The index of an EXT is the first element if it is not UNDEF.
12517 // Watch out for the beginning UNDEFs. The EXT index should be the expected
12518 // value of the first element. E.g.
12519 // <-1, -1, 3, ...> is treated as <1, 2, 3, ...>.
12520 // <-1, -1, 0, 1, ...> is treated as <2*NumElts-2, 2*NumElts-1, 0, 1, ...>.
12521 // ExpectedElt is the last mask index plus 1.
12522 Imm = ExpectedElt.getZExtValue();
12523
12524 // There are two difference cases requiring to reverse input vectors.
12525 // For example, for vector <4 x i32> we have the following cases,
12526 // Case 1: shufflevector(<4 x i32>,<4 x i32>,<-1, -1, -1, 0>)
12527 // Case 2: shufflevector(<4 x i32>,<4 x i32>,<-1, -1, 7, 0>)
12528 // For both cases, we finally use mask <5, 6, 7, 0>, which requires
12529 // to reverse two input vectors.
12530 if (Imm < NumElts)
12531 ReverseEXT = true;
12532 else
12533 Imm -= NumElts;
12534
12535 return true;
12536 }
12537
12538 /// isZIP_v_undef_Mask - Special case of isZIPMask for canonical form of
12539 /// "vector_shuffle v, v", i.e., "vector_shuffle v, undef".
12540 /// Mask is e.g., <0, 0, 1, 1> instead of <0, 4, 1, 5>.
isZIP_v_undef_Mask(ArrayRef<int> M,EVT VT,unsigned & WhichResult)12541 static bool isZIP_v_undef_Mask(ArrayRef<int> M, EVT VT, unsigned &WhichResult) {
12542 unsigned NumElts = VT.getVectorNumElements();
12543 if (NumElts % 2 != 0)
12544 return false;
12545 WhichResult = (M[0] == 0 ? 0 : 1);
12546 unsigned Idx = WhichResult * NumElts / 2;
12547 for (unsigned i = 0; i != NumElts; i += 2) {
12548 if ((M[i] >= 0 && (unsigned)M[i] != Idx) ||
12549 (M[i + 1] >= 0 && (unsigned)M[i + 1] != Idx))
12550 return false;
12551 Idx += 1;
12552 }
12553
12554 return true;
12555 }
12556
12557 /// isUZP_v_undef_Mask - Special case of isUZPMask for canonical form of
12558 /// "vector_shuffle v, v", i.e., "vector_shuffle v, undef".
12559 /// Mask is e.g., <0, 2, 0, 2> instead of <0, 2, 4, 6>,
isUZP_v_undef_Mask(ArrayRef<int> M,EVT VT,unsigned & WhichResult)12560 static bool isUZP_v_undef_Mask(ArrayRef<int> M, EVT VT, unsigned &WhichResult) {
12561 unsigned Half = VT.getVectorNumElements() / 2;
12562 WhichResult = (M[0] == 0 ? 0 : 1);
12563 for (unsigned j = 0; j != 2; ++j) {
12564 unsigned Idx = WhichResult;
12565 for (unsigned i = 0; i != Half; ++i) {
12566 int MIdx = M[i + j * Half];
12567 if (MIdx >= 0 && (unsigned)MIdx != Idx)
12568 return false;
12569 Idx += 2;
12570 }
12571 }
12572
12573 return true;
12574 }
12575
12576 /// isTRN_v_undef_Mask - Special case of isTRNMask for canonical form of
12577 /// "vector_shuffle v, v", i.e., "vector_shuffle v, undef".
12578 /// Mask is e.g., <0, 0, 2, 2> instead of <0, 4, 2, 6>.
isTRN_v_undef_Mask(ArrayRef<int> M,EVT VT,unsigned & WhichResult)12579 static bool isTRN_v_undef_Mask(ArrayRef<int> M, EVT VT, unsigned &WhichResult) {
12580 unsigned NumElts = VT.getVectorNumElements();
12581 if (NumElts % 2 != 0)
12582 return false;
12583 WhichResult = (M[0] == 0 ? 0 : 1);
12584 for (unsigned i = 0; i < NumElts; i += 2) {
12585 if ((M[i] >= 0 && (unsigned)M[i] != i + WhichResult) ||
12586 (M[i + 1] >= 0 && (unsigned)M[i + 1] != i + WhichResult))
12587 return false;
12588 }
12589 return true;
12590 }
12591
isINSMask(ArrayRef<int> M,int NumInputElements,bool & DstIsLeft,int & Anomaly)12592 static bool isINSMask(ArrayRef<int> M, int NumInputElements,
12593 bool &DstIsLeft, int &Anomaly) {
12594 if (M.size() != static_cast<size_t>(NumInputElements))
12595 return false;
12596
12597 int NumLHSMatch = 0, NumRHSMatch = 0;
12598 int LastLHSMismatch = -1, LastRHSMismatch = -1;
12599
12600 for (int i = 0; i < NumInputElements; ++i) {
12601 if (M[i] == -1) {
12602 ++NumLHSMatch;
12603 ++NumRHSMatch;
12604 continue;
12605 }
12606
12607 if (M[i] == i)
12608 ++NumLHSMatch;
12609 else
12610 LastLHSMismatch = i;
12611
12612 if (M[i] == i + NumInputElements)
12613 ++NumRHSMatch;
12614 else
12615 LastRHSMismatch = i;
12616 }
12617
12618 if (NumLHSMatch == NumInputElements - 1) {
12619 DstIsLeft = true;
12620 Anomaly = LastLHSMismatch;
12621 return true;
12622 } else if (NumRHSMatch == NumInputElements - 1) {
12623 DstIsLeft = false;
12624 Anomaly = LastRHSMismatch;
12625 return true;
12626 }
12627
12628 return false;
12629 }
12630
isConcatMask(ArrayRef<int> Mask,EVT VT,bool SplitLHS)12631 static bool isConcatMask(ArrayRef<int> Mask, EVT VT, bool SplitLHS) {
12632 if (VT.getSizeInBits() != 128)
12633 return false;
12634
12635 unsigned NumElts = VT.getVectorNumElements();
12636
12637 for (int I = 0, E = NumElts / 2; I != E; I++) {
12638 if (Mask[I] != I)
12639 return false;
12640 }
12641
12642 int Offset = NumElts / 2;
12643 for (int I = NumElts / 2, E = NumElts; I != E; I++) {
12644 if (Mask[I] != I + SplitLHS * Offset)
12645 return false;
12646 }
12647
12648 return true;
12649 }
12650
tryFormConcatFromShuffle(SDValue Op,SelectionDAG & DAG)12651 static SDValue tryFormConcatFromShuffle(SDValue Op, SelectionDAG &DAG) {
12652 SDLoc DL(Op);
12653 EVT VT = Op.getValueType();
12654 SDValue V0 = Op.getOperand(0);
12655 SDValue V1 = Op.getOperand(1);
12656 ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(Op)->getMask();
12657
12658 if (VT.getVectorElementType() != V0.getValueType().getVectorElementType() ||
12659 VT.getVectorElementType() != V1.getValueType().getVectorElementType())
12660 return SDValue();
12661
12662 bool SplitV0 = V0.getValueSizeInBits() == 128;
12663
12664 if (!isConcatMask(Mask, VT, SplitV0))
12665 return SDValue();
12666
12667 EVT CastVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
12668 if (SplitV0) {
12669 V0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, CastVT, V0,
12670 DAG.getConstant(0, DL, MVT::i64));
12671 }
12672 if (V1.getValueSizeInBits() == 128) {
12673 V1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, CastVT, V1,
12674 DAG.getConstant(0, DL, MVT::i64));
12675 }
12676 return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, V0, V1);
12677 }
12678
12679 /// GeneratePerfectShuffle - Given an entry in the perfect-shuffle table, emit
12680 /// the specified operations to build the shuffle. ID is the perfect-shuffle
12681 //ID, V1 and V2 are the original shuffle inputs. PFEntry is the Perfect shuffle
12682 //table entry and LHS/RHS are the immediate inputs for this stage of the
12683 //shuffle.
GeneratePerfectShuffle(unsigned ID,SDValue V1,SDValue V2,unsigned PFEntry,SDValue LHS,SDValue RHS,SelectionDAG & DAG,const SDLoc & dl)12684 static SDValue GeneratePerfectShuffle(unsigned ID, SDValue V1,
12685 SDValue V2, unsigned PFEntry, SDValue LHS,
12686 SDValue RHS, SelectionDAG &DAG,
12687 const SDLoc &dl) {
12688 unsigned OpNum = (PFEntry >> 26) & 0x0F;
12689 unsigned LHSID = (PFEntry >> 13) & ((1 << 13) - 1);
12690 unsigned RHSID = (PFEntry >> 0) & ((1 << 13) - 1);
12691
12692 enum {
12693 OP_COPY = 0, // Copy, used for things like <u,u,u,3> to say it is <0,1,2,3>
12694 OP_VREV,
12695 OP_VDUP0,
12696 OP_VDUP1,
12697 OP_VDUP2,
12698 OP_VDUP3,
12699 OP_VEXT1,
12700 OP_VEXT2,
12701 OP_VEXT3,
12702 OP_VUZPL, // VUZP, left result
12703 OP_VUZPR, // VUZP, right result
12704 OP_VZIPL, // VZIP, left result
12705 OP_VZIPR, // VZIP, right result
12706 OP_VTRNL, // VTRN, left result
12707 OP_VTRNR, // VTRN, right result
12708 OP_MOVLANE // Move lane. RHSID is the lane to move into
12709 };
12710
12711 if (OpNum == OP_COPY) {
12712 if (LHSID == (1 * 9 + 2) * 9 + 3)
12713 return LHS;
12714 assert(LHSID == ((4 * 9 + 5) * 9 + 6) * 9 + 7 && "Illegal OP_COPY!");
12715 return RHS;
12716 }
12717
12718 if (OpNum == OP_MOVLANE) {
12719 // Decompose a PerfectShuffle ID to get the Mask for lane Elt
12720 auto getPFIDLane = [](unsigned ID, int Elt) -> int {
12721 assert(Elt < 4 && "Expected Perfect Lanes to be less than 4");
12722 Elt = 3 - Elt;
12723 while (Elt > 0) {
12724 ID /= 9;
12725 Elt--;
12726 }
12727 return (ID % 9 == 8) ? -1 : ID % 9;
12728 };
12729
12730 // For OP_MOVLANE shuffles, the RHSID represents the lane to move into. We
12731 // get the lane to move from the PFID, which is always from the
12732 // original vectors (V1 or V2).
12733 SDValue OpLHS = GeneratePerfectShuffle(
12734 LHSID, V1, V2, PerfectShuffleTable[LHSID], LHS, RHS, DAG, dl);
12735 EVT VT = OpLHS.getValueType();
12736 assert(RHSID < 8 && "Expected a lane index for RHSID!");
12737 unsigned ExtLane = 0;
12738 SDValue Input;
12739
12740 // OP_MOVLANE are either D movs (if bit 0x4 is set) or S movs. D movs
12741 // convert into a higher type.
12742 if (RHSID & 0x4) {
12743 int MaskElt = getPFIDLane(ID, (RHSID & 0x01) << 1) >> 1;
12744 if (MaskElt == -1)
12745 MaskElt = (getPFIDLane(ID, ((RHSID & 0x01) << 1) + 1) - 1) >> 1;
12746 assert(MaskElt >= 0 && "Didn't expect an undef movlane index!");
12747 ExtLane = MaskElt < 2 ? MaskElt : (MaskElt - 2);
12748 Input = MaskElt < 2 ? V1 : V2;
12749 if (VT.getScalarSizeInBits() == 16) {
12750 Input = DAG.getBitcast(MVT::v2f32, Input);
12751 OpLHS = DAG.getBitcast(MVT::v2f32, OpLHS);
12752 } else {
12753 assert(VT.getScalarSizeInBits() == 32 &&
12754 "Expected 16 or 32 bit shuffle elemements");
12755 Input = DAG.getBitcast(MVT::v2f64, Input);
12756 OpLHS = DAG.getBitcast(MVT::v2f64, OpLHS);
12757 }
12758 } else {
12759 int MaskElt = getPFIDLane(ID, RHSID);
12760 assert(MaskElt >= 0 && "Didn't expect an undef movlane index!");
12761 ExtLane = MaskElt < 4 ? MaskElt : (MaskElt - 4);
12762 Input = MaskElt < 4 ? V1 : V2;
12763 // Be careful about creating illegal types. Use f16 instead of i16.
12764 if (VT == MVT::v4i16) {
12765 Input = DAG.getBitcast(MVT::v4f16, Input);
12766 OpLHS = DAG.getBitcast(MVT::v4f16, OpLHS);
12767 }
12768 }
12769 SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
12770 Input.getValueType().getVectorElementType(),
12771 Input, DAG.getVectorIdxConstant(ExtLane, dl));
12772 SDValue Ins =
12773 DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, Input.getValueType(), OpLHS,
12774 Ext, DAG.getVectorIdxConstant(RHSID & 0x3, dl));
12775 return DAG.getBitcast(VT, Ins);
12776 }
12777
12778 SDValue OpLHS, OpRHS;
12779 OpLHS = GeneratePerfectShuffle(LHSID, V1, V2, PerfectShuffleTable[LHSID], LHS,
12780 RHS, DAG, dl);
12781 OpRHS = GeneratePerfectShuffle(RHSID, V1, V2, PerfectShuffleTable[RHSID], LHS,
12782 RHS, DAG, dl);
12783 EVT VT = OpLHS.getValueType();
12784
12785 switch (OpNum) {
12786 default:
12787 llvm_unreachable("Unknown shuffle opcode!");
12788 case OP_VREV:
12789 // VREV divides the vector in half and swaps within the half.
12790 if (VT.getVectorElementType() == MVT::i32 ||
12791 VT.getVectorElementType() == MVT::f32)
12792 return DAG.getNode(AArch64ISD::REV64, dl, VT, OpLHS);
12793 // vrev <4 x i16> -> REV32
12794 if (VT.getVectorElementType() == MVT::i16 ||
12795 VT.getVectorElementType() == MVT::f16 ||
12796 VT.getVectorElementType() == MVT::bf16)
12797 return DAG.getNode(AArch64ISD::REV32, dl, VT, OpLHS);
12798 // vrev <4 x i8> -> REV16
12799 assert(VT.getVectorElementType() == MVT::i8);
12800 return DAG.getNode(AArch64ISD::REV16, dl, VT, OpLHS);
12801 case OP_VDUP0:
12802 case OP_VDUP1:
12803 case OP_VDUP2:
12804 case OP_VDUP3: {
12805 EVT EltTy = VT.getVectorElementType();
12806 unsigned Opcode;
12807 if (EltTy == MVT::i8)
12808 Opcode = AArch64ISD::DUPLANE8;
12809 else if (EltTy == MVT::i16 || EltTy == MVT::f16 || EltTy == MVT::bf16)
12810 Opcode = AArch64ISD::DUPLANE16;
12811 else if (EltTy == MVT::i32 || EltTy == MVT::f32)
12812 Opcode = AArch64ISD::DUPLANE32;
12813 else if (EltTy == MVT::i64 || EltTy == MVT::f64)
12814 Opcode = AArch64ISD::DUPLANE64;
12815 else
12816 llvm_unreachable("Invalid vector element type?");
12817
12818 if (VT.getSizeInBits() == 64)
12819 OpLHS = WidenVector(OpLHS, DAG);
12820 SDValue Lane = DAG.getConstant(OpNum - OP_VDUP0, dl, MVT::i64);
12821 return DAG.getNode(Opcode, dl, VT, OpLHS, Lane);
12822 }
12823 case OP_VEXT1:
12824 case OP_VEXT2:
12825 case OP_VEXT3: {
12826 unsigned Imm = (OpNum - OP_VEXT1 + 1) * getExtFactor(OpLHS);
12827 return DAG.getNode(AArch64ISD::EXT, dl, VT, OpLHS, OpRHS,
12828 DAG.getConstant(Imm, dl, MVT::i32));
12829 }
12830 case OP_VUZPL:
12831 return DAG.getNode(AArch64ISD::UZP1, dl, VT, OpLHS, OpRHS);
12832 case OP_VUZPR:
12833 return DAG.getNode(AArch64ISD::UZP2, dl, VT, OpLHS, OpRHS);
12834 case OP_VZIPL:
12835 return DAG.getNode(AArch64ISD::ZIP1, dl, VT, OpLHS, OpRHS);
12836 case OP_VZIPR:
12837 return DAG.getNode(AArch64ISD::ZIP2, dl, VT, OpLHS, OpRHS);
12838 case OP_VTRNL:
12839 return DAG.getNode(AArch64ISD::TRN1, dl, VT, OpLHS, OpRHS);
12840 case OP_VTRNR:
12841 return DAG.getNode(AArch64ISD::TRN2, dl, VT, OpLHS, OpRHS);
12842 }
12843 }
12844
GenerateTBL(SDValue Op,ArrayRef<int> ShuffleMask,SelectionDAG & DAG)12845 static SDValue GenerateTBL(SDValue Op, ArrayRef<int> ShuffleMask,
12846 SelectionDAG &DAG) {
12847 // Check to see if we can use the TBL instruction.
12848 SDValue V1 = Op.getOperand(0);
12849 SDValue V2 = Op.getOperand(1);
12850 SDLoc DL(Op);
12851
12852 EVT EltVT = Op.getValueType().getVectorElementType();
12853 unsigned BytesPerElt = EltVT.getSizeInBits() / 8;
12854
12855 bool Swap = false;
12856 if (V1.isUndef() || isZerosVector(V1.getNode())) {
12857 std::swap(V1, V2);
12858 Swap = true;
12859 }
12860
12861 // If the V2 source is undef or zero then we can use a tbl1, as tbl1 will fill
12862 // out of range values with 0s. We do need to make sure that any out-of-range
12863 // values are really out-of-range for a v16i8 vector.
12864 bool IsUndefOrZero = V2.isUndef() || isZerosVector(V2.getNode());
12865 MVT IndexVT = MVT::v8i8;
12866 unsigned IndexLen = 8;
12867 if (Op.getValueSizeInBits() == 128) {
12868 IndexVT = MVT::v16i8;
12869 IndexLen = 16;
12870 }
12871
12872 SmallVector<SDValue, 8> TBLMask;
12873 for (int Val : ShuffleMask) {
12874 for (unsigned Byte = 0; Byte < BytesPerElt; ++Byte) {
12875 unsigned Offset = Byte + Val * BytesPerElt;
12876 if (Swap)
12877 Offset = Offset < IndexLen ? Offset + IndexLen : Offset - IndexLen;
12878 if (IsUndefOrZero && Offset >= IndexLen)
12879 Offset = 255;
12880 TBLMask.push_back(DAG.getConstant(Offset, DL, MVT::i32));
12881 }
12882 }
12883
12884 SDValue V1Cst = DAG.getNode(ISD::BITCAST, DL, IndexVT, V1);
12885 SDValue V2Cst = DAG.getNode(ISD::BITCAST, DL, IndexVT, V2);
12886
12887 SDValue Shuffle;
12888 if (IsUndefOrZero) {
12889 if (IndexLen == 8)
12890 V1Cst = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, V1Cst, V1Cst);
12891 Shuffle = DAG.getNode(
12892 ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
12893 DAG.getConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), V1Cst,
12894 DAG.getBuildVector(IndexVT, DL, ArrayRef(TBLMask.data(), IndexLen)));
12895 } else {
12896 if (IndexLen == 8) {
12897 V1Cst = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v16i8, V1Cst, V2Cst);
12898 Shuffle = DAG.getNode(
12899 ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
12900 DAG.getConstant(Intrinsic::aarch64_neon_tbl1, DL, MVT::i32), V1Cst,
12901 DAG.getBuildVector(IndexVT, DL, ArrayRef(TBLMask.data(), IndexLen)));
12902 } else {
12903 // FIXME: We cannot, for the moment, emit a TBL2 instruction because we
12904 // cannot currently represent the register constraints on the input
12905 // table registers.
12906 // Shuffle = DAG.getNode(AArch64ISD::TBL2, DL, IndexVT, V1Cst, V2Cst,
12907 // DAG.getBuildVector(IndexVT, DL, &TBLMask[0],
12908 // IndexLen));
12909 Shuffle = DAG.getNode(
12910 ISD::INTRINSIC_WO_CHAIN, DL, IndexVT,
12911 DAG.getConstant(Intrinsic::aarch64_neon_tbl2, DL, MVT::i32), V1Cst,
12912 V2Cst,
12913 DAG.getBuildVector(IndexVT, DL, ArrayRef(TBLMask.data(), IndexLen)));
12914 }
12915 }
12916 return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(), Shuffle);
12917 }
12918
getDUPLANEOp(EVT EltType)12919 static unsigned getDUPLANEOp(EVT EltType) {
12920 if (EltType == MVT::i8)
12921 return AArch64ISD::DUPLANE8;
12922 if (EltType == MVT::i16 || EltType == MVT::f16 || EltType == MVT::bf16)
12923 return AArch64ISD::DUPLANE16;
12924 if (EltType == MVT::i32 || EltType == MVT::f32)
12925 return AArch64ISD::DUPLANE32;
12926 if (EltType == MVT::i64 || EltType == MVT::f64)
12927 return AArch64ISD::DUPLANE64;
12928
12929 llvm_unreachable("Invalid vector element type?");
12930 }
12931
constructDup(SDValue V,int Lane,SDLoc dl,EVT VT,unsigned Opcode,SelectionDAG & DAG)12932 static SDValue constructDup(SDValue V, int Lane, SDLoc dl, EVT VT,
12933 unsigned Opcode, SelectionDAG &DAG) {
12934 // Try to eliminate a bitcasted extract subvector before a DUPLANE.
12935 auto getScaledOffsetDup = [](SDValue BitCast, int &LaneC, MVT &CastVT) {
12936 // Match: dup (bitcast (extract_subv X, C)), LaneC
12937 if (BitCast.getOpcode() != ISD::BITCAST ||
12938 BitCast.getOperand(0).getOpcode() != ISD::EXTRACT_SUBVECTOR)
12939 return false;
12940
12941 // The extract index must align in the destination type. That may not
12942 // happen if the bitcast is from narrow to wide type.
12943 SDValue Extract = BitCast.getOperand(0);
12944 unsigned ExtIdx = Extract.getConstantOperandVal(1);
12945 unsigned SrcEltBitWidth = Extract.getScalarValueSizeInBits();
12946 unsigned ExtIdxInBits = ExtIdx * SrcEltBitWidth;
12947 unsigned CastedEltBitWidth = BitCast.getScalarValueSizeInBits();
12948 if (ExtIdxInBits % CastedEltBitWidth != 0)
12949 return false;
12950
12951 // Can't handle cases where vector size is not 128-bit
12952 if (!Extract.getOperand(0).getValueType().is128BitVector())
12953 return false;
12954
12955 // Update the lane value by offsetting with the scaled extract index.
12956 LaneC += ExtIdxInBits / CastedEltBitWidth;
12957
12958 // Determine the casted vector type of the wide vector input.
12959 // dup (bitcast (extract_subv X, C)), LaneC --> dup (bitcast X), LaneC'
12960 // Examples:
12961 // dup (bitcast (extract_subv v2f64 X, 1) to v2f32), 1 --> dup v4f32 X, 3
12962 // dup (bitcast (extract_subv v16i8 X, 8) to v4i16), 1 --> dup v8i16 X, 5
12963 unsigned SrcVecNumElts =
12964 Extract.getOperand(0).getValueSizeInBits() / CastedEltBitWidth;
12965 CastVT = MVT::getVectorVT(BitCast.getSimpleValueType().getScalarType(),
12966 SrcVecNumElts);
12967 return true;
12968 };
12969 MVT CastVT;
12970 if (getScaledOffsetDup(V, Lane, CastVT)) {
12971 V = DAG.getBitcast(CastVT, V.getOperand(0).getOperand(0));
12972 } else if (V.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
12973 V.getOperand(0).getValueType().is128BitVector()) {
12974 // The lane is incremented by the index of the extract.
12975 // Example: dup v2f32 (extract v4f32 X, 2), 1 --> dup v4f32 X, 3
12976 Lane += V.getConstantOperandVal(1);
12977 V = V.getOperand(0);
12978 } else if (V.getOpcode() == ISD::CONCAT_VECTORS) {
12979 // The lane is decremented if we are splatting from the 2nd operand.
12980 // Example: dup v4i32 (concat v2i32 X, v2i32 Y), 3 --> dup v4i32 Y, 1
12981 unsigned Idx = Lane >= (int)VT.getVectorNumElements() / 2;
12982 Lane -= Idx * VT.getVectorNumElements() / 2;
12983 V = WidenVector(V.getOperand(Idx), DAG);
12984 } else if (VT.getSizeInBits() == 64) {
12985 // Widen the operand to 128-bit register with undef.
12986 V = WidenVector(V, DAG);
12987 }
12988 return DAG.getNode(Opcode, dl, VT, V, DAG.getConstant(Lane, dl, MVT::i64));
12989 }
12990
12991 // Return true if we can get a new shuffle mask by checking the parameter mask
12992 // array to test whether every two adjacent mask values are continuous and
12993 // starting from an even number.
isWideTypeMask(ArrayRef<int> M,EVT VT,SmallVectorImpl<int> & NewMask)12994 static bool isWideTypeMask(ArrayRef<int> M, EVT VT,
12995 SmallVectorImpl<int> &NewMask) {
12996 unsigned NumElts = VT.getVectorNumElements();
12997 if (NumElts % 2 != 0)
12998 return false;
12999
13000 NewMask.clear();
13001 for (unsigned i = 0; i < NumElts; i += 2) {
13002 int M0 = M[i];
13003 int M1 = M[i + 1];
13004
13005 // If both elements are undef, new mask is undef too.
13006 if (M0 == -1 && M1 == -1) {
13007 NewMask.push_back(-1);
13008 continue;
13009 }
13010
13011 if (M0 == -1 && M1 != -1 && (M1 % 2) == 1) {
13012 NewMask.push_back(M1 / 2);
13013 continue;
13014 }
13015
13016 if (M0 != -1 && (M0 % 2) == 0 && ((M0 + 1) == M1 || M1 == -1)) {
13017 NewMask.push_back(M0 / 2);
13018 continue;
13019 }
13020
13021 NewMask.clear();
13022 return false;
13023 }
13024
13025 assert(NewMask.size() == NumElts / 2 && "Incorrect size for mask!");
13026 return true;
13027 }
13028
13029 // Try to widen element type to get a new mask value for a better permutation
13030 // sequence, so that we can use NEON shuffle instructions, such as zip1/2,
13031 // UZP1/2, TRN1/2, REV, INS, etc.
13032 // For example:
13033 // shufflevector <4 x i32> %a, <4 x i32> %b,
13034 // <4 x i32> <i32 6, i32 7, i32 2, i32 3>
13035 // is equivalent to:
13036 // shufflevector <2 x i64> %a, <2 x i64> %b, <2 x i32> <i32 3, i32 1>
13037 // Finally, we can get:
13038 // mov v0.d[0], v1.d[1]
tryWidenMaskForShuffle(SDValue Op,SelectionDAG & DAG)13039 static SDValue tryWidenMaskForShuffle(SDValue Op, SelectionDAG &DAG) {
13040 SDLoc DL(Op);
13041 EVT VT = Op.getValueType();
13042 EVT ScalarVT = VT.getVectorElementType();
13043 unsigned ElementSize = ScalarVT.getFixedSizeInBits();
13044 SDValue V0 = Op.getOperand(0);
13045 SDValue V1 = Op.getOperand(1);
13046 ArrayRef<int> Mask = cast<ShuffleVectorSDNode>(Op)->getMask();
13047
13048 // If combining adjacent elements, like two i16's -> i32, two i32's -> i64 ...
13049 // We need to make sure the wider element type is legal. Thus, ElementSize
13050 // should be not larger than 32 bits, and i1 type should also be excluded.
13051 if (ElementSize > 32 || ElementSize == 1)
13052 return SDValue();
13053
13054 SmallVector<int, 8> NewMask;
13055 if (isWideTypeMask(Mask, VT, NewMask)) {
13056 MVT NewEltVT = VT.isFloatingPoint()
13057 ? MVT::getFloatingPointVT(ElementSize * 2)
13058 : MVT::getIntegerVT(ElementSize * 2);
13059 MVT NewVT = MVT::getVectorVT(NewEltVT, VT.getVectorNumElements() / 2);
13060 if (DAG.getTargetLoweringInfo().isTypeLegal(NewVT)) {
13061 V0 = DAG.getBitcast(NewVT, V0);
13062 V1 = DAG.getBitcast(NewVT, V1);
13063 return DAG.getBitcast(VT,
13064 DAG.getVectorShuffle(NewVT, DL, V0, V1, NewMask));
13065 }
13066 }
13067
13068 return SDValue();
13069 }
13070
13071 // Try to fold shuffle (tbl2, tbl2) into a single tbl4.
tryToConvertShuffleOfTbl2ToTbl4(SDValue Op,ArrayRef<int> ShuffleMask,SelectionDAG & DAG)13072 static SDValue tryToConvertShuffleOfTbl2ToTbl4(SDValue Op,
13073 ArrayRef<int> ShuffleMask,
13074 SelectionDAG &DAG) {
13075 SDValue Tbl1 = Op->getOperand(0);
13076 SDValue Tbl2 = Op->getOperand(1);
13077 SDLoc dl(Op);
13078 SDValue Tbl2ID =
13079 DAG.getTargetConstant(Intrinsic::aarch64_neon_tbl2, dl, MVT::i64);
13080
13081 EVT VT = Op.getValueType();
13082 if (Tbl1->getOpcode() != ISD::INTRINSIC_WO_CHAIN ||
13083 Tbl1->getOperand(0) != Tbl2ID ||
13084 Tbl2->getOpcode() != ISD::INTRINSIC_WO_CHAIN ||
13085 Tbl2->getOperand(0) != Tbl2ID)
13086 return SDValue();
13087
13088 if (Tbl1->getValueType(0) != MVT::v16i8 ||
13089 Tbl2->getValueType(0) != MVT::v16i8)
13090 return SDValue();
13091
13092 SDValue Mask1 = Tbl1->getOperand(3);
13093 SDValue Mask2 = Tbl2->getOperand(3);
13094 SmallVector<SDValue, 16> TBLMaskParts(16, SDValue());
13095 for (unsigned I = 0; I < 16; I++) {
13096 if (ShuffleMask[I] < 16)
13097 TBLMaskParts[I] = Mask1->getOperand(ShuffleMask[I]);
13098 else {
13099 auto *C =
13100 dyn_cast<ConstantSDNode>(Mask2->getOperand(ShuffleMask[I] - 16));
13101 if (!C)
13102 return SDValue();
13103 TBLMaskParts[I] = DAG.getConstant(C->getSExtValue() + 32, dl, MVT::i32);
13104 }
13105 }
13106
13107 SDValue TBLMask = DAG.getBuildVector(VT, dl, TBLMaskParts);
13108 SDValue ID =
13109 DAG.getTargetConstant(Intrinsic::aarch64_neon_tbl4, dl, MVT::i64);
13110
13111 return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, MVT::v16i8,
13112 {ID, Tbl1->getOperand(1), Tbl1->getOperand(2),
13113 Tbl2->getOperand(1), Tbl2->getOperand(2), TBLMask});
13114 }
13115
13116 // Baseline legalization for ZERO_EXTEND_VECTOR_INREG will blend-in zeros,
13117 // but we don't have an appropriate instruction,
13118 // so custom-lower it as ZIP1-with-zeros.
13119 SDValue
LowerZERO_EXTEND_VECTOR_INREG(SDValue Op,SelectionDAG & DAG) const13120 AArch64TargetLowering::LowerZERO_EXTEND_VECTOR_INREG(SDValue Op,
13121 SelectionDAG &DAG) const {
13122 SDLoc dl(Op);
13123 EVT VT = Op.getValueType();
13124 SDValue SrcOp = Op.getOperand(0);
13125 EVT SrcVT = SrcOp.getValueType();
13126 assert(VT.getScalarSizeInBits() % SrcVT.getScalarSizeInBits() == 0 &&
13127 "Unexpected extension factor.");
13128 unsigned Scale = VT.getScalarSizeInBits() / SrcVT.getScalarSizeInBits();
13129 // FIXME: support multi-step zipping?
13130 if (Scale != 2)
13131 return SDValue();
13132 SDValue Zeros = DAG.getConstant(0, dl, SrcVT);
13133 return DAG.getBitcast(VT,
13134 DAG.getNode(AArch64ISD::ZIP1, dl, SrcVT, SrcOp, Zeros));
13135 }
13136
LowerVECTOR_SHUFFLE(SDValue Op,SelectionDAG & DAG) const13137 SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
13138 SelectionDAG &DAG) const {
13139 SDLoc dl(Op);
13140 EVT VT = Op.getValueType();
13141
13142 ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
13143
13144 if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
13145 return LowerFixedLengthVECTOR_SHUFFLEToSVE(Op, DAG);
13146
13147 // Convert shuffles that are directly supported on NEON to target-specific
13148 // DAG nodes, instead of keeping them as shuffles and matching them again
13149 // during code selection. This is more efficient and avoids the possibility
13150 // of inconsistencies between legalization and selection.
13151 ArrayRef<int> ShuffleMask = SVN->getMask();
13152
13153 SDValue V1 = Op.getOperand(0);
13154 SDValue V2 = Op.getOperand(1);
13155
13156 assert(V1.getValueType() == VT && "Unexpected VECTOR_SHUFFLE type!");
13157 assert(ShuffleMask.size() == VT.getVectorNumElements() &&
13158 "Unexpected VECTOR_SHUFFLE mask size!");
13159
13160 if (SDValue Res = tryToConvertShuffleOfTbl2ToTbl4(Op, ShuffleMask, DAG))
13161 return Res;
13162
13163 if (SVN->isSplat()) {
13164 int Lane = SVN->getSplatIndex();
13165 // If this is undef splat, generate it via "just" vdup, if possible.
13166 if (Lane == -1)
13167 Lane = 0;
13168
13169 if (Lane == 0 && V1.getOpcode() == ISD::SCALAR_TO_VECTOR)
13170 return DAG.getNode(AArch64ISD::DUP, dl, V1.getValueType(),
13171 V1.getOperand(0));
13172 // Test if V1 is a BUILD_VECTOR and the lane being referenced is a non-
13173 // constant. If so, we can just reference the lane's definition directly.
13174 if (V1.getOpcode() == ISD::BUILD_VECTOR &&
13175 !isa<ConstantSDNode>(V1.getOperand(Lane)))
13176 return DAG.getNode(AArch64ISD::DUP, dl, VT, V1.getOperand(Lane));
13177
13178 // Otherwise, duplicate from the lane of the input vector.
13179 unsigned Opcode = getDUPLANEOp(V1.getValueType().getVectorElementType());
13180 return constructDup(V1, Lane, dl, VT, Opcode, DAG);
13181 }
13182
13183 // Check if the mask matches a DUP for a wider element
13184 for (unsigned LaneSize : {64U, 32U, 16U}) {
13185 unsigned Lane = 0;
13186 if (isWideDUPMask(ShuffleMask, VT, LaneSize, Lane)) {
13187 unsigned Opcode = LaneSize == 64 ? AArch64ISD::DUPLANE64
13188 : LaneSize == 32 ? AArch64ISD::DUPLANE32
13189 : AArch64ISD::DUPLANE16;
13190 // Cast V1 to an integer vector with required lane size
13191 MVT NewEltTy = MVT::getIntegerVT(LaneSize);
13192 unsigned NewEltCount = VT.getSizeInBits() / LaneSize;
13193 MVT NewVecTy = MVT::getVectorVT(NewEltTy, NewEltCount);
13194 V1 = DAG.getBitcast(NewVecTy, V1);
13195 // Constuct the DUP instruction
13196 V1 = constructDup(V1, Lane, dl, NewVecTy, Opcode, DAG);
13197 // Cast back to the original type
13198 return DAG.getBitcast(VT, V1);
13199 }
13200 }
13201
13202 unsigned NumElts = VT.getVectorNumElements();
13203 unsigned EltSize = VT.getScalarSizeInBits();
13204 if (isREVMask(ShuffleMask, EltSize, NumElts, 64))
13205 return DAG.getNode(AArch64ISD::REV64, dl, V1.getValueType(), V1, V2);
13206 if (isREVMask(ShuffleMask, EltSize, NumElts, 32))
13207 return DAG.getNode(AArch64ISD::REV32, dl, V1.getValueType(), V1, V2);
13208 if (isREVMask(ShuffleMask, EltSize, NumElts, 16))
13209 return DAG.getNode(AArch64ISD::REV16, dl, V1.getValueType(), V1, V2);
13210
13211 if (((NumElts == 8 && EltSize == 16) || (NumElts == 16 && EltSize == 8)) &&
13212 ShuffleVectorInst::isReverseMask(ShuffleMask, ShuffleMask.size())) {
13213 SDValue Rev = DAG.getNode(AArch64ISD::REV64, dl, VT, V1);
13214 return DAG.getNode(AArch64ISD::EXT, dl, VT, Rev, Rev,
13215 DAG.getConstant(8, dl, MVT::i32));
13216 }
13217
13218 bool ReverseEXT = false;
13219 unsigned Imm;
13220 if (isEXTMask(ShuffleMask, VT, ReverseEXT, Imm)) {
13221 if (ReverseEXT)
13222 std::swap(V1, V2);
13223 Imm *= getExtFactor(V1);
13224 return DAG.getNode(AArch64ISD::EXT, dl, V1.getValueType(), V1, V2,
13225 DAG.getConstant(Imm, dl, MVT::i32));
13226 } else if (V2->isUndef() && isSingletonEXTMask(ShuffleMask, VT, Imm)) {
13227 Imm *= getExtFactor(V1);
13228 return DAG.getNode(AArch64ISD::EXT, dl, V1.getValueType(), V1, V1,
13229 DAG.getConstant(Imm, dl, MVT::i32));
13230 }
13231
13232 unsigned WhichResult;
13233 if (isZIPMask(ShuffleMask, NumElts, WhichResult)) {
13234 unsigned Opc = (WhichResult == 0) ? AArch64ISD::ZIP1 : AArch64ISD::ZIP2;
13235 return DAG.getNode(Opc, dl, V1.getValueType(), V1, V2);
13236 }
13237 if (isUZPMask(ShuffleMask, NumElts, WhichResult)) {
13238 unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2;
13239 return DAG.getNode(Opc, dl, V1.getValueType(), V1, V2);
13240 }
13241 if (isTRNMask(ShuffleMask, NumElts, WhichResult)) {
13242 unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2;
13243 return DAG.getNode(Opc, dl, V1.getValueType(), V1, V2);
13244 }
13245
13246 if (isZIP_v_undef_Mask(ShuffleMask, VT, WhichResult)) {
13247 unsigned Opc = (WhichResult == 0) ? AArch64ISD::ZIP1 : AArch64ISD::ZIP2;
13248 return DAG.getNode(Opc, dl, V1.getValueType(), V1, V1);
13249 }
13250 if (isUZP_v_undef_Mask(ShuffleMask, VT, WhichResult)) {
13251 unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2;
13252 return DAG.getNode(Opc, dl, V1.getValueType(), V1, V1);
13253 }
13254 if (isTRN_v_undef_Mask(ShuffleMask, VT, WhichResult)) {
13255 unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2;
13256 return DAG.getNode(Opc, dl, V1.getValueType(), V1, V1);
13257 }
13258
13259 if (SDValue Concat = tryFormConcatFromShuffle(Op, DAG))
13260 return Concat;
13261
13262 bool DstIsLeft;
13263 int Anomaly;
13264 int NumInputElements = V1.getValueType().getVectorNumElements();
13265 if (isINSMask(ShuffleMask, NumInputElements, DstIsLeft, Anomaly)) {
13266 SDValue DstVec = DstIsLeft ? V1 : V2;
13267 SDValue DstLaneV = DAG.getConstant(Anomaly, dl, MVT::i64);
13268
13269 SDValue SrcVec = V1;
13270 int SrcLane = ShuffleMask[Anomaly];
13271 if (SrcLane >= NumInputElements) {
13272 SrcVec = V2;
13273 SrcLane -= NumElts;
13274 }
13275 SDValue SrcLaneV = DAG.getConstant(SrcLane, dl, MVT::i64);
13276
13277 EVT ScalarVT = VT.getVectorElementType();
13278
13279 if (ScalarVT.getFixedSizeInBits() < 32 && ScalarVT.isInteger())
13280 ScalarVT = MVT::i32;
13281
13282 return DAG.getNode(
13283 ISD::INSERT_VECTOR_ELT, dl, VT, DstVec,
13284 DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ScalarVT, SrcVec, SrcLaneV),
13285 DstLaneV);
13286 }
13287
13288 if (SDValue NewSD = tryWidenMaskForShuffle(Op, DAG))
13289 return NewSD;
13290
13291 // If the shuffle is not directly supported and it has 4 elements, use
13292 // the PerfectShuffle-generated table to synthesize it from other shuffles.
13293 if (NumElts == 4) {
13294 unsigned PFIndexes[4];
13295 for (unsigned i = 0; i != 4; ++i) {
13296 if (ShuffleMask[i] < 0)
13297 PFIndexes[i] = 8;
13298 else
13299 PFIndexes[i] = ShuffleMask[i];
13300 }
13301
13302 // Compute the index in the perfect shuffle table.
13303 unsigned PFTableIndex = PFIndexes[0] * 9 * 9 * 9 + PFIndexes[1] * 9 * 9 +
13304 PFIndexes[2] * 9 + PFIndexes[3];
13305 unsigned PFEntry = PerfectShuffleTable[PFTableIndex];
13306 return GeneratePerfectShuffle(PFTableIndex, V1, V2, PFEntry, V1, V2, DAG,
13307 dl);
13308 }
13309
13310 return GenerateTBL(Op, ShuffleMask, DAG);
13311 }
13312
LowerSPLAT_VECTOR(SDValue Op,SelectionDAG & DAG) const13313 SDValue AArch64TargetLowering::LowerSPLAT_VECTOR(SDValue Op,
13314 SelectionDAG &DAG) const {
13315 EVT VT = Op.getValueType();
13316
13317 if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
13318 return LowerToScalableOp(Op, DAG);
13319
13320 assert(VT.isScalableVector() && VT.getVectorElementType() == MVT::i1 &&
13321 "Unexpected vector type!");
13322
13323 // We can handle the constant cases during isel.
13324 if (isa<ConstantSDNode>(Op.getOperand(0)))
13325 return Op;
13326
13327 // There isn't a natural way to handle the general i1 case, so we use some
13328 // trickery with whilelo.
13329 SDLoc DL(Op);
13330 SDValue SplatVal = DAG.getAnyExtOrTrunc(Op.getOperand(0), DL, MVT::i64);
13331 SplatVal = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, MVT::i64, SplatVal,
13332 DAG.getValueType(MVT::i1));
13333 SDValue ID =
13334 DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo, DL, MVT::i64);
13335 SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
13336 if (VT == MVT::nxv1i1)
13337 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::nxv1i1,
13338 DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::nxv2i1, ID,
13339 Zero, SplatVal),
13340 Zero);
13341 return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT, ID, Zero, SplatVal);
13342 }
13343
LowerDUPQLane(SDValue Op,SelectionDAG & DAG) const13344 SDValue AArch64TargetLowering::LowerDUPQLane(SDValue Op,
13345 SelectionDAG &DAG) const {
13346 SDLoc DL(Op);
13347
13348 EVT VT = Op.getValueType();
13349 if (!isTypeLegal(VT) || !VT.isScalableVector())
13350 return SDValue();
13351
13352 // Current lowering only supports the SVE-ACLE types.
13353 if (VT.getSizeInBits().getKnownMinValue() != AArch64::SVEBitsPerBlock)
13354 return SDValue();
13355
13356 // The DUPQ operation is indepedent of element type so normalise to i64s.
13357 SDValue Idx128 = Op.getOperand(2);
13358
13359 // DUPQ can be used when idx is in range.
13360 auto *CIdx = dyn_cast<ConstantSDNode>(Idx128);
13361 if (CIdx && (CIdx->getZExtValue() <= 3)) {
13362 SDValue CI = DAG.getTargetConstant(CIdx->getZExtValue(), DL, MVT::i64);
13363 return DAG.getNode(AArch64ISD::DUPLANE128, DL, VT, Op.getOperand(1), CI);
13364 }
13365
13366 SDValue V = DAG.getNode(ISD::BITCAST, DL, MVT::nxv2i64, Op.getOperand(1));
13367
13368 // The ACLE says this must produce the same result as:
13369 // svtbl(data, svadd_x(svptrue_b64(),
13370 // svand_x(svptrue_b64(), svindex_u64(0, 1), 1),
13371 // index * 2))
13372 SDValue One = DAG.getConstant(1, DL, MVT::i64);
13373 SDValue SplatOne = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv2i64, One);
13374
13375 // create the vector 0,1,0,1,...
13376 SDValue SV = DAG.getStepVector(DL, MVT::nxv2i64);
13377 SV = DAG.getNode(ISD::AND, DL, MVT::nxv2i64, SV, SplatOne);
13378
13379 // create the vector idx64,idx64+1,idx64,idx64+1,...
13380 SDValue Idx64 = DAG.getNode(ISD::ADD, DL, MVT::i64, Idx128, Idx128);
13381 SDValue SplatIdx64 = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv2i64, Idx64);
13382 SDValue ShuffleMask = DAG.getNode(ISD::ADD, DL, MVT::nxv2i64, SV, SplatIdx64);
13383
13384 // create the vector Val[idx64],Val[idx64+1],Val[idx64],Val[idx64+1],...
13385 SDValue TBL = DAG.getNode(AArch64ISD::TBL, DL, MVT::nxv2i64, V, ShuffleMask);
13386 return DAG.getNode(ISD::BITCAST, DL, VT, TBL);
13387 }
13388
13389
resolveBuildVector(BuildVectorSDNode * BVN,APInt & CnstBits,APInt & UndefBits)13390 static bool resolveBuildVector(BuildVectorSDNode *BVN, APInt &CnstBits,
13391 APInt &UndefBits) {
13392 EVT VT = BVN->getValueType(0);
13393 APInt SplatBits, SplatUndef;
13394 unsigned SplatBitSize;
13395 bool HasAnyUndefs;
13396 if (BVN->isConstantSplat(SplatBits, SplatUndef, SplatBitSize, HasAnyUndefs)) {
13397 unsigned NumSplats = VT.getSizeInBits() / SplatBitSize;
13398
13399 for (unsigned i = 0; i < NumSplats; ++i) {
13400 CnstBits <<= SplatBitSize;
13401 UndefBits <<= SplatBitSize;
13402 CnstBits |= SplatBits.zextOrTrunc(VT.getSizeInBits());
13403 UndefBits |= (SplatBits ^ SplatUndef).zextOrTrunc(VT.getSizeInBits());
13404 }
13405
13406 return true;
13407 }
13408
13409 return false;
13410 }
13411
13412 // Try 64-bit splatted SIMD immediate.
tryAdvSIMDModImm64(unsigned NewOp,SDValue Op,SelectionDAG & DAG,const APInt & Bits)13413 static SDValue tryAdvSIMDModImm64(unsigned NewOp, SDValue Op, SelectionDAG &DAG,
13414 const APInt &Bits) {
13415 if (Bits.getHiBits(64) == Bits.getLoBits(64)) {
13416 uint64_t Value = Bits.zextOrTrunc(64).getZExtValue();
13417 EVT VT = Op.getValueType();
13418 MVT MovTy = (VT.getSizeInBits() == 128) ? MVT::v2i64 : MVT::f64;
13419
13420 if (AArch64_AM::isAdvSIMDModImmType10(Value)) {
13421 Value = AArch64_AM::encodeAdvSIMDModImmType10(Value);
13422
13423 SDLoc dl(Op);
13424 SDValue Mov = DAG.getNode(NewOp, dl, MovTy,
13425 DAG.getConstant(Value, dl, MVT::i32));
13426 return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov);
13427 }
13428 }
13429
13430 return SDValue();
13431 }
13432
13433 // Try 32-bit splatted SIMD immediate.
tryAdvSIMDModImm32(unsigned NewOp,SDValue Op,SelectionDAG & DAG,const APInt & Bits,const SDValue * LHS=nullptr)13434 static SDValue tryAdvSIMDModImm32(unsigned NewOp, SDValue Op, SelectionDAG &DAG,
13435 const APInt &Bits,
13436 const SDValue *LHS = nullptr) {
13437 EVT VT = Op.getValueType();
13438 if (VT.isFixedLengthVector() &&
13439 !DAG.getSubtarget<AArch64Subtarget>().isNeonAvailable())
13440 return SDValue();
13441
13442 if (Bits.getHiBits(64) == Bits.getLoBits(64)) {
13443 uint64_t Value = Bits.zextOrTrunc(64).getZExtValue();
13444 MVT MovTy = (VT.getSizeInBits() == 128) ? MVT::v4i32 : MVT::v2i32;
13445 bool isAdvSIMDModImm = false;
13446 uint64_t Shift;
13447
13448 if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType1(Value))) {
13449 Value = AArch64_AM::encodeAdvSIMDModImmType1(Value);
13450 Shift = 0;
13451 }
13452 else if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType2(Value))) {
13453 Value = AArch64_AM::encodeAdvSIMDModImmType2(Value);
13454 Shift = 8;
13455 }
13456 else if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType3(Value))) {
13457 Value = AArch64_AM::encodeAdvSIMDModImmType3(Value);
13458 Shift = 16;
13459 }
13460 else if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType4(Value))) {
13461 Value = AArch64_AM::encodeAdvSIMDModImmType4(Value);
13462 Shift = 24;
13463 }
13464
13465 if (isAdvSIMDModImm) {
13466 SDLoc dl(Op);
13467 SDValue Mov;
13468
13469 if (LHS)
13470 Mov = DAG.getNode(NewOp, dl, MovTy,
13471 DAG.getNode(AArch64ISD::NVCAST, dl, MovTy, *LHS),
13472 DAG.getConstant(Value, dl, MVT::i32),
13473 DAG.getConstant(Shift, dl, MVT::i32));
13474 else
13475 Mov = DAG.getNode(NewOp, dl, MovTy,
13476 DAG.getConstant(Value, dl, MVT::i32),
13477 DAG.getConstant(Shift, dl, MVT::i32));
13478
13479 return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov);
13480 }
13481 }
13482
13483 return SDValue();
13484 }
13485
13486 // Try 16-bit splatted SIMD immediate.
tryAdvSIMDModImm16(unsigned NewOp,SDValue Op,SelectionDAG & DAG,const APInt & Bits,const SDValue * LHS=nullptr)13487 static SDValue tryAdvSIMDModImm16(unsigned NewOp, SDValue Op, SelectionDAG &DAG,
13488 const APInt &Bits,
13489 const SDValue *LHS = nullptr) {
13490 EVT VT = Op.getValueType();
13491 if (VT.isFixedLengthVector() &&
13492 !DAG.getSubtarget<AArch64Subtarget>().isNeonAvailable())
13493 return SDValue();
13494
13495 if (Bits.getHiBits(64) == Bits.getLoBits(64)) {
13496 uint64_t Value = Bits.zextOrTrunc(64).getZExtValue();
13497 MVT MovTy = (VT.getSizeInBits() == 128) ? MVT::v8i16 : MVT::v4i16;
13498 bool isAdvSIMDModImm = false;
13499 uint64_t Shift;
13500
13501 if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType5(Value))) {
13502 Value = AArch64_AM::encodeAdvSIMDModImmType5(Value);
13503 Shift = 0;
13504 }
13505 else if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType6(Value))) {
13506 Value = AArch64_AM::encodeAdvSIMDModImmType6(Value);
13507 Shift = 8;
13508 }
13509
13510 if (isAdvSIMDModImm) {
13511 SDLoc dl(Op);
13512 SDValue Mov;
13513
13514 if (LHS)
13515 Mov = DAG.getNode(NewOp, dl, MovTy,
13516 DAG.getNode(AArch64ISD::NVCAST, dl, MovTy, *LHS),
13517 DAG.getConstant(Value, dl, MVT::i32),
13518 DAG.getConstant(Shift, dl, MVT::i32));
13519 else
13520 Mov = DAG.getNode(NewOp, dl, MovTy,
13521 DAG.getConstant(Value, dl, MVT::i32),
13522 DAG.getConstant(Shift, dl, MVT::i32));
13523
13524 return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov);
13525 }
13526 }
13527
13528 return SDValue();
13529 }
13530
13531 // Try 32-bit splatted SIMD immediate with shifted ones.
tryAdvSIMDModImm321s(unsigned NewOp,SDValue Op,SelectionDAG & DAG,const APInt & Bits)13532 static SDValue tryAdvSIMDModImm321s(unsigned NewOp, SDValue Op,
13533 SelectionDAG &DAG, const APInt &Bits) {
13534 if (Bits.getHiBits(64) == Bits.getLoBits(64)) {
13535 uint64_t Value = Bits.zextOrTrunc(64).getZExtValue();
13536 EVT VT = Op.getValueType();
13537 MVT MovTy = (VT.getSizeInBits() == 128) ? MVT::v4i32 : MVT::v2i32;
13538 bool isAdvSIMDModImm = false;
13539 uint64_t Shift;
13540
13541 if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType7(Value))) {
13542 Value = AArch64_AM::encodeAdvSIMDModImmType7(Value);
13543 Shift = 264;
13544 }
13545 else if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType8(Value))) {
13546 Value = AArch64_AM::encodeAdvSIMDModImmType8(Value);
13547 Shift = 272;
13548 }
13549
13550 if (isAdvSIMDModImm) {
13551 SDLoc dl(Op);
13552 SDValue Mov = DAG.getNode(NewOp, dl, MovTy,
13553 DAG.getConstant(Value, dl, MVT::i32),
13554 DAG.getConstant(Shift, dl, MVT::i32));
13555 return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov);
13556 }
13557 }
13558
13559 return SDValue();
13560 }
13561
13562 // Try 8-bit splatted SIMD immediate.
tryAdvSIMDModImm8(unsigned NewOp,SDValue Op,SelectionDAG & DAG,const APInt & Bits)13563 static SDValue tryAdvSIMDModImm8(unsigned NewOp, SDValue Op, SelectionDAG &DAG,
13564 const APInt &Bits) {
13565 if (Bits.getHiBits(64) == Bits.getLoBits(64)) {
13566 uint64_t Value = Bits.zextOrTrunc(64).getZExtValue();
13567 EVT VT = Op.getValueType();
13568 MVT MovTy = (VT.getSizeInBits() == 128) ? MVT::v16i8 : MVT::v8i8;
13569
13570 if (AArch64_AM::isAdvSIMDModImmType9(Value)) {
13571 Value = AArch64_AM::encodeAdvSIMDModImmType9(Value);
13572
13573 SDLoc dl(Op);
13574 SDValue Mov = DAG.getNode(NewOp, dl, MovTy,
13575 DAG.getConstant(Value, dl, MVT::i32));
13576 return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov);
13577 }
13578 }
13579
13580 return SDValue();
13581 }
13582
13583 // Try FP splatted SIMD immediate.
tryAdvSIMDModImmFP(unsigned NewOp,SDValue Op,SelectionDAG & DAG,const APInt & Bits)13584 static SDValue tryAdvSIMDModImmFP(unsigned NewOp, SDValue Op, SelectionDAG &DAG,
13585 const APInt &Bits) {
13586 if (Bits.getHiBits(64) == Bits.getLoBits(64)) {
13587 uint64_t Value = Bits.zextOrTrunc(64).getZExtValue();
13588 EVT VT = Op.getValueType();
13589 bool isWide = (VT.getSizeInBits() == 128);
13590 MVT MovTy;
13591 bool isAdvSIMDModImm = false;
13592
13593 if ((isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType11(Value))) {
13594 Value = AArch64_AM::encodeAdvSIMDModImmType11(Value);
13595 MovTy = isWide ? MVT::v4f32 : MVT::v2f32;
13596 }
13597 else if (isWide &&
13598 (isAdvSIMDModImm = AArch64_AM::isAdvSIMDModImmType12(Value))) {
13599 Value = AArch64_AM::encodeAdvSIMDModImmType12(Value);
13600 MovTy = MVT::v2f64;
13601 }
13602
13603 if (isAdvSIMDModImm) {
13604 SDLoc dl(Op);
13605 SDValue Mov = DAG.getNode(NewOp, dl, MovTy,
13606 DAG.getConstant(Value, dl, MVT::i32));
13607 return DAG.getNode(AArch64ISD::NVCAST, dl, VT, Mov);
13608 }
13609 }
13610
13611 return SDValue();
13612 }
13613
13614 // Specialized code to quickly find if PotentialBVec is a BuildVector that
13615 // consists of only the same constant int value, returned in reference arg
13616 // ConstVal
isAllConstantBuildVector(const SDValue & PotentialBVec,uint64_t & ConstVal)13617 static bool isAllConstantBuildVector(const SDValue &PotentialBVec,
13618 uint64_t &ConstVal) {
13619 BuildVectorSDNode *Bvec = dyn_cast<BuildVectorSDNode>(PotentialBVec);
13620 if (!Bvec)
13621 return false;
13622 ConstantSDNode *FirstElt = dyn_cast<ConstantSDNode>(Bvec->getOperand(0));
13623 if (!FirstElt)
13624 return false;
13625 EVT VT = Bvec->getValueType(0);
13626 unsigned NumElts = VT.getVectorNumElements();
13627 for (unsigned i = 1; i < NumElts; ++i)
13628 if (dyn_cast<ConstantSDNode>(Bvec->getOperand(i)) != FirstElt)
13629 return false;
13630 ConstVal = FirstElt->getZExtValue();
13631 return true;
13632 }
13633
isAllInactivePredicate(SDValue N)13634 static bool isAllInactivePredicate(SDValue N) {
13635 // Look through cast.
13636 while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST)
13637 N = N.getOperand(0);
13638
13639 return ISD::isConstantSplatVectorAllZeros(N.getNode());
13640 }
13641
isAllActivePredicate(SelectionDAG & DAG,SDValue N)13642 static bool isAllActivePredicate(SelectionDAG &DAG, SDValue N) {
13643 unsigned NumElts = N.getValueType().getVectorMinNumElements();
13644
13645 // Look through cast.
13646 while (N.getOpcode() == AArch64ISD::REINTERPRET_CAST) {
13647 N = N.getOperand(0);
13648 // When reinterpreting from a type with fewer elements the "new" elements
13649 // are not active, so bail if they're likely to be used.
13650 if (N.getValueType().getVectorMinNumElements() < NumElts)
13651 return false;
13652 }
13653
13654 if (ISD::isConstantSplatVectorAllOnes(N.getNode()))
13655 return true;
13656
13657 // "ptrue p.<ty>, all" can be considered all active when <ty> is the same size
13658 // or smaller than the implicit element type represented by N.
13659 // NOTE: A larger element count implies a smaller element type.
13660 if (N.getOpcode() == AArch64ISD::PTRUE &&
13661 N.getConstantOperandVal(0) == AArch64SVEPredPattern::all)
13662 return N.getValueType().getVectorMinNumElements() >= NumElts;
13663
13664 // If we're compiling for a specific vector-length, we can check if the
13665 // pattern's VL equals that of the scalable vector at runtime.
13666 if (N.getOpcode() == AArch64ISD::PTRUE) {
13667 const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
13668 unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits();
13669 unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits();
13670 if (MaxSVESize && MinSVESize == MaxSVESize) {
13671 unsigned VScale = MaxSVESize / AArch64::SVEBitsPerBlock;
13672 unsigned PatNumElts =
13673 getNumElementsFromSVEPredPattern(N.getConstantOperandVal(0));
13674 return PatNumElts == (NumElts * VScale);
13675 }
13676 }
13677
13678 return false;
13679 }
13680
13681 // Attempt to form a vector S[LR]I from (or (and X, BvecC1), (lsl Y, C2)),
13682 // to (SLI X, Y, C2), where X and Y have matching vector types, BvecC1 is a
13683 // BUILD_VECTORs with constant element C1, C2 is a constant, and:
13684 // - for the SLI case: C1 == ~(Ones(ElemSizeInBits) << C2)
13685 // - for the SRI case: C1 == ~(Ones(ElemSizeInBits) >> C2)
13686 // The (or (lsl Y, C2), (and X, BvecC1)) case is also handled.
tryLowerToSLI(SDNode * N,SelectionDAG & DAG)13687 static SDValue tryLowerToSLI(SDNode *N, SelectionDAG &DAG) {
13688 EVT VT = N->getValueType(0);
13689
13690 if (!VT.isVector())
13691 return SDValue();
13692
13693 SDLoc DL(N);
13694
13695 SDValue And;
13696 SDValue Shift;
13697
13698 SDValue FirstOp = N->getOperand(0);
13699 unsigned FirstOpc = FirstOp.getOpcode();
13700 SDValue SecondOp = N->getOperand(1);
13701 unsigned SecondOpc = SecondOp.getOpcode();
13702
13703 // Is one of the operands an AND or a BICi? The AND may have been optimised to
13704 // a BICi in order to use an immediate instead of a register.
13705 // Is the other operand an shl or lshr? This will have been turned into:
13706 // AArch64ISD::VSHL vector, #shift or AArch64ISD::VLSHR vector, #shift
13707 // or (AArch64ISD::SHL_PRED || AArch64ISD::SRL_PRED) mask, vector, #shiftVec.
13708 if ((FirstOpc == ISD::AND || FirstOpc == AArch64ISD::BICi) &&
13709 (SecondOpc == AArch64ISD::VSHL || SecondOpc == AArch64ISD::VLSHR ||
13710 SecondOpc == AArch64ISD::SHL_PRED ||
13711 SecondOpc == AArch64ISD::SRL_PRED)) {
13712 And = FirstOp;
13713 Shift = SecondOp;
13714
13715 } else if ((SecondOpc == ISD::AND || SecondOpc == AArch64ISD::BICi) &&
13716 (FirstOpc == AArch64ISD::VSHL || FirstOpc == AArch64ISD::VLSHR ||
13717 FirstOpc == AArch64ISD::SHL_PRED ||
13718 FirstOpc == AArch64ISD::SRL_PRED)) {
13719 And = SecondOp;
13720 Shift = FirstOp;
13721 } else
13722 return SDValue();
13723
13724 bool IsAnd = And.getOpcode() == ISD::AND;
13725 bool IsShiftRight = Shift.getOpcode() == AArch64ISD::VLSHR ||
13726 Shift.getOpcode() == AArch64ISD::SRL_PRED;
13727 bool ShiftHasPredOp = Shift.getOpcode() == AArch64ISD::SHL_PRED ||
13728 Shift.getOpcode() == AArch64ISD::SRL_PRED;
13729
13730 // Is the shift amount constant and are all lanes active?
13731 uint64_t C2;
13732 if (ShiftHasPredOp) {
13733 if (!isAllActivePredicate(DAG, Shift.getOperand(0)))
13734 return SDValue();
13735 APInt C;
13736 if (!ISD::isConstantSplatVector(Shift.getOperand(2).getNode(), C))
13737 return SDValue();
13738 C2 = C.getZExtValue();
13739 } else if (ConstantSDNode *C2node =
13740 dyn_cast<ConstantSDNode>(Shift.getOperand(1)))
13741 C2 = C2node->getZExtValue();
13742 else
13743 return SDValue();
13744
13745 APInt C1AsAPInt;
13746 unsigned ElemSizeInBits = VT.getScalarSizeInBits();
13747 if (IsAnd) {
13748 // Is the and mask vector all constant?
13749 if (!ISD::isConstantSplatVector(And.getOperand(1).getNode(), C1AsAPInt))
13750 return SDValue();
13751 } else {
13752 // Reconstruct the corresponding AND immediate from the two BICi immediates.
13753 ConstantSDNode *C1nodeImm = dyn_cast<ConstantSDNode>(And.getOperand(1));
13754 ConstantSDNode *C1nodeShift = dyn_cast<ConstantSDNode>(And.getOperand(2));
13755 assert(C1nodeImm && C1nodeShift);
13756 C1AsAPInt = ~(C1nodeImm->getAPIntValue() << C1nodeShift->getAPIntValue());
13757 C1AsAPInt = C1AsAPInt.zextOrTrunc(ElemSizeInBits);
13758 }
13759
13760 // Is C1 == ~(Ones(ElemSizeInBits) << C2) or
13761 // C1 == ~(Ones(ElemSizeInBits) >> C2), taking into account
13762 // how much one can shift elements of a particular size?
13763 if (C2 > ElemSizeInBits)
13764 return SDValue();
13765
13766 APInt RequiredC1 = IsShiftRight ? APInt::getHighBitsSet(ElemSizeInBits, C2)
13767 : APInt::getLowBitsSet(ElemSizeInBits, C2);
13768 if (C1AsAPInt != RequiredC1)
13769 return SDValue();
13770
13771 SDValue X = And.getOperand(0);
13772 SDValue Y = ShiftHasPredOp ? Shift.getOperand(1) : Shift.getOperand(0);
13773 SDValue Imm = ShiftHasPredOp ? DAG.getTargetConstant(C2, DL, MVT::i32)
13774 : Shift.getOperand(1);
13775
13776 unsigned Inst = IsShiftRight ? AArch64ISD::VSRI : AArch64ISD::VSLI;
13777 SDValue ResultSLI = DAG.getNode(Inst, DL, VT, X, Y, Imm);
13778
13779 LLVM_DEBUG(dbgs() << "aarch64-lower: transformed: \n");
13780 LLVM_DEBUG(N->dump(&DAG));
13781 LLVM_DEBUG(dbgs() << "into: \n");
13782 LLVM_DEBUG(ResultSLI->dump(&DAG));
13783
13784 ++NumShiftInserts;
13785 return ResultSLI;
13786 }
13787
LowerVectorOR(SDValue Op,SelectionDAG & DAG) const13788 SDValue AArch64TargetLowering::LowerVectorOR(SDValue Op,
13789 SelectionDAG &DAG) const {
13790 if (useSVEForFixedLengthVectorVT(Op.getValueType(),
13791 !Subtarget->isNeonAvailable()))
13792 return LowerToScalableOp(Op, DAG);
13793
13794 // Attempt to form a vector S[LR]I from (or (and X, C1), (lsl Y, C2))
13795 if (SDValue Res = tryLowerToSLI(Op.getNode(), DAG))
13796 return Res;
13797
13798 EVT VT = Op.getValueType();
13799 if (VT.isScalableVector())
13800 return Op;
13801
13802 SDValue LHS = Op.getOperand(0);
13803 BuildVectorSDNode *BVN =
13804 dyn_cast<BuildVectorSDNode>(Op.getOperand(1).getNode());
13805 if (!BVN) {
13806 // OR commutes, so try swapping the operands.
13807 LHS = Op.getOperand(1);
13808 BVN = dyn_cast<BuildVectorSDNode>(Op.getOperand(0).getNode());
13809 }
13810 if (!BVN)
13811 return Op;
13812
13813 APInt DefBits(VT.getSizeInBits(), 0);
13814 APInt UndefBits(VT.getSizeInBits(), 0);
13815 if (resolveBuildVector(BVN, DefBits, UndefBits)) {
13816 SDValue NewOp;
13817
13818 if ((NewOp = tryAdvSIMDModImm32(AArch64ISD::ORRi, Op, DAG,
13819 DefBits, &LHS)) ||
13820 (NewOp = tryAdvSIMDModImm16(AArch64ISD::ORRi, Op, DAG,
13821 DefBits, &LHS)))
13822 return NewOp;
13823
13824 if ((NewOp = tryAdvSIMDModImm32(AArch64ISD::ORRi, Op, DAG,
13825 UndefBits, &LHS)) ||
13826 (NewOp = tryAdvSIMDModImm16(AArch64ISD::ORRi, Op, DAG,
13827 UndefBits, &LHS)))
13828 return NewOp;
13829 }
13830
13831 // We can always fall back to a non-immediate OR.
13832 return Op;
13833 }
13834
13835 // Normalize the operands of BUILD_VECTOR. The value of constant operands will
13836 // be truncated to fit element width.
NormalizeBuildVector(SDValue Op,SelectionDAG & DAG)13837 static SDValue NormalizeBuildVector(SDValue Op,
13838 SelectionDAG &DAG) {
13839 assert(Op.getOpcode() == ISD::BUILD_VECTOR && "Unknown opcode!");
13840 SDLoc dl(Op);
13841 EVT VT = Op.getValueType();
13842 EVT EltTy= VT.getVectorElementType();
13843
13844 if (EltTy.isFloatingPoint() || EltTy.getSizeInBits() > 16)
13845 return Op;
13846
13847 SmallVector<SDValue, 16> Ops;
13848 for (SDValue Lane : Op->ops()) {
13849 // For integer vectors, type legalization would have promoted the
13850 // operands already. Otherwise, if Op is a floating-point splat
13851 // (with operands cast to integers), then the only possibilities
13852 // are constants and UNDEFs.
13853 if (auto *CstLane = dyn_cast<ConstantSDNode>(Lane)) {
13854 APInt LowBits(EltTy.getSizeInBits(),
13855 CstLane->getZExtValue());
13856 Lane = DAG.getConstant(LowBits.getZExtValue(), dl, MVT::i32);
13857 } else if (Lane.getNode()->isUndef()) {
13858 Lane = DAG.getUNDEF(MVT::i32);
13859 } else {
13860 assert(Lane.getValueType() == MVT::i32 &&
13861 "Unexpected BUILD_VECTOR operand type");
13862 }
13863 Ops.push_back(Lane);
13864 }
13865 return DAG.getBuildVector(VT, dl, Ops);
13866 }
13867
ConstantBuildVector(SDValue Op,SelectionDAG & DAG,const AArch64Subtarget * ST)13868 static SDValue ConstantBuildVector(SDValue Op, SelectionDAG &DAG,
13869 const AArch64Subtarget *ST) {
13870 EVT VT = Op.getValueType();
13871 assert((VT.getSizeInBits() == 64 || VT.getSizeInBits() == 128) &&
13872 "Expected a legal NEON vector");
13873
13874 APInt DefBits(VT.getSizeInBits(), 0);
13875 APInt UndefBits(VT.getSizeInBits(), 0);
13876 BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(Op.getNode());
13877 if (resolveBuildVector(BVN, DefBits, UndefBits)) {
13878 auto TryMOVIWithBits = [&](APInt DefBits) {
13879 SDValue NewOp;
13880 if ((NewOp =
13881 tryAdvSIMDModImm64(AArch64ISD::MOVIedit, Op, DAG, DefBits)) ||
13882 (NewOp =
13883 tryAdvSIMDModImm32(AArch64ISD::MOVIshift, Op, DAG, DefBits)) ||
13884 (NewOp =
13885 tryAdvSIMDModImm321s(AArch64ISD::MOVImsl, Op, DAG, DefBits)) ||
13886 (NewOp =
13887 tryAdvSIMDModImm16(AArch64ISD::MOVIshift, Op, DAG, DefBits)) ||
13888 (NewOp = tryAdvSIMDModImm8(AArch64ISD::MOVI, Op, DAG, DefBits)) ||
13889 (NewOp = tryAdvSIMDModImmFP(AArch64ISD::FMOV, Op, DAG, DefBits)))
13890 return NewOp;
13891
13892 APInt NotDefBits = ~DefBits;
13893 if ((NewOp = tryAdvSIMDModImm32(AArch64ISD::MVNIshift, Op, DAG,
13894 NotDefBits)) ||
13895 (NewOp = tryAdvSIMDModImm321s(AArch64ISD::MVNImsl, Op, DAG,
13896 NotDefBits)) ||
13897 (NewOp =
13898 tryAdvSIMDModImm16(AArch64ISD::MVNIshift, Op, DAG, NotDefBits)))
13899 return NewOp;
13900 return SDValue();
13901 };
13902 if (SDValue R = TryMOVIWithBits(DefBits))
13903 return R;
13904 if (SDValue R = TryMOVIWithBits(UndefBits))
13905 return R;
13906
13907 // See if a fneg of the constant can be materialized with a MOVI, etc
13908 auto TryWithFNeg = [&](APInt DefBits, MVT FVT) {
13909 // FNegate each sub-element of the constant
13910 assert(VT.getSizeInBits() % FVT.getScalarSizeInBits() == 0);
13911 APInt Neg = APInt::getHighBitsSet(FVT.getSizeInBits(), 1)
13912 .zext(VT.getSizeInBits());
13913 APInt NegBits(VT.getSizeInBits(), 0);
13914 unsigned NumElts = VT.getSizeInBits() / FVT.getScalarSizeInBits();
13915 for (unsigned i = 0; i < NumElts; i++)
13916 NegBits |= Neg << (FVT.getScalarSizeInBits() * i);
13917 NegBits = DefBits ^ NegBits;
13918
13919 // Try to create the new constants with MOVI, and if so generate a fneg
13920 // for it.
13921 if (SDValue NewOp = TryMOVIWithBits(NegBits)) {
13922 SDLoc DL(Op);
13923 MVT VFVT = NumElts == 1 ? FVT : MVT::getVectorVT(FVT, NumElts);
13924 return DAG.getNode(
13925 AArch64ISD::NVCAST, DL, VT,
13926 DAG.getNode(ISD::FNEG, DL, VFVT,
13927 DAG.getNode(AArch64ISD::NVCAST, DL, VFVT, NewOp)));
13928 }
13929 return SDValue();
13930 };
13931 SDValue R;
13932 if ((R = TryWithFNeg(DefBits, MVT::f32)) ||
13933 (R = TryWithFNeg(DefBits, MVT::f64)) ||
13934 (ST->hasFullFP16() && (R = TryWithFNeg(DefBits, MVT::f16))))
13935 return R;
13936 }
13937
13938 return SDValue();
13939 }
13940
LowerBUILD_VECTOR(SDValue Op,SelectionDAG & DAG) const13941 SDValue AArch64TargetLowering::LowerBUILD_VECTOR(SDValue Op,
13942 SelectionDAG &DAG) const {
13943 EVT VT = Op.getValueType();
13944
13945 if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) {
13946 if (auto SeqInfo = cast<BuildVectorSDNode>(Op)->isConstantSequence()) {
13947 SDLoc DL(Op);
13948 EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
13949 SDValue Start = DAG.getConstant(SeqInfo->first, DL, ContainerVT);
13950 SDValue Steps = DAG.getStepVector(DL, ContainerVT, SeqInfo->second);
13951 SDValue Seq = DAG.getNode(ISD::ADD, DL, ContainerVT, Start, Steps);
13952 return convertFromScalableVector(DAG, Op.getValueType(), Seq);
13953 }
13954
13955 // Revert to common legalisation for all other variants.
13956 return SDValue();
13957 }
13958
13959 // Try to build a simple constant vector.
13960 Op = NormalizeBuildVector(Op, DAG);
13961 // Thought this might return a non-BUILD_VECTOR (e.g. CONCAT_VECTORS), if so,
13962 // abort.
13963 if (Op.getOpcode() != ISD::BUILD_VECTOR)
13964 return SDValue();
13965
13966 // Certain vector constants, used to express things like logical NOT and
13967 // arithmetic NEG, are passed through unmodified. This allows special
13968 // patterns for these operations to match, which will lower these constants
13969 // to whatever is proven necessary.
13970 BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(Op.getNode());
13971 if (BVN->isConstant()) {
13972 if (ConstantSDNode *Const = BVN->getConstantSplatNode()) {
13973 unsigned BitSize = VT.getVectorElementType().getSizeInBits();
13974 APInt Val(BitSize,
13975 Const->getAPIntValue().zextOrTrunc(BitSize).getZExtValue());
13976 if (Val.isZero() || (VT.isInteger() && Val.isAllOnes()))
13977 return Op;
13978 }
13979 if (ConstantFPSDNode *Const = BVN->getConstantFPSplatNode())
13980 if (Const->isZero() && !Const->isNegative())
13981 return Op;
13982 }
13983
13984 if (SDValue V = ConstantBuildVector(Op, DAG, Subtarget))
13985 return V;
13986
13987 // Scan through the operands to find some interesting properties we can
13988 // exploit:
13989 // 1) If only one value is used, we can use a DUP, or
13990 // 2) if only the low element is not undef, we can just insert that, or
13991 // 3) if only one constant value is used (w/ some non-constant lanes),
13992 // we can splat the constant value into the whole vector then fill
13993 // in the non-constant lanes.
13994 // 4) FIXME: If different constant values are used, but we can intelligently
13995 // select the values we'll be overwriting for the non-constant
13996 // lanes such that we can directly materialize the vector
13997 // some other way (MOVI, e.g.), we can be sneaky.
13998 // 5) if all operands are EXTRACT_VECTOR_ELT, check for VUZP.
13999 SDLoc dl(Op);
14000 unsigned NumElts = VT.getVectorNumElements();
14001 bool isOnlyLowElement = true;
14002 bool usesOnlyOneValue = true;
14003 bool usesOnlyOneConstantValue = true;
14004 bool isConstant = true;
14005 bool AllLanesExtractElt = true;
14006 unsigned NumConstantLanes = 0;
14007 unsigned NumDifferentLanes = 0;
14008 unsigned NumUndefLanes = 0;
14009 SDValue Value;
14010 SDValue ConstantValue;
14011 SmallMapVector<SDValue, unsigned, 16> DifferentValueMap;
14012 unsigned ConsecutiveValCount = 0;
14013 SDValue PrevVal;
14014 for (unsigned i = 0; i < NumElts; ++i) {
14015 SDValue V = Op.getOperand(i);
14016 if (V.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
14017 AllLanesExtractElt = false;
14018 if (V.isUndef()) {
14019 ++NumUndefLanes;
14020 continue;
14021 }
14022 if (i > 0)
14023 isOnlyLowElement = false;
14024 if (!isIntOrFPConstant(V))
14025 isConstant = false;
14026
14027 if (isIntOrFPConstant(V)) {
14028 ++NumConstantLanes;
14029 if (!ConstantValue.getNode())
14030 ConstantValue = V;
14031 else if (ConstantValue != V)
14032 usesOnlyOneConstantValue = false;
14033 }
14034
14035 if (!Value.getNode())
14036 Value = V;
14037 else if (V != Value) {
14038 usesOnlyOneValue = false;
14039 ++NumDifferentLanes;
14040 }
14041
14042 if (PrevVal != V) {
14043 ConsecutiveValCount = 0;
14044 PrevVal = V;
14045 }
14046
14047 // Keep different values and its last consecutive count. For example,
14048 //
14049 // t22: v16i8 = build_vector t23, t23, t23, t23, t23, t23, t23, t23,
14050 // t24, t24, t24, t24, t24, t24, t24, t24
14051 // t23 = consecutive count 8
14052 // t24 = consecutive count 8
14053 // ------------------------------------------------------------------
14054 // t22: v16i8 = build_vector t24, t24, t23, t23, t23, t23, t23, t24,
14055 // t24, t24, t24, t24, t24, t24, t24, t24
14056 // t23 = consecutive count 5
14057 // t24 = consecutive count 9
14058 DifferentValueMap[V] = ++ConsecutiveValCount;
14059 }
14060
14061 if (!Value.getNode()) {
14062 LLVM_DEBUG(
14063 dbgs() << "LowerBUILD_VECTOR: value undefined, creating undef node\n");
14064 return DAG.getUNDEF(VT);
14065 }
14066
14067 // Convert BUILD_VECTOR where all elements but the lowest are undef into
14068 // SCALAR_TO_VECTOR, except for when we have a single-element constant vector
14069 // as SimplifyDemandedBits will just turn that back into BUILD_VECTOR.
14070 if (isOnlyLowElement && !(NumElts == 1 && isIntOrFPConstant(Value))) {
14071 LLVM_DEBUG(dbgs() << "LowerBUILD_VECTOR: only low element used, creating 1 "
14072 "SCALAR_TO_VECTOR node\n");
14073 return DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Value);
14074 }
14075
14076 if (AllLanesExtractElt) {
14077 SDNode *Vector = nullptr;
14078 bool Even = false;
14079 bool Odd = false;
14080 // Check whether the extract elements match the Even pattern <0,2,4,...> or
14081 // the Odd pattern <1,3,5,...>.
14082 for (unsigned i = 0; i < NumElts; ++i) {
14083 SDValue V = Op.getOperand(i);
14084 const SDNode *N = V.getNode();
14085 if (!isa<ConstantSDNode>(N->getOperand(1))) {
14086 Even = false;
14087 Odd = false;
14088 break;
14089 }
14090 SDValue N0 = N->getOperand(0);
14091
14092 // All elements are extracted from the same vector.
14093 if (!Vector) {
14094 Vector = N0.getNode();
14095 // Check that the type of EXTRACT_VECTOR_ELT matches the type of
14096 // BUILD_VECTOR.
14097 if (VT.getVectorElementType() !=
14098 N0.getValueType().getVectorElementType())
14099 break;
14100 } else if (Vector != N0.getNode()) {
14101 Odd = false;
14102 Even = false;
14103 break;
14104 }
14105
14106 // Extracted values are either at Even indices <0,2,4,...> or at Odd
14107 // indices <1,3,5,...>.
14108 uint64_t Val = N->getConstantOperandVal(1);
14109 if (Val == 2 * i) {
14110 Even = true;
14111 continue;
14112 }
14113 if (Val - 1 == 2 * i) {
14114 Odd = true;
14115 continue;
14116 }
14117
14118 // Something does not match: abort.
14119 Odd = false;
14120 Even = false;
14121 break;
14122 }
14123 if (Even || Odd) {
14124 SDValue LHS =
14125 DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, SDValue(Vector, 0),
14126 DAG.getConstant(0, dl, MVT::i64));
14127 SDValue RHS =
14128 DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, SDValue(Vector, 0),
14129 DAG.getConstant(NumElts, dl, MVT::i64));
14130
14131 if (Even && !Odd)
14132 return DAG.getNode(AArch64ISD::UZP1, dl, VT, LHS, RHS);
14133 if (Odd && !Even)
14134 return DAG.getNode(AArch64ISD::UZP2, dl, VT, LHS, RHS);
14135 }
14136 }
14137
14138 // Use DUP for non-constant splats. For f32 constant splats, reduce to
14139 // i32 and try again.
14140 if (usesOnlyOneValue) {
14141 if (!isConstant) {
14142 if (Value.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
14143 Value.getValueType() != VT) {
14144 LLVM_DEBUG(
14145 dbgs() << "LowerBUILD_VECTOR: use DUP for non-constant splats\n");
14146 return DAG.getNode(AArch64ISD::DUP, dl, VT, Value);
14147 }
14148
14149 // This is actually a DUPLANExx operation, which keeps everything vectory.
14150
14151 SDValue Lane = Value.getOperand(1);
14152 Value = Value.getOperand(0);
14153 if (Value.getValueSizeInBits() == 64) {
14154 LLVM_DEBUG(
14155 dbgs() << "LowerBUILD_VECTOR: DUPLANE works on 128-bit vectors, "
14156 "widening it\n");
14157 Value = WidenVector(Value, DAG);
14158 }
14159
14160 unsigned Opcode = getDUPLANEOp(VT.getVectorElementType());
14161 return DAG.getNode(Opcode, dl, VT, Value, Lane);
14162 }
14163
14164 if (VT.getVectorElementType().isFloatingPoint()) {
14165 SmallVector<SDValue, 8> Ops;
14166 EVT EltTy = VT.getVectorElementType();
14167 assert ((EltTy == MVT::f16 || EltTy == MVT::bf16 || EltTy == MVT::f32 ||
14168 EltTy == MVT::f64) && "Unsupported floating-point vector type");
14169 LLVM_DEBUG(
14170 dbgs() << "LowerBUILD_VECTOR: float constant splats, creating int "
14171 "BITCASTS, and try again\n");
14172 MVT NewType = MVT::getIntegerVT(EltTy.getSizeInBits());
14173 for (unsigned i = 0; i < NumElts; ++i)
14174 Ops.push_back(DAG.getNode(ISD::BITCAST, dl, NewType, Op.getOperand(i)));
14175 EVT VecVT = EVT::getVectorVT(*DAG.getContext(), NewType, NumElts);
14176 SDValue Val = DAG.getBuildVector(VecVT, dl, Ops);
14177 LLVM_DEBUG(dbgs() << "LowerBUILD_VECTOR: trying to lower new vector: ";
14178 Val.dump(););
14179 Val = LowerBUILD_VECTOR(Val, DAG);
14180 if (Val.getNode())
14181 return DAG.getNode(ISD::BITCAST, dl, VT, Val);
14182 }
14183 }
14184
14185 // If we need to insert a small number of different non-constant elements and
14186 // the vector width is sufficiently large, prefer using DUP with the common
14187 // value and INSERT_VECTOR_ELT for the different lanes. If DUP is preferred,
14188 // skip the constant lane handling below.
14189 bool PreferDUPAndInsert =
14190 !isConstant && NumDifferentLanes >= 1 &&
14191 NumDifferentLanes < ((NumElts - NumUndefLanes) / 2) &&
14192 NumDifferentLanes >= NumConstantLanes;
14193
14194 // If there was only one constant value used and for more than one lane,
14195 // start by splatting that value, then replace the non-constant lanes. This
14196 // is better than the default, which will perform a separate initialization
14197 // for each lane.
14198 if (!PreferDUPAndInsert && NumConstantLanes > 0 && usesOnlyOneConstantValue) {
14199 // Firstly, try to materialize the splat constant.
14200 SDValue Val = DAG.getSplatBuildVector(VT, dl, ConstantValue);
14201 unsigned BitSize = VT.getScalarSizeInBits();
14202 APInt ConstantValueAPInt(1, 0);
14203 if (auto *C = dyn_cast<ConstantSDNode>(ConstantValue))
14204 ConstantValueAPInt = C->getAPIntValue().zextOrTrunc(BitSize);
14205 if (!isNullConstant(ConstantValue) && !isNullFPConstant(ConstantValue) &&
14206 !ConstantValueAPInt.isAllOnes()) {
14207 Val = ConstantBuildVector(Val, DAG, Subtarget);
14208 if (!Val)
14209 // Otherwise, materialize the constant and splat it.
14210 Val = DAG.getNode(AArch64ISD::DUP, dl, VT, ConstantValue);
14211 }
14212
14213 // Now insert the non-constant lanes.
14214 for (unsigned i = 0; i < NumElts; ++i) {
14215 SDValue V = Op.getOperand(i);
14216 SDValue LaneIdx = DAG.getConstant(i, dl, MVT::i64);
14217 if (!isIntOrFPConstant(V))
14218 // Note that type legalization likely mucked about with the VT of the
14219 // source operand, so we may have to convert it here before inserting.
14220 Val = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, Val, V, LaneIdx);
14221 }
14222 return Val;
14223 }
14224
14225 // This will generate a load from the constant pool.
14226 if (isConstant) {
14227 LLVM_DEBUG(
14228 dbgs() << "LowerBUILD_VECTOR: all elements are constant, use default "
14229 "expansion\n");
14230 return SDValue();
14231 }
14232
14233 // Detect patterns of a0,a1,a2,a3,b0,b1,b2,b3,c0,c1,c2,c3,d0,d1,d2,d3 from
14234 // v4i32s. This is really a truncate, which we can construct out of (legal)
14235 // concats and truncate nodes.
14236 if (SDValue M = ReconstructTruncateFromBuildVector(Op, DAG))
14237 return M;
14238
14239 // Empirical tests suggest this is rarely worth it for vectors of length <= 2.
14240 if (NumElts >= 4) {
14241 if (SDValue Shuffle = ReconstructShuffle(Op, DAG))
14242 return Shuffle;
14243
14244 if (SDValue Shuffle = ReconstructShuffleWithRuntimeMask(Op, DAG))
14245 return Shuffle;
14246 }
14247
14248 if (PreferDUPAndInsert) {
14249 // First, build a constant vector with the common element.
14250 SmallVector<SDValue, 8> Ops(NumElts, Value);
14251 SDValue NewVector = LowerBUILD_VECTOR(DAG.getBuildVector(VT, dl, Ops), DAG);
14252 // Next, insert the elements that do not match the common value.
14253 for (unsigned I = 0; I < NumElts; ++I)
14254 if (Op.getOperand(I) != Value)
14255 NewVector =
14256 DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, NewVector,
14257 Op.getOperand(I), DAG.getConstant(I, dl, MVT::i64));
14258
14259 return NewVector;
14260 }
14261
14262 // If vector consists of two different values, try to generate two DUPs and
14263 // (CONCAT_VECTORS or VECTOR_SHUFFLE).
14264 if (DifferentValueMap.size() == 2 && NumUndefLanes == 0) {
14265 SmallVector<SDValue, 2> Vals;
14266 // Check the consecutive count of the value is the half number of vector
14267 // elements. In this case, we can use CONCAT_VECTORS. For example,
14268 //
14269 // canUseVECTOR_CONCAT = true;
14270 // t22: v16i8 = build_vector t23, t23, t23, t23, t23, t23, t23, t23,
14271 // t24, t24, t24, t24, t24, t24, t24, t24
14272 //
14273 // canUseVECTOR_CONCAT = false;
14274 // t22: v16i8 = build_vector t23, t23, t23, t23, t23, t24, t24, t24,
14275 // t24, t24, t24, t24, t24, t24, t24, t24
14276 bool canUseVECTOR_CONCAT = true;
14277 for (auto Pair : DifferentValueMap) {
14278 // Check different values have same length which is NumElts / 2.
14279 if (Pair.second != NumElts / 2)
14280 canUseVECTOR_CONCAT = false;
14281 Vals.push_back(Pair.first);
14282 }
14283
14284 // If canUseVECTOR_CONCAT is true, we can generate two DUPs and
14285 // CONCAT_VECTORs. For example,
14286 //
14287 // t22: v16i8 = BUILD_VECTOR t23, t23, t23, t23, t23, t23, t23, t23,
14288 // t24, t24, t24, t24, t24, t24, t24, t24
14289 // ==>
14290 // t26: v8i8 = AArch64ISD::DUP t23
14291 // t28: v8i8 = AArch64ISD::DUP t24
14292 // t29: v16i8 = concat_vectors t26, t28
14293 if (canUseVECTOR_CONCAT) {
14294 EVT SubVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
14295 if (isTypeLegal(SubVT) && SubVT.isVector() &&
14296 SubVT.getVectorNumElements() >= 2) {
14297 SmallVector<SDValue, 8> Ops1(NumElts / 2, Vals[0]);
14298 SmallVector<SDValue, 8> Ops2(NumElts / 2, Vals[1]);
14299 SDValue DUP1 =
14300 LowerBUILD_VECTOR(DAG.getBuildVector(SubVT, dl, Ops1), DAG);
14301 SDValue DUP2 =
14302 LowerBUILD_VECTOR(DAG.getBuildVector(SubVT, dl, Ops2), DAG);
14303 SDValue CONCAT_VECTORS =
14304 DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, DUP1, DUP2);
14305 return CONCAT_VECTORS;
14306 }
14307 }
14308
14309 // Let's try to generate VECTOR_SHUFFLE. For example,
14310 //
14311 // t24: v8i8 = BUILD_VECTOR t25, t25, t25, t25, t26, t26, t26, t26
14312 // ==>
14313 // t27: v8i8 = BUILD_VECTOR t26, t26, t26, t26, t26, t26, t26, t26
14314 // t28: v8i8 = BUILD_VECTOR t25, t25, t25, t25, t25, t25, t25, t25
14315 // t29: v8i8 = vector_shuffle<0,1,2,3,12,13,14,15> t27, t28
14316 if (NumElts >= 8) {
14317 SmallVector<int, 16> MaskVec;
14318 // Build mask for VECTOR_SHUFLLE.
14319 SDValue FirstLaneVal = Op.getOperand(0);
14320 for (unsigned i = 0; i < NumElts; ++i) {
14321 SDValue Val = Op.getOperand(i);
14322 if (FirstLaneVal == Val)
14323 MaskVec.push_back(i);
14324 else
14325 MaskVec.push_back(i + NumElts);
14326 }
14327
14328 SmallVector<SDValue, 8> Ops1(NumElts, Vals[0]);
14329 SmallVector<SDValue, 8> Ops2(NumElts, Vals[1]);
14330 SDValue VEC1 = DAG.getBuildVector(VT, dl, Ops1);
14331 SDValue VEC2 = DAG.getBuildVector(VT, dl, Ops2);
14332 SDValue VECTOR_SHUFFLE =
14333 DAG.getVectorShuffle(VT, dl, VEC1, VEC2, MaskVec);
14334 return VECTOR_SHUFFLE;
14335 }
14336 }
14337
14338 // If all else fails, just use a sequence of INSERT_VECTOR_ELT when we
14339 // know the default expansion would otherwise fall back on something even
14340 // worse. For a vector with one or two non-undef values, that's
14341 // scalar_to_vector for the elements followed by a shuffle (provided the
14342 // shuffle is valid for the target) and materialization element by element
14343 // on the stack followed by a load for everything else.
14344 if (!isConstant && !usesOnlyOneValue) {
14345 LLVM_DEBUG(
14346 dbgs() << "LowerBUILD_VECTOR: alternatives failed, creating sequence "
14347 "of INSERT_VECTOR_ELT\n");
14348
14349 SDValue Vec = DAG.getUNDEF(VT);
14350 SDValue Op0 = Op.getOperand(0);
14351 unsigned i = 0;
14352
14353 // Use SCALAR_TO_VECTOR for lane zero to
14354 // a) Avoid a RMW dependency on the full vector register, and
14355 // b) Allow the register coalescer to fold away the copy if the
14356 // value is already in an S or D register, and we're forced to emit an
14357 // INSERT_SUBREG that we can't fold anywhere.
14358 //
14359 // We also allow types like i8 and i16 which are illegal scalar but legal
14360 // vector element types. After type-legalization the inserted value is
14361 // extended (i32) and it is safe to cast them to the vector type by ignoring
14362 // the upper bits of the lowest lane (e.g. v8i8, v4i16).
14363 if (!Op0.isUndef()) {
14364 LLVM_DEBUG(dbgs() << "Creating node for op0, it is not undefined:\n");
14365 Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, VT, Op0);
14366 ++i;
14367 }
14368 LLVM_DEBUG(if (i < NumElts) dbgs()
14369 << "Creating nodes for the other vector elements:\n";);
14370 for (; i < NumElts; ++i) {
14371 SDValue V = Op.getOperand(i);
14372 if (V.isUndef())
14373 continue;
14374 SDValue LaneIdx = DAG.getConstant(i, dl, MVT::i64);
14375 Vec = DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, Vec, V, LaneIdx);
14376 }
14377 return Vec;
14378 }
14379
14380 LLVM_DEBUG(
14381 dbgs() << "LowerBUILD_VECTOR: use default expansion, failed to find "
14382 "better alternative\n");
14383 return SDValue();
14384 }
14385
LowerCONCAT_VECTORS(SDValue Op,SelectionDAG & DAG) const14386 SDValue AArch64TargetLowering::LowerCONCAT_VECTORS(SDValue Op,
14387 SelectionDAG &DAG) const {
14388 if (useSVEForFixedLengthVectorVT(Op.getValueType(),
14389 !Subtarget->isNeonAvailable()))
14390 return LowerFixedLengthConcatVectorsToSVE(Op, DAG);
14391
14392 assert(Op.getValueType().isScalableVector() &&
14393 isTypeLegal(Op.getValueType()) &&
14394 "Expected legal scalable vector type!");
14395
14396 if (isTypeLegal(Op.getOperand(0).getValueType())) {
14397 unsigned NumOperands = Op->getNumOperands();
14398 assert(NumOperands > 1 && isPowerOf2_32(NumOperands) &&
14399 "Unexpected number of operands in CONCAT_VECTORS");
14400
14401 if (NumOperands == 2)
14402 return Op;
14403
14404 // Concat each pair of subvectors and pack into the lower half of the array.
14405 SmallVector<SDValue> ConcatOps(Op->op_begin(), Op->op_end());
14406 while (ConcatOps.size() > 1) {
14407 for (unsigned I = 0, E = ConcatOps.size(); I != E; I += 2) {
14408 SDValue V1 = ConcatOps[I];
14409 SDValue V2 = ConcatOps[I + 1];
14410 EVT SubVT = V1.getValueType();
14411 EVT PairVT = SubVT.getDoubleNumVectorElementsVT(*DAG.getContext());
14412 ConcatOps[I / 2] =
14413 DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(Op), PairVT, V1, V2);
14414 }
14415 ConcatOps.resize(ConcatOps.size() / 2);
14416 }
14417 return ConcatOps[0];
14418 }
14419
14420 return SDValue();
14421 }
14422
LowerINSERT_VECTOR_ELT(SDValue Op,SelectionDAG & DAG) const14423 SDValue AArch64TargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
14424 SelectionDAG &DAG) const {
14425 assert(Op.getOpcode() == ISD::INSERT_VECTOR_ELT && "Unknown opcode!");
14426
14427 if (useSVEForFixedLengthVectorVT(Op.getValueType(),
14428 !Subtarget->isNeonAvailable()))
14429 return LowerFixedLengthInsertVectorElt(Op, DAG);
14430
14431 EVT VT = Op.getOperand(0).getValueType();
14432
14433 if (VT.getScalarType() == MVT::i1) {
14434 EVT VectorVT = getPromotedVTForPredicate(VT);
14435 SDLoc DL(Op);
14436 SDValue ExtendedVector =
14437 DAG.getAnyExtOrTrunc(Op.getOperand(0), DL, VectorVT);
14438 SDValue ExtendedValue =
14439 DAG.getAnyExtOrTrunc(Op.getOperand(1), DL,
14440 VectorVT.getScalarType().getSizeInBits() < 32
14441 ? MVT::i32
14442 : VectorVT.getScalarType());
14443 ExtendedVector =
14444 DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VectorVT, ExtendedVector,
14445 ExtendedValue, Op.getOperand(2));
14446 return DAG.getAnyExtOrTrunc(ExtendedVector, DL, VT);
14447 }
14448
14449 // Check for non-constant or out of range lane.
14450 ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Op.getOperand(2));
14451 if (!CI || CI->getZExtValue() >= VT.getVectorNumElements())
14452 return SDValue();
14453
14454 return Op;
14455 }
14456
14457 SDValue
LowerEXTRACT_VECTOR_ELT(SDValue Op,SelectionDAG & DAG) const14458 AArch64TargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
14459 SelectionDAG &DAG) const {
14460 assert(Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT && "Unknown opcode!");
14461 EVT VT = Op.getOperand(0).getValueType();
14462
14463 if (VT.getScalarType() == MVT::i1) {
14464 // We can't directly extract from an SVE predicate; extend it first.
14465 // (This isn't the only possible lowering, but it's straightforward.)
14466 EVT VectorVT = getPromotedVTForPredicate(VT);
14467 SDLoc DL(Op);
14468 SDValue Extend =
14469 DAG.getNode(ISD::ANY_EXTEND, DL, VectorVT, Op.getOperand(0));
14470 MVT ExtractTy = VectorVT == MVT::nxv2i64 ? MVT::i64 : MVT::i32;
14471 SDValue Extract = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtractTy,
14472 Extend, Op.getOperand(1));
14473 return DAG.getAnyExtOrTrunc(Extract, DL, Op.getValueType());
14474 }
14475
14476 if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
14477 return LowerFixedLengthExtractVectorElt(Op, DAG);
14478
14479 // Check for non-constant or out of range lane.
14480 ConstantSDNode *CI = dyn_cast<ConstantSDNode>(Op.getOperand(1));
14481 if (!CI || CI->getZExtValue() >= VT.getVectorNumElements())
14482 return SDValue();
14483
14484 // Insertion/extraction are legal for V128 types.
14485 if (VT == MVT::v16i8 || VT == MVT::v8i16 || VT == MVT::v4i32 ||
14486 VT == MVT::v2i64 || VT == MVT::v4f32 || VT == MVT::v2f64 ||
14487 VT == MVT::v8f16 || VT == MVT::v8bf16)
14488 return Op;
14489
14490 if (VT != MVT::v8i8 && VT != MVT::v4i16 && VT != MVT::v2i32 &&
14491 VT != MVT::v1i64 && VT != MVT::v2f32 && VT != MVT::v4f16 &&
14492 VT != MVT::v4bf16)
14493 return SDValue();
14494
14495 // For V64 types, we perform extraction by expanding the value
14496 // to a V128 type and perform the extraction on that.
14497 SDLoc DL(Op);
14498 SDValue WideVec = WidenVector(Op.getOperand(0), DAG);
14499 EVT WideTy = WideVec.getValueType();
14500
14501 EVT ExtrTy = WideTy.getVectorElementType();
14502 if (ExtrTy == MVT::i16 || ExtrTy == MVT::i8)
14503 ExtrTy = MVT::i32;
14504
14505 // For extractions, we just return the result directly.
14506 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ExtrTy, WideVec,
14507 Op.getOperand(1));
14508 }
14509
LowerEXTRACT_SUBVECTOR(SDValue Op,SelectionDAG & DAG) const14510 SDValue AArch64TargetLowering::LowerEXTRACT_SUBVECTOR(SDValue Op,
14511 SelectionDAG &DAG) const {
14512 EVT VT = Op.getValueType();
14513 assert(VT.isFixedLengthVector() &&
14514 "Only cases that extract a fixed length vector are supported!");
14515 EVT InVT = Op.getOperand(0).getValueType();
14516
14517 // If we don't have legal types yet, do nothing
14518 if (!isTypeLegal(InVT))
14519 return SDValue();
14520
14521 if (InVT.is128BitVector()) {
14522 assert(VT.is64BitVector() && "Extracting unexpected vector type!");
14523 unsigned Idx = Op.getConstantOperandVal(1);
14524
14525 // This will get lowered to an appropriate EXTRACT_SUBREG in ISel.
14526 if (Idx == 0)
14527 return Op;
14528
14529 // If this is extracting the upper 64-bits of a 128-bit vector, we match
14530 // that directly.
14531 if (Idx * InVT.getScalarSizeInBits() == 64 && Subtarget->isNeonAvailable())
14532 return Op;
14533 }
14534
14535 if (InVT.isScalableVector() ||
14536 useSVEForFixedLengthVectorVT(InVT, !Subtarget->isNeonAvailable())) {
14537 SDLoc DL(Op);
14538 SDValue Vec = Op.getOperand(0);
14539 SDValue Idx = Op.getOperand(1);
14540
14541 EVT PackedVT = getPackedSVEVectorVT(InVT.getVectorElementType());
14542 if (PackedVT != InVT) {
14543 // Pack input into the bottom part of an SVE register and try again.
14544 SDValue Container = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, PackedVT,
14545 DAG.getUNDEF(PackedVT), Vec,
14546 DAG.getVectorIdxConstant(0, DL));
14547 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Container, Idx);
14548 }
14549
14550 // This will get matched by custom code during ISelDAGToDAG.
14551 if (isNullConstant(Idx))
14552 return Op;
14553
14554 assert(InVT.isScalableVector() && "Unexpected vector type!");
14555 // Move requested subvector to the start of the vector and try again.
14556 SDValue Splice = DAG.getNode(ISD::VECTOR_SPLICE, DL, InVT, Vec, Vec, Idx);
14557 return convertFromScalableVector(DAG, VT, Splice);
14558 }
14559
14560 return SDValue();
14561 }
14562
LowerINSERT_SUBVECTOR(SDValue Op,SelectionDAG & DAG) const14563 SDValue AArch64TargetLowering::LowerINSERT_SUBVECTOR(SDValue Op,
14564 SelectionDAG &DAG) const {
14565 assert(Op.getValueType().isScalableVector() &&
14566 "Only expect to lower inserts into scalable vectors!");
14567
14568 EVT InVT = Op.getOperand(1).getValueType();
14569 unsigned Idx = Op.getConstantOperandVal(2);
14570
14571 SDValue Vec0 = Op.getOperand(0);
14572 SDValue Vec1 = Op.getOperand(1);
14573 SDLoc DL(Op);
14574 EVT VT = Op.getValueType();
14575
14576 if (InVT.isScalableVector()) {
14577 if (!isTypeLegal(VT))
14578 return SDValue();
14579
14580 // Break down insert_subvector into simpler parts.
14581 if (VT.getVectorElementType() == MVT::i1) {
14582 unsigned NumElts = VT.getVectorMinNumElements();
14583 EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
14584
14585 SDValue Lo, Hi;
14586 Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, Vec0,
14587 DAG.getVectorIdxConstant(0, DL));
14588 Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, Vec0,
14589 DAG.getVectorIdxConstant(NumElts / 2, DL));
14590 if (Idx < (NumElts / 2))
14591 Lo = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, HalfVT, Lo, Vec1,
14592 DAG.getVectorIdxConstant(Idx, DL));
14593 else
14594 Hi = DAG.getNode(ISD::INSERT_SUBVECTOR, DL, HalfVT, Hi, Vec1,
14595 DAG.getVectorIdxConstant(Idx - (NumElts / 2), DL));
14596
14597 return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Lo, Hi);
14598 }
14599
14600 // Ensure the subvector is half the size of the main vector.
14601 if (VT.getVectorElementCount() != (InVT.getVectorElementCount() * 2))
14602 return SDValue();
14603
14604 // Here narrow and wide refers to the vector element types. After "casting"
14605 // both vectors must have the same bit length and so because the subvector
14606 // has fewer elements, those elements need to be bigger.
14607 EVT NarrowVT = getPackedSVEVectorVT(VT.getVectorElementCount());
14608 EVT WideVT = getPackedSVEVectorVT(InVT.getVectorElementCount());
14609
14610 // NOP cast operands to the largest legal vector of the same element count.
14611 if (VT.isFloatingPoint()) {
14612 Vec0 = getSVESafeBitCast(NarrowVT, Vec0, DAG);
14613 Vec1 = getSVESafeBitCast(WideVT, Vec1, DAG);
14614 } else {
14615 // Legal integer vectors are already their largest so Vec0 is fine as is.
14616 Vec1 = DAG.getNode(ISD::ANY_EXTEND, DL, WideVT, Vec1);
14617 }
14618
14619 // To replace the top/bottom half of vector V with vector SubV we widen the
14620 // preserved half of V, concatenate this to SubV (the order depending on the
14621 // half being replaced) and then narrow the result.
14622 SDValue Narrow;
14623 if (Idx == 0) {
14624 SDValue HiVec0 = DAG.getNode(AArch64ISD::UUNPKHI, DL, WideVT, Vec0);
14625 Narrow = DAG.getNode(AArch64ISD::UZP1, DL, NarrowVT, Vec1, HiVec0);
14626 } else {
14627 assert(Idx == InVT.getVectorMinNumElements() &&
14628 "Invalid subvector index!");
14629 SDValue LoVec0 = DAG.getNode(AArch64ISD::UUNPKLO, DL, WideVT, Vec0);
14630 Narrow = DAG.getNode(AArch64ISD::UZP1, DL, NarrowVT, LoVec0, Vec1);
14631 }
14632
14633 return getSVESafeBitCast(VT, Narrow, DAG);
14634 }
14635
14636 if (Idx == 0 && isPackedVectorType(VT, DAG)) {
14637 // This will be matched by custom code during ISelDAGToDAG.
14638 if (Vec0.isUndef())
14639 return Op;
14640
14641 std::optional<unsigned> PredPattern =
14642 getSVEPredPatternFromNumElements(InVT.getVectorNumElements());
14643 auto PredTy = VT.changeVectorElementType(MVT::i1);
14644 SDValue PTrue = getPTrue(DAG, DL, PredTy, *PredPattern);
14645 SDValue ScalableVec1 = convertToScalableVector(DAG, VT, Vec1);
14646 return DAG.getNode(ISD::VSELECT, DL, VT, PTrue, ScalableVec1, Vec0);
14647 }
14648
14649 return SDValue();
14650 }
14651
isPow2Splat(SDValue Op,uint64_t & SplatVal,bool & Negated)14652 static bool isPow2Splat(SDValue Op, uint64_t &SplatVal, bool &Negated) {
14653 if (Op.getOpcode() != AArch64ISD::DUP &&
14654 Op.getOpcode() != ISD::SPLAT_VECTOR &&
14655 Op.getOpcode() != ISD::BUILD_VECTOR)
14656 return false;
14657
14658 if (Op.getOpcode() == ISD::BUILD_VECTOR &&
14659 !isAllConstantBuildVector(Op, SplatVal))
14660 return false;
14661
14662 if (Op.getOpcode() != ISD::BUILD_VECTOR &&
14663 !isa<ConstantSDNode>(Op->getOperand(0)))
14664 return false;
14665
14666 SplatVal = Op->getConstantOperandVal(0);
14667 if (Op.getValueType().getVectorElementType() != MVT::i64)
14668 SplatVal = (int32_t)SplatVal;
14669
14670 Negated = false;
14671 if (isPowerOf2_64(SplatVal))
14672 return true;
14673
14674 Negated = true;
14675 if (isPowerOf2_64(-SplatVal)) {
14676 SplatVal = -SplatVal;
14677 return true;
14678 }
14679
14680 return false;
14681 }
14682
LowerDIV(SDValue Op,SelectionDAG & DAG) const14683 SDValue AArch64TargetLowering::LowerDIV(SDValue Op, SelectionDAG &DAG) const {
14684 EVT VT = Op.getValueType();
14685 SDLoc dl(Op);
14686
14687 if (useSVEForFixedLengthVectorVT(VT, /*OverrideNEON=*/true))
14688 return LowerFixedLengthVectorIntDivideToSVE(Op, DAG);
14689
14690 assert(VT.isScalableVector() && "Expected a scalable vector.");
14691
14692 bool Signed = Op.getOpcode() == ISD::SDIV;
14693 unsigned PredOpcode = Signed ? AArch64ISD::SDIV_PRED : AArch64ISD::UDIV_PRED;
14694
14695 bool Negated;
14696 uint64_t SplatVal;
14697 if (Signed && isPow2Splat(Op.getOperand(1), SplatVal, Negated)) {
14698 SDValue Pg = getPredicateForScalableVector(DAG, dl, VT);
14699 SDValue Res =
14700 DAG.getNode(AArch64ISD::SRAD_MERGE_OP1, dl, VT, Pg, Op->getOperand(0),
14701 DAG.getTargetConstant(Log2_64(SplatVal), dl, MVT::i32));
14702 if (Negated)
14703 Res = DAG.getNode(ISD::SUB, dl, VT, DAG.getConstant(0, dl, VT), Res);
14704
14705 return Res;
14706 }
14707
14708 if (VT == MVT::nxv4i32 || VT == MVT::nxv2i64)
14709 return LowerToPredicatedOp(Op, DAG, PredOpcode);
14710
14711 // SVE doesn't have i8 and i16 DIV operations; widen them to 32-bit
14712 // operations, and truncate the result.
14713 EVT WidenedVT;
14714 if (VT == MVT::nxv16i8)
14715 WidenedVT = MVT::nxv8i16;
14716 else if (VT == MVT::nxv8i16)
14717 WidenedVT = MVT::nxv4i32;
14718 else
14719 llvm_unreachable("Unexpected Custom DIV operation");
14720
14721 unsigned UnpkLo = Signed ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO;
14722 unsigned UnpkHi = Signed ? AArch64ISD::SUNPKHI : AArch64ISD::UUNPKHI;
14723 SDValue Op0Lo = DAG.getNode(UnpkLo, dl, WidenedVT, Op.getOperand(0));
14724 SDValue Op1Lo = DAG.getNode(UnpkLo, dl, WidenedVT, Op.getOperand(1));
14725 SDValue Op0Hi = DAG.getNode(UnpkHi, dl, WidenedVT, Op.getOperand(0));
14726 SDValue Op1Hi = DAG.getNode(UnpkHi, dl, WidenedVT, Op.getOperand(1));
14727 SDValue ResultLo = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Lo, Op1Lo);
14728 SDValue ResultHi = DAG.getNode(Op.getOpcode(), dl, WidenedVT, Op0Hi, Op1Hi);
14729 return DAG.getNode(AArch64ISD::UZP1, dl, VT, ResultLo, ResultHi);
14730 }
14731
shouldExpandBuildVectorWithShuffles(EVT VT,unsigned DefinedValues) const14732 bool AArch64TargetLowering::shouldExpandBuildVectorWithShuffles(
14733 EVT VT, unsigned DefinedValues) const {
14734 if (!Subtarget->isNeonAvailable())
14735 return false;
14736 return TargetLowering::shouldExpandBuildVectorWithShuffles(VT, DefinedValues);
14737 }
14738
isShuffleMaskLegal(ArrayRef<int> M,EVT VT) const14739 bool AArch64TargetLowering::isShuffleMaskLegal(ArrayRef<int> M, EVT VT) const {
14740 // Currently no fixed length shuffles that require SVE are legal.
14741 if (useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
14742 return false;
14743
14744 if (VT.getVectorNumElements() == 4 &&
14745 (VT.is128BitVector() || VT.is64BitVector())) {
14746 unsigned Cost = getPerfectShuffleCost(M);
14747 if (Cost <= 1)
14748 return true;
14749 }
14750
14751 bool DummyBool;
14752 int DummyInt;
14753 unsigned DummyUnsigned;
14754
14755 unsigned EltSize = VT.getScalarSizeInBits();
14756 unsigned NumElts = VT.getVectorNumElements();
14757 return (ShuffleVectorSDNode::isSplatMask(&M[0], VT) ||
14758 isREVMask(M, EltSize, NumElts, 64) ||
14759 isREVMask(M, EltSize, NumElts, 32) ||
14760 isREVMask(M, EltSize, NumElts, 16) ||
14761 isEXTMask(M, VT, DummyBool, DummyUnsigned) ||
14762 isTRNMask(M, NumElts, DummyUnsigned) ||
14763 isUZPMask(M, NumElts, DummyUnsigned) ||
14764 isZIPMask(M, NumElts, DummyUnsigned) ||
14765 isTRN_v_undef_Mask(M, VT, DummyUnsigned) ||
14766 isUZP_v_undef_Mask(M, VT, DummyUnsigned) ||
14767 isZIP_v_undef_Mask(M, VT, DummyUnsigned) ||
14768 isINSMask(M, NumElts, DummyBool, DummyInt) ||
14769 isConcatMask(M, VT, VT.getSizeInBits() == 128));
14770 }
14771
isVectorClearMaskLegal(ArrayRef<int> M,EVT VT) const14772 bool AArch64TargetLowering::isVectorClearMaskLegal(ArrayRef<int> M,
14773 EVT VT) const {
14774 // Just delegate to the generic legality, clear masks aren't special.
14775 return isShuffleMaskLegal(M, VT);
14776 }
14777
14778 /// getVShiftImm - Check if this is a valid build_vector for the immediate
14779 /// operand of a vector shift operation, where all the elements of the
14780 /// build_vector must have the same constant integer value.
getVShiftImm(SDValue Op,unsigned ElementBits,int64_t & Cnt)14781 static bool getVShiftImm(SDValue Op, unsigned ElementBits, int64_t &Cnt) {
14782 // Ignore bit_converts.
14783 while (Op.getOpcode() == ISD::BITCAST)
14784 Op = Op.getOperand(0);
14785 BuildVectorSDNode *BVN = dyn_cast<BuildVectorSDNode>(Op.getNode());
14786 APInt SplatBits, SplatUndef;
14787 unsigned SplatBitSize;
14788 bool HasAnyUndefs;
14789 if (!BVN || !BVN->isConstantSplat(SplatBits, SplatUndef, SplatBitSize,
14790 HasAnyUndefs, ElementBits) ||
14791 SplatBitSize > ElementBits)
14792 return false;
14793 Cnt = SplatBits.getSExtValue();
14794 return true;
14795 }
14796
14797 /// isVShiftLImm - Check if this is a valid build_vector for the immediate
14798 /// operand of a vector shift left operation. That value must be in the range:
14799 /// 0 <= Value < ElementBits for a left shift; or
14800 /// 0 <= Value <= ElementBits for a long left shift.
isVShiftLImm(SDValue Op,EVT VT,bool isLong,int64_t & Cnt)14801 static bool isVShiftLImm(SDValue Op, EVT VT, bool isLong, int64_t &Cnt) {
14802 assert(VT.isVector() && "vector shift count is not a vector type");
14803 int64_t ElementBits = VT.getScalarSizeInBits();
14804 if (!getVShiftImm(Op, ElementBits, Cnt))
14805 return false;
14806 return (Cnt >= 0 && (isLong ? Cnt - 1 : Cnt) < ElementBits);
14807 }
14808
14809 /// isVShiftRImm - Check if this is a valid build_vector for the immediate
14810 /// operand of a vector shift right operation. The value must be in the range:
14811 /// 1 <= Value <= ElementBits for a right shift; or
isVShiftRImm(SDValue Op,EVT VT,bool isNarrow,int64_t & Cnt)14812 static bool isVShiftRImm(SDValue Op, EVT VT, bool isNarrow, int64_t &Cnt) {
14813 assert(VT.isVector() && "vector shift count is not a vector type");
14814 int64_t ElementBits = VT.getScalarSizeInBits();
14815 if (!getVShiftImm(Op, ElementBits, Cnt))
14816 return false;
14817 return (Cnt >= 1 && Cnt <= (isNarrow ? ElementBits / 2 : ElementBits));
14818 }
14819
LowerTRUNCATE(SDValue Op,SelectionDAG & DAG) const14820 SDValue AArch64TargetLowering::LowerTRUNCATE(SDValue Op,
14821 SelectionDAG &DAG) const {
14822 EVT VT = Op.getValueType();
14823
14824 if (VT.getScalarType() == MVT::i1) {
14825 // Lower i1 truncate to `(x & 1) != 0`.
14826 SDLoc dl(Op);
14827 EVT OpVT = Op.getOperand(0).getValueType();
14828 SDValue Zero = DAG.getConstant(0, dl, OpVT);
14829 SDValue One = DAG.getConstant(1, dl, OpVT);
14830 SDValue And = DAG.getNode(ISD::AND, dl, OpVT, Op.getOperand(0), One);
14831 return DAG.getSetCC(dl, VT, And, Zero, ISD::SETNE);
14832 }
14833
14834 if (!VT.isVector() || VT.isScalableVector())
14835 return SDValue();
14836
14837 if (useSVEForFixedLengthVectorVT(Op.getOperand(0).getValueType(),
14838 !Subtarget->isNeonAvailable()))
14839 return LowerFixedLengthVectorTruncateToSVE(Op, DAG);
14840
14841 return SDValue();
14842 }
14843
14844 // Check if we can we lower this SRL to a rounding shift instruction. ResVT is
14845 // possibly a truncated type, it tells how many bits of the value are to be
14846 // used.
canLowerSRLToRoundingShiftForVT(SDValue Shift,EVT ResVT,SelectionDAG & DAG,unsigned & ShiftValue,SDValue & RShOperand)14847 static bool canLowerSRLToRoundingShiftForVT(SDValue Shift, EVT ResVT,
14848 SelectionDAG &DAG,
14849 unsigned &ShiftValue,
14850 SDValue &RShOperand) {
14851 if (Shift->getOpcode() != ISD::SRL)
14852 return false;
14853
14854 EVT VT = Shift.getValueType();
14855 assert(VT.isScalableVT());
14856
14857 auto ShiftOp1 =
14858 dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Shift->getOperand(1)));
14859 if (!ShiftOp1)
14860 return false;
14861
14862 ShiftValue = ShiftOp1->getZExtValue();
14863 if (ShiftValue < 1 || ShiftValue > ResVT.getScalarSizeInBits())
14864 return false;
14865
14866 SDValue Add = Shift->getOperand(0);
14867 if (Add->getOpcode() != ISD::ADD || !Add->hasOneUse())
14868 return false;
14869
14870 assert(ResVT.getScalarSizeInBits() <= VT.getScalarSizeInBits() &&
14871 "ResVT must be truncated or same type as the shift.");
14872 // Check if an overflow can lead to incorrect results.
14873 uint64_t ExtraBits = VT.getScalarSizeInBits() - ResVT.getScalarSizeInBits();
14874 if (ShiftValue > ExtraBits && !Add->getFlags().hasNoUnsignedWrap())
14875 return false;
14876
14877 auto AddOp1 =
14878 dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(Add->getOperand(1)));
14879 if (!AddOp1)
14880 return false;
14881 uint64_t AddValue = AddOp1->getZExtValue();
14882 if (AddValue != 1ULL << (ShiftValue - 1))
14883 return false;
14884
14885 RShOperand = Add->getOperand(0);
14886 return true;
14887 }
14888
LowerVectorSRA_SRL_SHL(SDValue Op,SelectionDAG & DAG) const14889 SDValue AArch64TargetLowering::LowerVectorSRA_SRL_SHL(SDValue Op,
14890 SelectionDAG &DAG) const {
14891 EVT VT = Op.getValueType();
14892 SDLoc DL(Op);
14893 int64_t Cnt;
14894
14895 if (!Op.getOperand(1).getValueType().isVector())
14896 return Op;
14897 unsigned EltSize = VT.getScalarSizeInBits();
14898
14899 switch (Op.getOpcode()) {
14900 case ISD::SHL:
14901 if (VT.isScalableVector() ||
14902 useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable()))
14903 return LowerToPredicatedOp(Op, DAG, AArch64ISD::SHL_PRED);
14904
14905 if (isVShiftLImm(Op.getOperand(1), VT, false, Cnt) && Cnt < EltSize)
14906 return DAG.getNode(AArch64ISD::VSHL, DL, VT, Op.getOperand(0),
14907 DAG.getConstant(Cnt, DL, MVT::i32));
14908 return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
14909 DAG.getConstant(Intrinsic::aarch64_neon_ushl, DL,
14910 MVT::i32),
14911 Op.getOperand(0), Op.getOperand(1));
14912 case ISD::SRA:
14913 case ISD::SRL:
14914 if (VT.isScalableVector() &&
14915 (Subtarget->hasSVE2() ||
14916 (Subtarget->hasSME() && Subtarget->isStreaming()))) {
14917 SDValue RShOperand;
14918 unsigned ShiftValue;
14919 if (canLowerSRLToRoundingShiftForVT(Op, VT, DAG, ShiftValue, RShOperand))
14920 return DAG.getNode(AArch64ISD::URSHR_I_PRED, DL, VT,
14921 getPredicateForVector(DAG, DL, VT), RShOperand,
14922 DAG.getTargetConstant(ShiftValue, DL, MVT::i32));
14923 }
14924
14925 if (VT.isScalableVector() ||
14926 useSVEForFixedLengthVectorVT(VT, !Subtarget->isNeonAvailable())) {
14927 unsigned Opc = Op.getOpcode() == ISD::SRA ? AArch64ISD::SRA_PRED
14928 : AArch64ISD::SRL_PRED;
14929 return LowerToPredicatedOp(Op, DAG, Opc);
14930 }
14931
14932 // Right shift immediate
14933 if (isVShiftRImm(Op.getOperand(1), VT, false, Cnt) && Cnt < EltSize) {
14934 unsigned Opc =
14935 (Op.getOpcode() == ISD::SRA) ? AArch64ISD::VASHR : AArch64ISD::VLSHR;
14936 return DAG.getNode(Opc, DL, VT, Op.getOperand(0),
14937 DAG.getConstant(Cnt, DL, MVT::i32), Op->getFlags());
14938 }
14939
14940 // Right shift register. Note, there is not a shift right register
14941 // instruction, but the shift left register instruction takes a signed
14942 // value, where negative numbers specify a right shift.
14943 unsigned Opc = (Op.getOpcode() == ISD::SRA) ? Intrinsic::aarch64_neon_sshl
14944 : Intrinsic::aarch64_neon_ushl;
14945 // negate the shift amount
14946 SDValue NegShift = DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT),
14947 Op.getOperand(1));
14948 SDValue NegShiftLeft =
14949 DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VT,
14950 DAG.getConstant(Opc, DL, MVT::i32), Op.getOperand(0),
14951 NegShift);
14952 return NegShiftLeft;
14953 }
14954
14955 llvm_unreachable("unexpected shift opcode");
14956 }
14957
EmitVectorComparison(SDValue LHS,SDValue RHS,AArch64CC::CondCode CC,bool NoNans,EVT VT,const SDLoc & dl,SelectionDAG & DAG)14958 static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
14959 AArch64CC::CondCode CC, bool NoNans, EVT VT,
14960 const SDLoc &dl, SelectionDAG &DAG) {
14961 EVT SrcVT = LHS.getValueType();
14962 assert(VT.getSizeInBits() == SrcVT.getSizeInBits() &&
14963 "function only supposed to emit natural comparisons");
14964
14965 APInt SplatValue;
14966 APInt SplatUndef;
14967 unsigned SplatBitSize = 0;
14968 bool HasAnyUndefs;
14969
14970 BuildVectorSDNode *BVN = dyn_cast<BuildVectorSDNode>(RHS.getNode());
14971 bool IsCnst = BVN && BVN->isConstantSplat(SplatValue, SplatUndef,
14972 SplatBitSize, HasAnyUndefs);
14973
14974 bool IsZero = IsCnst && SplatValue == 0;
14975 bool IsOne =
14976 IsCnst && SrcVT.getScalarSizeInBits() == SplatBitSize && SplatValue == 1;
14977 bool IsMinusOne = IsCnst && SplatValue.isAllOnes();
14978
14979 if (SrcVT.getVectorElementType().isFloatingPoint()) {
14980 switch (CC) {
14981 default:
14982 return SDValue();
14983 case AArch64CC::NE: {
14984 SDValue Fcmeq;
14985 if (IsZero)
14986 Fcmeq = DAG.getNode(AArch64ISD::FCMEQz, dl, VT, LHS);
14987 else
14988 Fcmeq = DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
14989 return DAG.getNOT(dl, Fcmeq, VT);
14990 }
14991 case AArch64CC::EQ:
14992 if (IsZero)
14993 return DAG.getNode(AArch64ISD::FCMEQz, dl, VT, LHS);
14994 return DAG.getNode(AArch64ISD::FCMEQ, dl, VT, LHS, RHS);
14995 case AArch64CC::GE:
14996 if (IsZero)
14997 return DAG.getNode(AArch64ISD::FCMGEz, dl, VT, LHS);
14998 return DAG.getNode(AArch64ISD::FCMGE, dl, VT, LHS, RHS);
14999 case AArch64CC::GT:
15000 if (IsZero)
15001 return DAG.getNode(AArch64ISD::FCMGTz, dl, VT, LHS);
15002 return DAG.getNode(AArch64ISD::FCMGT, dl, VT, LHS, RHS);
15003 case AArch64CC::LE:
15004 if (!NoNans)
15005 return SDValue();
15006 // If we ignore NaNs then we can use to the LS implementation.
15007 [[fallthrough]];
15008 case AArch64CC::LS:
15009 if (IsZero)
15010 return DAG.getNode(AArch64ISD::FCMLEz, dl, VT, LHS);
15011 return DAG.getNode(AArch64ISD::FCMGE, dl, VT, RHS, LHS);
15012 case AArch64CC::LT:
15013 if (!NoNans)
15014 return SDValue();
15015 // If we ignore NaNs then we can use to the MI implementation.
15016 [[fallthrough]];
15017 case AArch64CC::MI:
15018 if (IsZero)
15019 return DAG.getNode(AArch64ISD::FCMLTz, dl, VT, LHS);
15020 return DAG.getNode(AArch64ISD::FCMGT, dl, VT, RHS, LHS);
15021 }
15022 }
15023
15024 switch (CC) {
15025 default:
15026 return SDValue();
15027 case AArch64CC::NE: {
15028 SDValue Cmeq;
15029 if (IsZero)
15030 Cmeq = DAG.getNode(AArch64ISD::CMEQz, dl, VT, LHS);
15031 else
15032 Cmeq = DAG.getNode(AArch64ISD::CMEQ, dl, VT, LHS, RHS);
15033 return DAG.getNOT(dl, Cmeq, VT);
15034 }
15035 case AArch64CC::EQ:
15036 if (IsZero)
15037 return DAG.getNode(AArch64ISD::CMEQz, dl, VT, LHS);
15038 return DAG.getNode(AArch64ISD::CMEQ, dl, VT, LHS, RHS);
15039 case AArch64CC::GE:
15040 if (IsZero)
15041 return DAG.getNode(AArch64ISD::CMGEz, dl, VT, LHS);
15042 return DAG.getNode(AArch64ISD::CMGE, dl, VT, LHS, RHS);
15043 case AArch64CC::GT:
15044 if (IsZero)
15045 return DAG.getNode(AArch64ISD::CMGTz, dl, VT, LHS);
15046 if (IsMinusOne)
15047 return DAG.getNode(AArch64ISD::CMGEz, dl, VT, LHS, RHS);
15048 return DAG.getNode(AArch64ISD::CMGT, dl, VT, LHS, RHS);
15049 case AArch64CC::LE:
15050 if (IsZero)
15051 return DAG.getNode(AArch64ISD::CMLEz, dl, VT, LHS);
15052 return DAG.getNode(AArch64ISD::CMGE, dl, VT, RHS, LHS);
15053 case AArch64CC::LS:
15054 return DAG.getNode(AArch64ISD::CMHS, dl, VT, RHS, LHS);
15055 case AArch64CC::LO:
15056 return DAG.getNode(AArch64ISD::CMHI, dl, VT, RHS, LHS);
15057 case AArch64CC::LT:
15058 if (IsZero)
15059 return DAG.getNode(AArch64ISD::CMLTz, dl, VT, LHS);
15060 if (IsOne)
15061 return DAG.getNode(AArch64ISD::CMLEz, dl, VT, LHS);
15062 return DAG.getNode(AArch64ISD::CMGT, dl, VT, RHS, LHS);
15063 case AArch64CC::HI:
15064 return DAG.getNode(AArch64ISD::CMHI, dl, VT, LHS, RHS);
15065 case AArch64CC::HS:
15066 return DAG.getNode(AArch64ISD::CMHS, dl, VT, LHS, RHS);
15067 }
15068 }
15069
LowerVSETCC(SDValue Op,SelectionDAG & DAG) const15070 SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
15071 SelectionDAG &DAG) const {
15072 if (Op.getValueType().isScalableVector())
15073 return LowerToPredicatedOp(Op, DAG, AArch64ISD::SETCC_MERGE_ZERO);
15074
15075 if (useSVEForFixedLengthVectorVT(Op.getOperand(0).getValueType(),
15076 !Subtarget->isNeonAvailable()))
15077 return LowerFixedLengthVectorSetccToSVE(Op, DAG);
15078
15079 ISD::CondCode CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
15080 SDValue LHS = Op.getOperand(0);
15081 SDValue RHS = Op.getOperand(1);
15082 EVT CmpVT = LHS.getValueType().changeVectorElementTypeToInteger();
15083 SDLoc dl(Op);
15084
15085 if (LHS.getValueType().getVectorElementType().isInteger()) {
15086 assert(LHS.getValueType() == RHS.getValueType());
15087 AArch64CC::CondCode AArch64CC = changeIntCCToAArch64CC(CC);
15088 SDValue Cmp =
15089 EmitVectorComparison(LHS, RHS, AArch64CC, false, CmpVT, dl, DAG);
15090 return DAG.getSExtOrTrunc(Cmp, dl, Op.getValueType());
15091 }
15092
15093 // Lower isnan(x) | isnan(never-nan) to x != x.
15094 // Lower !isnan(x) & !isnan(never-nan) to x == x.
15095 if (CC == ISD::SETUO || CC == ISD::SETO) {
15096 bool OneNaN = false;
15097 if (LHS == RHS) {
15098 OneNaN = true;
15099 } else if (DAG.isKnownNeverNaN(RHS)) {
15100 OneNaN = true;
15101 RHS = LHS;
15102 } else if (DAG.isKnownNeverNaN(LHS)) {
15103 OneNaN = true;
15104 LHS = RHS;
15105 }
15106 if (OneNaN) {
15107 CC = CC == ISD::SETUO ? ISD::SETUNE : ISD::SETOEQ;
15108 }
15109 }
15110
15111 const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
15112
15113 // Make v4f16 (only) fcmp operations utilise vector instructions
15114 // v8f16 support will be a litle more complicated
15115 if ((!FullFP16 && LHS.getValueType().getVectorElementType() == MVT::f16) ||
15116 LHS.getValueType().getVectorElementType() == MVT::bf16) {
15117 if (LHS.getValueType().getVectorNumElements() == 4) {
15118 LHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::v4f32, LHS);
15119 RHS = DAG.getNode(ISD::FP_EXTEND, dl, MVT::v4f32, RHS);
15120 SDValue NewSetcc = DAG.getSetCC(dl, MVT::v4i16, LHS, RHS, CC);
15121 DAG.ReplaceAllUsesWith(Op, NewSetcc);
15122 CmpVT = MVT::v4i32;
15123 } else
15124 return SDValue();
15125 }
15126
15127 assert((!FullFP16 && LHS.getValueType().getVectorElementType() != MVT::f16) ||
15128 LHS.getValueType().getVectorElementType() != MVT::bf16 ||
15129 LHS.getValueType().getVectorElementType() != MVT::f128);
15130
15131 // Unfortunately, the mapping of LLVM FP CC's onto AArch64 CC's isn't totally
15132 // clean. Some of them require two branches to implement.
15133 AArch64CC::CondCode CC1, CC2;
15134 bool ShouldInvert;
15135 changeVectorFPCCToAArch64CC(CC, CC1, CC2, ShouldInvert);
15136
15137 bool NoNaNs = getTargetMachine().Options.NoNaNsFPMath || Op->getFlags().hasNoNaNs();
15138 SDValue Cmp =
15139 EmitVectorComparison(LHS, RHS, CC1, NoNaNs, CmpVT, dl, DAG);
15140 if (!Cmp.getNode())
15141 return SDValue();
15142
15143 if (CC2 != AArch64CC::AL) {
15144 SDValue Cmp2 =
15145 EmitVectorComparison(LHS, RHS, CC2, NoNaNs, CmpVT, dl, DAG);
15146 if (!Cmp2.getNode())
15147 return SDValue();
15148
15149 Cmp = DAG.getNode(ISD::OR, dl, CmpVT, Cmp, Cmp2);
15150 }
15151
15152 Cmp = DAG.getSExtOrTrunc(Cmp, dl, Op.getValueType());
15153
15154 if (ShouldInvert)
15155 Cmp = DAG.getNOT(dl, Cmp, Cmp.getValueType());
15156
15157 return Cmp;
15158 }
15159
getReductionSDNode(unsigned Op,SDLoc DL,SDValue ScalarOp,SelectionDAG & DAG)15160 static SDValue getReductionSDNode(unsigned Op, SDLoc DL, SDValue ScalarOp,
15161 SelectionDAG &DAG) {
15162 SDValue VecOp = ScalarOp.getOperand(0);
15163 auto Rdx = DAG.getNode(Op, DL, VecOp.getSimpleValueType(), VecOp);
15164 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarOp.getValueType(), Rdx,
15165 DAG.getConstant(0, DL, MVT::i64));
15166 }
15167
getVectorBitwiseReduce(unsigned Opcode,SDValue Vec,EVT VT,SDLoc DL,SelectionDAG & DAG)15168 static SDValue getVectorBitwiseReduce(unsigned Opcode, SDValue Vec, EVT VT,
15169 SDLoc DL, SelectionDAG &DAG) {
15170 unsigned ScalarOpcode;
15171 switch (Opcode) {
15172 case ISD::VECREDUCE_AND:
15173 ScalarOpcode = ISD::AND;
15174 break;
15175 case ISD::VECREDUCE_OR:
15176 ScalarOpcode = ISD::OR;
15177 break;
15178 case ISD::VECREDUCE_XOR:
15179 ScalarOpcode = ISD::XOR;
15180 break;
15181 default:
15182 llvm_unreachable("Expected bitwise vector reduction");
15183 return SDValue();
15184 }
15185
15186 EVT VecVT = Vec.getValueType();
15187 assert(VecVT.isFixedLengthVector() && VecVT.isPow2VectorType() &&
15188 "Expected power-of-2 length vector");
15189
15190 EVT ElemVT = VecVT.getVectorElementType();
15191
15192 SDValue Result;
15193 unsigned NumElems = VecVT.getVectorNumElements();
15194
15195 // Special case for boolean reductions
15196 if (ElemVT == MVT::i1) {
15197 // Split large vectors into smaller ones
15198 if (NumElems > 16) {
15199 SDValue Lo, Hi;
15200 std::tie(Lo, Hi) = DAG.SplitVector(Vec, DL);
15201 EVT HalfVT = Lo.getValueType();
15202 SDValue HalfVec = DAG.getNode(ScalarOpcode, DL, HalfVT, Lo, Hi);
15203 return getVectorBitwiseReduce(Opcode, HalfVec, VT, DL, DAG);
15204 }
15205
15206 // Vectors that are less than 64 bits get widened to neatly fit a 64 bit
15207 // register, so e.g. <4 x i1> gets lowered to <4 x i16>. Sign extending to
15208 // this element size leads to the best codegen, since e.g. setcc results
15209 // might need to be truncated otherwise.
15210 EVT ExtendedVT = MVT::getIntegerVT(std::max(64u / NumElems, 8u));
15211
15212 // any_ext doesn't work with umin/umax, so only use it for uadd.
15213 unsigned ExtendOp =
15214 ScalarOpcode == ISD::XOR ? ISD::ANY_EXTEND : ISD::SIGN_EXTEND;
15215 SDValue Extended = DAG.getNode(
15216 ExtendOp, DL, VecVT.changeVectorElementType(ExtendedVT), Vec);
15217 switch (ScalarOpcode) {
15218 case ISD::AND:
15219 Result = DAG.getNode(ISD::VECREDUCE_UMIN, DL, ExtendedVT, Extended);
15220 break;
15221 case ISD::OR:
15222 Result = DAG.getNode(ISD::VECREDUCE_UMAX, DL, ExtendedVT, Extended);
15223 break;
15224 case ISD::XOR:
15225 Result = DAG.getNode(ISD::VECREDUCE_ADD, DL, ExtendedVT, Extended);
15226 break;
15227 default:
15228 llvm_unreachable("Unexpected Opcode");
15229 }
15230
15231 Result = DAG.getAnyExtOrTrunc(Result, DL, MVT::i1);
15232 } else {
15233 // Iteratively split the vector in half and combine using the bitwise
15234 // operation until it fits in a 64 bit register.
15235 while (VecVT.getSizeInBits() > 64) {
15236 SDValue Lo, Hi;
15237 std::tie(Lo, Hi) = DAG.SplitVector(Vec, DL);
15238 VecVT = Lo.getValueType();
15239 NumElems = VecVT.getVectorNumElements();
15240 Vec = DAG.getNode(ScalarOpcode, DL, VecVT, Lo, Hi);
15241 }
15242
15243 EVT ScalarVT = EVT::getIntegerVT(*DAG.getContext(), VecVT.getSizeInBits());
15244
15245 // Do the remaining work on a scalar since it allows the code generator to
15246 // combine the shift and bitwise operation into one instruction and since
15247 // integer instructions can have higher throughput than vector instructions.
15248 SDValue Scalar = DAG.getBitcast(ScalarVT, Vec);
15249
15250 // Iteratively combine the lower and upper halves of the scalar using the
15251 // bitwise operation, halving the relevant region of the scalar in each
15252 // iteration, until the relevant region is just one element of the original
15253 // vector.
15254 for (unsigned Shift = NumElems / 2; Shift > 0; Shift /= 2) {
15255 SDValue ShiftAmount =
15256 DAG.getConstant(Shift * ElemVT.getSizeInBits(), DL, MVT::i64);
15257 SDValue Shifted =
15258 DAG.getNode(ISD::SRL, DL, ScalarVT, Scalar, ShiftAmount);
15259 Scalar = DAG.getNode(ScalarOpcode, DL, ScalarVT, Scalar, Shifted);
15260 }
15261
15262 Result = DAG.getAnyExtOrTrunc(Scalar, DL, ElemVT);
15263 }
15264
15265 return DAG.getAnyExtOrTrunc(Result, DL, VT);
15266 }
15267
LowerVECREDUCE(SDValue Op,SelectionDAG & DAG) const15268 SDValue AArch64TargetLowering::LowerVECREDUCE(SDValue Op,
15269 SelectionDAG &DAG) const {
15270 SDValue Src = Op.getOperand(0);
15271
15272 // Try to lower fixed length reductions to SVE.
15273 EVT SrcVT = Src.getValueType();
15274 bool OverrideNEON = !Subtarget->isNeonAvailable() ||
15275 Op.getOpcode() == ISD::VECREDUCE_AND ||
15276 Op.getOpcode() == ISD::VECREDUCE_OR ||
15277 Op.getOpcode() == ISD::VECREDUCE_XOR ||
15278 Op.getOpcode() == ISD::VECREDUCE_FADD ||
15279 (Op.getOpcode() != ISD::VECREDUCE_ADD &&
15280 SrcVT.getVectorElementType() == MVT::i64);
15281 if (SrcVT.isScalableVector() ||
15282 useSVEForFixedLengthVectorVT(
15283 SrcVT, OverrideNEON && Subtarget->useSVEForFixedLengthVectors())) {
15284
15285 if (SrcVT.getVectorElementType() == MVT::i1)
15286 return LowerPredReductionToSVE(Op, DAG);
15287
15288 switch (Op.getOpcode()) {
15289 case ISD::VECREDUCE_ADD:
15290 return LowerReductionToSVE(AArch64ISD::UADDV_PRED, Op, DAG);
15291 case ISD::VECREDUCE_AND:
15292 return LowerReductionToSVE(AArch64ISD::ANDV_PRED, Op, DAG);
15293 case ISD::VECREDUCE_OR:
15294 return LowerReductionToSVE(AArch64ISD::ORV_PRED, Op, DAG);
15295 case ISD::VECREDUCE_SMAX:
15296 return LowerReductionToSVE(AArch64ISD::SMAXV_PRED, Op, DAG);
15297 case ISD::VECREDUCE_SMIN:
15298 return LowerReductionToSVE(AArch64ISD::SMINV_PRED, Op, DAG);
15299 case ISD::VECREDUCE_UMAX:
15300 return LowerReductionToSVE(AArch64ISD::UMAXV_PRED, Op, DAG);
15301 case ISD::VECREDUCE_UMIN:
15302 return LowerReductionToSVE(AArch64ISD::UMINV_PRED, Op, DAG);
15303 case ISD::VECREDUCE_XOR:
15304 return LowerReductionToSVE(AArch64ISD::EORV_PRED, Op, DAG);
15305 case ISD::VECREDUCE_FADD:
15306 return LowerReductionToSVE(AArch64ISD::FADDV_PRED, Op, DAG);
15307 case ISD::VECREDUCE_FMAX:
15308 return LowerReductionToSVE(AArch64ISD::FMAXNMV_PRED, Op, DAG);
15309 case ISD::VECREDUCE_FMIN:
15310 return LowerReductionToSVE(AArch64ISD::FMINNMV_PRED, Op, DAG);
15311 case ISD::VECREDUCE_FMAXIMUM:
15312 return LowerReductionToSVE(AArch64ISD::FMAXV_PRED, Op, DAG);
15313 case ISD::VECREDUCE_FMINIMUM:
15314 return LowerReductionToSVE(AArch64ISD::FMINV_PRED, Op, DAG);
15315 default:
15316 llvm_unreachable("Unhandled fixed length reduction");
15317 }
15318 }
15319
15320 // Lower NEON reductions.
15321 SDLoc dl(Op);
15322 switch (Op.getOpcode()) {
15323 case ISD::VECREDUCE_AND:
15324 case ISD::VECREDUCE_OR:
15325 case ISD::VECREDUCE_XOR:
15326 return getVectorBitwiseReduce(Op.getOpcode(), Op.getOperand(0),
15327 Op.getValueType(), dl, DAG);
15328 case ISD::VECREDUCE_ADD:
15329 return getReductionSDNode(AArch64ISD::UADDV, dl, Op, DAG);
15330 case ISD::VECREDUCE_SMAX:
15331 return getReductionSDNode(AArch64ISD::SMAXV, dl, Op, DAG);
15332 case ISD::VECREDUCE_SMIN:
15333 return getReductionSDNode(AArch64ISD::SMINV, dl, Op, DAG);
15334 case ISD::VECREDUCE_UMAX:
15335 return getReductionSDNode(AArch64ISD::UMAXV, dl, Op, DAG);
15336 case ISD::VECREDUCE_UMIN:
15337 return getReductionSDNode(AArch64ISD::UMINV, dl, Op, DAG);
15338 default:
15339 llvm_unreachable("Unhandled reduction");
15340 }
15341 }
15342
LowerATOMIC_LOAD_AND(SDValue Op,SelectionDAG & DAG) const15343 SDValue AArch64TargetLowering::LowerATOMIC_LOAD_AND(SDValue Op,
15344 SelectionDAG &DAG) const {
15345 auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
15346 // No point replacing if we don't have the relevant instruction/libcall anyway
15347 if (!Subtarget.hasLSE() && !Subtarget.outlineAtomics())
15348 return SDValue();
15349
15350 // LSE has an atomic load-clear instruction, but not a load-and.
15351 SDLoc dl(Op);
15352 MVT VT = Op.getSimpleValueType();
15353 assert(VT != MVT::i128 && "Handled elsewhere, code replicated.");
15354 SDValue RHS = Op.getOperand(2);
15355 AtomicSDNode *AN = cast<AtomicSDNode>(Op.getNode());
15356 RHS = DAG.getNode(ISD::XOR, dl, VT, DAG.getConstant(-1ULL, dl, VT), RHS);
15357 return DAG.getAtomic(ISD::ATOMIC_LOAD_CLR, dl, AN->getMemoryVT(),
15358 Op.getOperand(0), Op.getOperand(1), RHS,
15359 AN->getMemOperand());
15360 }
15361
15362 SDValue
LowerWindowsDYNAMIC_STACKALLOC(SDValue Op,SelectionDAG & DAG) const15363 AArch64TargetLowering::LowerWindowsDYNAMIC_STACKALLOC(SDValue Op,
15364 SelectionDAG &DAG) const {
15365
15366 SDLoc dl(Op);
15367 // Get the inputs.
15368 SDNode *Node = Op.getNode();
15369 SDValue Chain = Op.getOperand(0);
15370 SDValue Size = Op.getOperand(1);
15371 MaybeAlign Align =
15372 cast<ConstantSDNode>(Op.getOperand(2))->getMaybeAlignValue();
15373 EVT VT = Node->getValueType(0);
15374
15375 if (DAG.getMachineFunction().getFunction().hasFnAttribute(
15376 "no-stack-arg-probe")) {
15377 SDValue SP = DAG.getCopyFromReg(Chain, dl, AArch64::SP, MVT::i64);
15378 Chain = SP.getValue(1);
15379 SP = DAG.getNode(ISD::SUB, dl, MVT::i64, SP, Size);
15380 if (Align)
15381 SP = DAG.getNode(ISD::AND, dl, VT, SP.getValue(0),
15382 DAG.getConstant(-(uint64_t)Align->value(), dl, VT));
15383 Chain = DAG.getCopyToReg(Chain, dl, AArch64::SP, SP);
15384 SDValue Ops[2] = {SP, Chain};
15385 return DAG.getMergeValues(Ops, dl);
15386 }
15387
15388 Chain = DAG.getCALLSEQ_START(Chain, 0, 0, dl);
15389
15390 EVT PtrVT = getPointerTy(DAG.getDataLayout());
15391 SDValue Callee = DAG.getTargetExternalSymbol(Subtarget->getChkStkName(),
15392 PtrVT, 0);
15393
15394 const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
15395 const uint32_t *Mask = TRI->getWindowsStackProbePreservedMask();
15396 if (Subtarget->hasCustomCallingConv())
15397 TRI->UpdateCustomCallPreservedMask(DAG.getMachineFunction(), &Mask);
15398
15399 Size = DAG.getNode(ISD::SRL, dl, MVT::i64, Size,
15400 DAG.getConstant(4, dl, MVT::i64));
15401 Chain = DAG.getCopyToReg(Chain, dl, AArch64::X15, Size, SDValue());
15402 Chain =
15403 DAG.getNode(AArch64ISD::CALL, dl, DAG.getVTList(MVT::Other, MVT::Glue),
15404 Chain, Callee, DAG.getRegister(AArch64::X15, MVT::i64),
15405 DAG.getRegisterMask(Mask), Chain.getValue(1));
15406 // To match the actual intent better, we should read the output from X15 here
15407 // again (instead of potentially spilling it to the stack), but rereading Size
15408 // from X15 here doesn't work at -O0, since it thinks that X15 is undefined
15409 // here.
15410
15411 Size = DAG.getNode(ISD::SHL, dl, MVT::i64, Size,
15412 DAG.getConstant(4, dl, MVT::i64));
15413
15414 SDValue SP = DAG.getCopyFromReg(Chain, dl, AArch64::SP, MVT::i64);
15415 Chain = SP.getValue(1);
15416 SP = DAG.getNode(ISD::SUB, dl, MVT::i64, SP, Size);
15417 if (Align)
15418 SP = DAG.getNode(ISD::AND, dl, VT, SP.getValue(0),
15419 DAG.getConstant(-(uint64_t)Align->value(), dl, VT));
15420 Chain = DAG.getCopyToReg(Chain, dl, AArch64::SP, SP);
15421
15422 Chain = DAG.getCALLSEQ_END(Chain, 0, 0, SDValue(), dl);
15423
15424 SDValue Ops[2] = {SP, Chain};
15425 return DAG.getMergeValues(Ops, dl);
15426 }
15427
15428 SDValue
LowerInlineDYNAMIC_STACKALLOC(SDValue Op,SelectionDAG & DAG) const15429 AArch64TargetLowering::LowerInlineDYNAMIC_STACKALLOC(SDValue Op,
15430 SelectionDAG &DAG) const {
15431 // Get the inputs.
15432 SDNode *Node = Op.getNode();
15433 SDValue Chain = Op.getOperand(0);
15434 SDValue Size = Op.getOperand(1);
15435
15436 MaybeAlign Align =
15437 cast<ConstantSDNode>(Op.getOperand(2))->getMaybeAlignValue();
15438 SDLoc dl(Op);
15439 EVT VT = Node->getValueType(0);
15440
15441 // Construct the new SP value in a GPR.
15442 SDValue SP = DAG.getCopyFromReg(Chain, dl, AArch64::SP, MVT::i64);
15443 Chain = SP.getValue(1);
15444 SP = DAG.getNode(ISD::SUB, dl, MVT::i64, SP, Size);
15445 if (Align)
15446 SP = DAG.getNode(ISD::AND, dl, VT, SP.getValue(0),
15447 DAG.getConstant(-(uint64_t)Align->value(), dl, VT));
15448
15449 // Set the real SP to the new value with a probing loop.
15450 Chain = DAG.getNode(AArch64ISD::PROBED_ALLOCA, dl, MVT::Other, Chain, SP);
15451 SDValue Ops[2] = {SP, Chain};
15452 return DAG.getMergeValues(Ops, dl);
15453 }
15454
15455 SDValue
LowerDYNAMIC_STACKALLOC(SDValue Op,SelectionDAG & DAG) const15456 AArch64TargetLowering::LowerDYNAMIC_STACKALLOC(SDValue Op,
15457 SelectionDAG &DAG) const {
15458 MachineFunction &MF = DAG.getMachineFunction();
15459
15460 if (Subtarget->isTargetWindows())
15461 return LowerWindowsDYNAMIC_STACKALLOC(Op, DAG);
15462 else if (hasInlineStackProbe(MF))
15463 return LowerInlineDYNAMIC_STACKALLOC(Op, DAG);
15464 else
15465 return SDValue();
15466 }
15467
LowerAVG(SDValue Op,SelectionDAG & DAG,unsigned NewOp) const15468 SDValue AArch64TargetLowering::LowerAVG(SDValue Op, SelectionDAG &DAG,
15469 unsigned NewOp) const {
15470 if (Subtarget->hasSVE2())
15471 return LowerToPredicatedOp(Op, DAG, NewOp);
15472
15473 // Default to expand.
15474 return SDValue();
15475 }
15476
LowerVSCALE(SDValue Op,SelectionDAG & DAG) const15477 SDValue AArch64TargetLowering::LowerVSCALE(SDValue Op,
15478 SelectionDAG &DAG) const {
15479 EVT VT = Op.getValueType();
15480 assert(VT != MVT::i64 && "Expected illegal VSCALE node");
15481
15482 SDLoc DL(Op);
15483 APInt MulImm = Op.getConstantOperandAPInt(0);
15484 return DAG.getZExtOrTrunc(DAG.getVScale(DL, MVT::i64, MulImm.sext(64)), DL,
15485 VT);
15486 }
15487
15488 /// Set the IntrinsicInfo for the `aarch64_sve_st<N>` intrinsics.
15489 template <unsigned NumVecs>
15490 static bool
setInfoSVEStN(const AArch64TargetLowering & TLI,const DataLayout & DL,AArch64TargetLowering::IntrinsicInfo & Info,const CallInst & CI)15491 setInfoSVEStN(const AArch64TargetLowering &TLI, const DataLayout &DL,
15492 AArch64TargetLowering::IntrinsicInfo &Info, const CallInst &CI) {
15493 Info.opc = ISD::INTRINSIC_VOID;
15494 // Retrieve EC from first vector argument.
15495 const EVT VT = TLI.getMemValueType(DL, CI.getArgOperand(0)->getType());
15496 ElementCount EC = VT.getVectorElementCount();
15497 #ifndef NDEBUG
15498 // Check the assumption that all input vectors are the same type.
15499 for (unsigned I = 0; I < NumVecs; ++I)
15500 assert(VT == TLI.getMemValueType(DL, CI.getArgOperand(I)->getType()) &&
15501 "Invalid type.");
15502 #endif
15503 // memVT is `NumVecs * VT`.
15504 Info.memVT = EVT::getVectorVT(CI.getType()->getContext(), VT.getScalarType(),
15505 EC * NumVecs);
15506 Info.ptrVal = CI.getArgOperand(CI.arg_size() - 1);
15507 Info.offset = 0;
15508 Info.align.reset();
15509 Info.flags = MachineMemOperand::MOStore;
15510 return true;
15511 }
15512
15513 /// getTgtMemIntrinsic - Represent NEON load and store intrinsics as
15514 /// MemIntrinsicNodes. The associated MachineMemOperands record the alignment
15515 /// specified in the intrinsic calls.
getTgtMemIntrinsic(IntrinsicInfo & Info,const CallInst & I,MachineFunction & MF,unsigned Intrinsic) const15516 bool AArch64TargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info,
15517 const CallInst &I,
15518 MachineFunction &MF,
15519 unsigned Intrinsic) const {
15520 auto &DL = I.getDataLayout();
15521 switch (Intrinsic) {
15522 case Intrinsic::aarch64_sve_st2:
15523 return setInfoSVEStN<2>(*this, DL, Info, I);
15524 case Intrinsic::aarch64_sve_st3:
15525 return setInfoSVEStN<3>(*this, DL, Info, I);
15526 case Intrinsic::aarch64_sve_st4:
15527 return setInfoSVEStN<4>(*this, DL, Info, I);
15528 case Intrinsic::aarch64_neon_ld2:
15529 case Intrinsic::aarch64_neon_ld3:
15530 case Intrinsic::aarch64_neon_ld4:
15531 case Intrinsic::aarch64_neon_ld1x2:
15532 case Intrinsic::aarch64_neon_ld1x3:
15533 case Intrinsic::aarch64_neon_ld1x4: {
15534 Info.opc = ISD::INTRINSIC_W_CHAIN;
15535 uint64_t NumElts = DL.getTypeSizeInBits(I.getType()) / 64;
15536 Info.memVT = EVT::getVectorVT(I.getType()->getContext(), MVT::i64, NumElts);
15537 Info.ptrVal = I.getArgOperand(I.arg_size() - 1);
15538 Info.offset = 0;
15539 Info.align.reset();
15540 // volatile loads with NEON intrinsics not supported
15541 Info.flags = MachineMemOperand::MOLoad;
15542 return true;
15543 }
15544 case Intrinsic::aarch64_neon_ld2lane:
15545 case Intrinsic::aarch64_neon_ld3lane:
15546 case Intrinsic::aarch64_neon_ld4lane:
15547 case Intrinsic::aarch64_neon_ld2r:
15548 case Intrinsic::aarch64_neon_ld3r:
15549 case Intrinsic::aarch64_neon_ld4r: {
15550 Info.opc = ISD::INTRINSIC_W_CHAIN;
15551 // ldx return struct with the same vec type
15552 Type *RetTy = I.getType();
15553 auto *StructTy = cast<StructType>(RetTy);
15554 unsigned NumElts = StructTy->getNumElements();
15555 Type *VecTy = StructTy->getElementType(0);
15556 MVT EleVT = MVT::getVT(VecTy).getVectorElementType();
15557 Info.memVT = EVT::getVectorVT(I.getType()->getContext(), EleVT, NumElts);
15558 Info.ptrVal = I.getArgOperand(I.arg_size() - 1);
15559 Info.offset = 0;
15560 Info.align.reset();
15561 // volatile loads with NEON intrinsics not supported
15562 Info.flags = MachineMemOperand::MOLoad;
15563 return true;
15564 }
15565 case Intrinsic::aarch64_neon_st2:
15566 case Intrinsic::aarch64_neon_st3:
15567 case Intrinsic::aarch64_neon_st4:
15568 case Intrinsic::aarch64_neon_st1x2:
15569 case Intrinsic::aarch64_neon_st1x3:
15570 case Intrinsic::aarch64_neon_st1x4: {
15571 Info.opc = ISD::INTRINSIC_VOID;
15572 unsigned NumElts = 0;
15573 for (const Value *Arg : I.args()) {
15574 Type *ArgTy = Arg->getType();
15575 if (!ArgTy->isVectorTy())
15576 break;
15577 NumElts += DL.getTypeSizeInBits(ArgTy) / 64;
15578 }
15579 Info.memVT = EVT::getVectorVT(I.getType()->getContext(), MVT::i64, NumElts);
15580 Info.ptrVal = I.getArgOperand(I.arg_size() - 1);
15581 Info.offset = 0;
15582 Info.align.reset();
15583 // volatile stores with NEON intrinsics not supported
15584 Info.flags = MachineMemOperand::MOStore;
15585 return true;
15586 }
15587 case Intrinsic::aarch64_neon_st2lane:
15588 case Intrinsic::aarch64_neon_st3lane:
15589 case Intrinsic::aarch64_neon_st4lane: {
15590 Info.opc = ISD::INTRINSIC_VOID;
15591 unsigned NumElts = 0;
15592 // all the vector type is same
15593 Type *VecTy = I.getArgOperand(0)->getType();
15594 MVT EleVT = MVT::getVT(VecTy).getVectorElementType();
15595
15596 for (const Value *Arg : I.args()) {
15597 Type *ArgTy = Arg->getType();
15598 if (!ArgTy->isVectorTy())
15599 break;
15600 NumElts += 1;
15601 }
15602
15603 Info.memVT = EVT::getVectorVT(I.getType()->getContext(), EleVT, NumElts);
15604 Info.ptrVal = I.getArgOperand(I.arg_size() - 1);
15605 Info.offset = 0;
15606 Info.align.reset();
15607 // volatile stores with NEON intrinsics not supported
15608 Info.flags = MachineMemOperand::MOStore;
15609 return true;
15610 }
15611 case Intrinsic::aarch64_ldaxr:
15612 case Intrinsic::aarch64_ldxr: {
15613 Type *ValTy = I.getParamElementType(0);
15614 Info.opc = ISD::INTRINSIC_W_CHAIN;
15615 Info.memVT = MVT::getVT(ValTy);
15616 Info.ptrVal = I.getArgOperand(0);
15617 Info.offset = 0;
15618 Info.align = DL.getABITypeAlign(ValTy);
15619 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOVolatile;
15620 return true;
15621 }
15622 case Intrinsic::aarch64_stlxr:
15623 case Intrinsic::aarch64_stxr: {
15624 Type *ValTy = I.getParamElementType(1);
15625 Info.opc = ISD::INTRINSIC_W_CHAIN;
15626 Info.memVT = MVT::getVT(ValTy);
15627 Info.ptrVal = I.getArgOperand(1);
15628 Info.offset = 0;
15629 Info.align = DL.getABITypeAlign(ValTy);
15630 Info.flags = MachineMemOperand::MOStore | MachineMemOperand::MOVolatile;
15631 return true;
15632 }
15633 case Intrinsic::aarch64_ldaxp:
15634 case Intrinsic::aarch64_ldxp:
15635 Info.opc = ISD::INTRINSIC_W_CHAIN;
15636 Info.memVT = MVT::i128;
15637 Info.ptrVal = I.getArgOperand(0);
15638 Info.offset = 0;
15639 Info.align = Align(16);
15640 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MOVolatile;
15641 return true;
15642 case Intrinsic::aarch64_stlxp:
15643 case Intrinsic::aarch64_stxp:
15644 Info.opc = ISD::INTRINSIC_W_CHAIN;
15645 Info.memVT = MVT::i128;
15646 Info.ptrVal = I.getArgOperand(2);
15647 Info.offset = 0;
15648 Info.align = Align(16);
15649 Info.flags = MachineMemOperand::MOStore | MachineMemOperand::MOVolatile;
15650 return true;
15651 case Intrinsic::aarch64_sve_ldnt1: {
15652 Type *ElTy = cast<VectorType>(I.getType())->getElementType();
15653 Info.opc = ISD::INTRINSIC_W_CHAIN;
15654 Info.memVT = MVT::getVT(I.getType());
15655 Info.ptrVal = I.getArgOperand(1);
15656 Info.offset = 0;
15657 Info.align = DL.getABITypeAlign(ElTy);
15658 Info.flags = MachineMemOperand::MOLoad | MachineMemOperand::MONonTemporal;
15659 return true;
15660 }
15661 case Intrinsic::aarch64_sve_stnt1: {
15662 Type *ElTy =
15663 cast<VectorType>(I.getArgOperand(0)->getType())->getElementType();
15664 Info.opc = ISD::INTRINSIC_W_CHAIN;
15665 Info.memVT = MVT::getVT(I.getOperand(0)->getType());
15666 Info.ptrVal = I.getArgOperand(2);
15667 Info.offset = 0;
15668 Info.align = DL.getABITypeAlign(ElTy);
15669 Info.flags = MachineMemOperand::MOStore | MachineMemOperand::MONonTemporal;
15670 return true;
15671 }
15672 case Intrinsic::aarch64_mops_memset_tag: {
15673 Value *Dst = I.getArgOperand(0);
15674 Value *Val = I.getArgOperand(1);
15675 Info.opc = ISD::INTRINSIC_W_CHAIN;
15676 Info.memVT = MVT::getVT(Val->getType());
15677 Info.ptrVal = Dst;
15678 Info.offset = 0;
15679 Info.align = I.getParamAlign(0).valueOrOne();
15680 Info.flags = MachineMemOperand::MOStore;
15681 // The size of the memory being operated on is unknown at this point
15682 Info.size = MemoryLocation::UnknownSize;
15683 return true;
15684 }
15685 default:
15686 break;
15687 }
15688
15689 return false;
15690 }
15691
shouldReduceLoadWidth(SDNode * Load,ISD::LoadExtType ExtTy,EVT NewVT) const15692 bool AArch64TargetLowering::shouldReduceLoadWidth(SDNode *Load,
15693 ISD::LoadExtType ExtTy,
15694 EVT NewVT) const {
15695 // TODO: This may be worth removing. Check regression tests for diffs.
15696 if (!TargetLoweringBase::shouldReduceLoadWidth(Load, ExtTy, NewVT))
15697 return false;
15698
15699 // If we're reducing the load width in order to avoid having to use an extra
15700 // instruction to do extension then it's probably a good idea.
15701 if (ExtTy != ISD::NON_EXTLOAD)
15702 return true;
15703 // Don't reduce load width if it would prevent us from combining a shift into
15704 // the offset.
15705 MemSDNode *Mem = dyn_cast<MemSDNode>(Load);
15706 assert(Mem);
15707 const SDValue &Base = Mem->getBasePtr();
15708 if (Base.getOpcode() == ISD::ADD &&
15709 Base.getOperand(1).getOpcode() == ISD::SHL &&
15710 Base.getOperand(1).hasOneUse() &&
15711 Base.getOperand(1).getOperand(1).getOpcode() == ISD::Constant) {
15712 // It's unknown whether a scalable vector has a power-of-2 bitwidth.
15713 if (Mem->getMemoryVT().isScalableVector())
15714 return false;
15715 // The shift can be combined if it matches the size of the value being
15716 // loaded (and so reducing the width would make it not match).
15717 uint64_t ShiftAmount = Base.getOperand(1).getConstantOperandVal(1);
15718 uint64_t LoadBytes = Mem->getMemoryVT().getSizeInBits()/8;
15719 if (ShiftAmount == Log2_32(LoadBytes))
15720 return false;
15721 }
15722 // We have no reason to disallow reducing the load width, so allow it.
15723 return true;
15724 }
15725
15726 // Treat a sext_inreg(extract(..)) as free if it has multiple uses.
shouldRemoveRedundantExtend(SDValue Extend) const15727 bool AArch64TargetLowering::shouldRemoveRedundantExtend(SDValue Extend) const {
15728 EVT VT = Extend.getValueType();
15729 if ((VT == MVT::i64 || VT == MVT::i32) && Extend->use_size()) {
15730 SDValue Extract = Extend.getOperand(0);
15731 if (Extract.getOpcode() == ISD::ANY_EXTEND && Extract.hasOneUse())
15732 Extract = Extract.getOperand(0);
15733 if (Extract.getOpcode() == ISD::EXTRACT_VECTOR_ELT && Extract.hasOneUse()) {
15734 EVT VecVT = Extract.getOperand(0).getValueType();
15735 if (VecVT.getScalarType() == MVT::i8 || VecVT.getScalarType() == MVT::i16)
15736 return false;
15737 }
15738 }
15739 return true;
15740 }
15741
15742 // Truncations from 64-bit GPR to 32-bit GPR is free.
isTruncateFree(Type * Ty1,Type * Ty2) const15743 bool AArch64TargetLowering::isTruncateFree(Type *Ty1, Type *Ty2) const {
15744 if (!Ty1->isIntegerTy() || !Ty2->isIntegerTy())
15745 return false;
15746 uint64_t NumBits1 = Ty1->getPrimitiveSizeInBits().getFixedValue();
15747 uint64_t NumBits2 = Ty2->getPrimitiveSizeInBits().getFixedValue();
15748 return NumBits1 > NumBits2;
15749 }
isTruncateFree(EVT VT1,EVT VT2) const15750 bool AArch64TargetLowering::isTruncateFree(EVT VT1, EVT VT2) const {
15751 if (VT1.isVector() || VT2.isVector() || !VT1.isInteger() || !VT2.isInteger())
15752 return false;
15753 uint64_t NumBits1 = VT1.getFixedSizeInBits();
15754 uint64_t NumBits2 = VT2.getFixedSizeInBits();
15755 return NumBits1 > NumBits2;
15756 }
15757
15758 /// Check if it is profitable to hoist instruction in then/else to if.
15759 /// Not profitable if I and it's user can form a FMA instruction
15760 /// because we prefer FMSUB/FMADD.
isProfitableToHoist(Instruction * I) const15761 bool AArch64TargetLowering::isProfitableToHoist(Instruction *I) const {
15762 if (I->getOpcode() != Instruction::FMul)
15763 return true;
15764
15765 if (!I->hasOneUse())
15766 return true;
15767
15768 Instruction *User = I->user_back();
15769
15770 if (!(User->getOpcode() == Instruction::FSub ||
15771 User->getOpcode() == Instruction::FAdd))
15772 return true;
15773
15774 const TargetOptions &Options = getTargetMachine().Options;
15775 const Function *F = I->getFunction();
15776 const DataLayout &DL = F->getDataLayout();
15777 Type *Ty = User->getOperand(0)->getType();
15778
15779 return !(isFMAFasterThanFMulAndFAdd(*F, Ty) &&
15780 isOperationLegalOrCustom(ISD::FMA, getValueType(DL, Ty)) &&
15781 (Options.AllowFPOpFusion == FPOpFusion::Fast ||
15782 Options.UnsafeFPMath));
15783 }
15784
15785 // All 32-bit GPR operations implicitly zero the high-half of the corresponding
15786 // 64-bit GPR.
isZExtFree(Type * Ty1,Type * Ty2) const15787 bool AArch64TargetLowering::isZExtFree(Type *Ty1, Type *Ty2) const {
15788 if (!Ty1->isIntegerTy() || !Ty2->isIntegerTy())
15789 return false;
15790 unsigned NumBits1 = Ty1->getPrimitiveSizeInBits();
15791 unsigned NumBits2 = Ty2->getPrimitiveSizeInBits();
15792 return NumBits1 == 32 && NumBits2 == 64;
15793 }
isZExtFree(EVT VT1,EVT VT2) const15794 bool AArch64TargetLowering::isZExtFree(EVT VT1, EVT VT2) const {
15795 if (VT1.isVector() || VT2.isVector() || !VT1.isInteger() || !VT2.isInteger())
15796 return false;
15797 unsigned NumBits1 = VT1.getSizeInBits();
15798 unsigned NumBits2 = VT2.getSizeInBits();
15799 return NumBits1 == 32 && NumBits2 == 64;
15800 }
15801
isZExtFree(SDValue Val,EVT VT2) const15802 bool AArch64TargetLowering::isZExtFree(SDValue Val, EVT VT2) const {
15803 EVT VT1 = Val.getValueType();
15804 if (isZExtFree(VT1, VT2)) {
15805 return true;
15806 }
15807
15808 if (Val.getOpcode() != ISD::LOAD)
15809 return false;
15810
15811 // 8-, 16-, and 32-bit integer loads all implicitly zero-extend.
15812 return (VT1.isSimple() && !VT1.isVector() && VT1.isInteger() &&
15813 VT2.isSimple() && !VT2.isVector() && VT2.isInteger() &&
15814 VT1.getSizeInBits() <= 32);
15815 }
15816
isExtFreeImpl(const Instruction * Ext) const15817 bool AArch64TargetLowering::isExtFreeImpl(const Instruction *Ext) const {
15818 if (isa<FPExtInst>(Ext))
15819 return false;
15820
15821 // Vector types are not free.
15822 if (Ext->getType()->isVectorTy())
15823 return false;
15824
15825 for (const Use &U : Ext->uses()) {
15826 // The extension is free if we can fold it with a left shift in an
15827 // addressing mode or an arithmetic operation: add, sub, and cmp.
15828
15829 // Is there a shift?
15830 const Instruction *Instr = cast<Instruction>(U.getUser());
15831
15832 // Is this a constant shift?
15833 switch (Instr->getOpcode()) {
15834 case Instruction::Shl:
15835 if (!isa<ConstantInt>(Instr->getOperand(1)))
15836 return false;
15837 break;
15838 case Instruction::GetElementPtr: {
15839 gep_type_iterator GTI = gep_type_begin(Instr);
15840 auto &DL = Ext->getDataLayout();
15841 std::advance(GTI, U.getOperandNo()-1);
15842 Type *IdxTy = GTI.getIndexedType();
15843 // This extension will end up with a shift because of the scaling factor.
15844 // 8-bit sized types have a scaling factor of 1, thus a shift amount of 0.
15845 // Get the shift amount based on the scaling factor:
15846 // log2(sizeof(IdxTy)) - log2(8).
15847 if (IdxTy->isScalableTy())
15848 return false;
15849 uint64_t ShiftAmt =
15850 llvm::countr_zero(DL.getTypeStoreSizeInBits(IdxTy).getFixedValue()) -
15851 3;
15852 // Is the constant foldable in the shift of the addressing mode?
15853 // I.e., shift amount is between 1 and 4 inclusive.
15854 if (ShiftAmt == 0 || ShiftAmt > 4)
15855 return false;
15856 break;
15857 }
15858 case Instruction::Trunc:
15859 // Check if this is a noop.
15860 // trunc(sext ty1 to ty2) to ty1.
15861 if (Instr->getType() == Ext->getOperand(0)->getType())
15862 continue;
15863 [[fallthrough]];
15864 default:
15865 return false;
15866 }
15867
15868 // At this point we can use the bfm family, so this extension is free
15869 // for that use.
15870 }
15871 return true;
15872 }
15873
isSplatShuffle(Value * V)15874 static bool isSplatShuffle(Value *V) {
15875 if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V))
15876 return all_equal(Shuf->getShuffleMask());
15877 return false;
15878 }
15879
15880 /// Check if both Op1 and Op2 are shufflevector extracts of either the lower
15881 /// or upper half of the vector elements.
areExtractShuffleVectors(Value * Op1,Value * Op2,bool AllowSplat=false)15882 static bool areExtractShuffleVectors(Value *Op1, Value *Op2,
15883 bool AllowSplat = false) {
15884 auto areTypesHalfed = [](Value *FullV, Value *HalfV) {
15885 auto *FullTy = FullV->getType();
15886 auto *HalfTy = HalfV->getType();
15887 return FullTy->getPrimitiveSizeInBits().getFixedValue() ==
15888 2 * HalfTy->getPrimitiveSizeInBits().getFixedValue();
15889 };
15890
15891 auto extractHalf = [](Value *FullV, Value *HalfV) {
15892 auto *FullVT = cast<FixedVectorType>(FullV->getType());
15893 auto *HalfVT = cast<FixedVectorType>(HalfV->getType());
15894 return FullVT->getNumElements() == 2 * HalfVT->getNumElements();
15895 };
15896
15897 ArrayRef<int> M1, M2;
15898 Value *S1Op1 = nullptr, *S2Op1 = nullptr;
15899 if (!match(Op1, m_Shuffle(m_Value(S1Op1), m_Undef(), m_Mask(M1))) ||
15900 !match(Op2, m_Shuffle(m_Value(S2Op1), m_Undef(), m_Mask(M2))))
15901 return false;
15902
15903 // If we allow splats, set S1Op1/S2Op1 to nullptr for the relavant arg so that
15904 // it is not checked as an extract below.
15905 if (AllowSplat && isSplatShuffle(Op1))
15906 S1Op1 = nullptr;
15907 if (AllowSplat && isSplatShuffle(Op2))
15908 S2Op1 = nullptr;
15909
15910 // Check that the operands are half as wide as the result and we extract
15911 // half of the elements of the input vectors.
15912 if ((S1Op1 && (!areTypesHalfed(S1Op1, Op1) || !extractHalf(S1Op1, Op1))) ||
15913 (S2Op1 && (!areTypesHalfed(S2Op1, Op2) || !extractHalf(S2Op1, Op2))))
15914 return false;
15915
15916 // Check the mask extracts either the lower or upper half of vector
15917 // elements.
15918 int M1Start = 0;
15919 int M2Start = 0;
15920 int NumElements = cast<FixedVectorType>(Op1->getType())->getNumElements() * 2;
15921 if ((S1Op1 &&
15922 !ShuffleVectorInst::isExtractSubvectorMask(M1, NumElements, M1Start)) ||
15923 (S2Op1 &&
15924 !ShuffleVectorInst::isExtractSubvectorMask(M2, NumElements, M2Start)))
15925 return false;
15926
15927 if ((M1Start != 0 && M1Start != (NumElements / 2)) ||
15928 (M2Start != 0 && M2Start != (NumElements / 2)))
15929 return false;
15930 if (S1Op1 && S2Op1 && M1Start != M2Start)
15931 return false;
15932
15933 return true;
15934 }
15935
15936 /// Check if Ext1 and Ext2 are extends of the same type, doubling the bitwidth
15937 /// of the vector elements.
areExtractExts(Value * Ext1,Value * Ext2)15938 static bool areExtractExts(Value *Ext1, Value *Ext2) {
15939 auto areExtDoubled = [](Instruction *Ext) {
15940 return Ext->getType()->getScalarSizeInBits() ==
15941 2 * Ext->getOperand(0)->getType()->getScalarSizeInBits();
15942 };
15943
15944 if (!match(Ext1, m_ZExtOrSExt(m_Value())) ||
15945 !match(Ext2, m_ZExtOrSExt(m_Value())) ||
15946 !areExtDoubled(cast<Instruction>(Ext1)) ||
15947 !areExtDoubled(cast<Instruction>(Ext2)))
15948 return false;
15949
15950 return true;
15951 }
15952
15953 /// Check if Op could be used with vmull_high_p64 intrinsic.
isOperandOfVmullHighP64(Value * Op)15954 static bool isOperandOfVmullHighP64(Value *Op) {
15955 Value *VectorOperand = nullptr;
15956 ConstantInt *ElementIndex = nullptr;
15957 return match(Op, m_ExtractElt(m_Value(VectorOperand),
15958 m_ConstantInt(ElementIndex))) &&
15959 ElementIndex->getValue() == 1 &&
15960 isa<FixedVectorType>(VectorOperand->getType()) &&
15961 cast<FixedVectorType>(VectorOperand->getType())->getNumElements() == 2;
15962 }
15963
15964 /// Check if Op1 and Op2 could be used with vmull_high_p64 intrinsic.
areOperandsOfVmullHighP64(Value * Op1,Value * Op2)15965 static bool areOperandsOfVmullHighP64(Value *Op1, Value *Op2) {
15966 return isOperandOfVmullHighP64(Op1) && isOperandOfVmullHighP64(Op2);
15967 }
15968
shouldSinkVectorOfPtrs(Value * Ptrs,SmallVectorImpl<Use * > & Ops)15969 static bool shouldSinkVectorOfPtrs(Value *Ptrs, SmallVectorImpl<Use *> &Ops) {
15970 // Restrict ourselves to the form CodeGenPrepare typically constructs.
15971 auto *GEP = dyn_cast<GetElementPtrInst>(Ptrs);
15972 if (!GEP || GEP->getNumOperands() != 2)
15973 return false;
15974
15975 Value *Base = GEP->getOperand(0);
15976 Value *Offsets = GEP->getOperand(1);
15977
15978 // We only care about scalar_base+vector_offsets.
15979 if (Base->getType()->isVectorTy() || !Offsets->getType()->isVectorTy())
15980 return false;
15981
15982 // Sink extends that would allow us to use 32-bit offset vectors.
15983 if (isa<SExtInst>(Offsets) || isa<ZExtInst>(Offsets)) {
15984 auto *OffsetsInst = cast<Instruction>(Offsets);
15985 if (OffsetsInst->getType()->getScalarSizeInBits() > 32 &&
15986 OffsetsInst->getOperand(0)->getType()->getScalarSizeInBits() <= 32)
15987 Ops.push_back(&GEP->getOperandUse(1));
15988 }
15989
15990 // Sink the GEP.
15991 return true;
15992 }
15993
15994 /// We want to sink following cases:
15995 /// (add|sub|gep) A, ((mul|shl) vscale, imm); (add|sub|gep) A, vscale;
15996 /// (add|sub|gep) A, ((mul|shl) zext(vscale), imm);
shouldSinkVScale(Value * Op,SmallVectorImpl<Use * > & Ops)15997 static bool shouldSinkVScale(Value *Op, SmallVectorImpl<Use *> &Ops) {
15998 if (match(Op, m_VScale()))
15999 return true;
16000 if (match(Op, m_Shl(m_VScale(), m_ConstantInt())) ||
16001 match(Op, m_Mul(m_VScale(), m_ConstantInt()))) {
16002 Ops.push_back(&cast<Instruction>(Op)->getOperandUse(0));
16003 return true;
16004 }
16005 if (match(Op, m_Shl(m_ZExt(m_VScale()), m_ConstantInt())) ||
16006 match(Op, m_Mul(m_ZExt(m_VScale()), m_ConstantInt()))) {
16007 Value *ZExtOp = cast<Instruction>(Op)->getOperand(0);
16008 Ops.push_back(&cast<Instruction>(ZExtOp)->getOperandUse(0));
16009 Ops.push_back(&cast<Instruction>(Op)->getOperandUse(0));
16010 return true;
16011 }
16012 return false;
16013 }
16014
16015 /// Check if sinking \p I's operands to I's basic block is profitable, because
16016 /// the operands can be folded into a target instruction, e.g.
16017 /// shufflevectors extracts and/or sext/zext can be folded into (u,s)subl(2).
shouldSinkOperands(Instruction * I,SmallVectorImpl<Use * > & Ops) const16018 bool AArch64TargetLowering::shouldSinkOperands(
16019 Instruction *I, SmallVectorImpl<Use *> &Ops) const {
16020 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
16021 switch (II->getIntrinsicID()) {
16022 case Intrinsic::aarch64_neon_smull:
16023 case Intrinsic::aarch64_neon_umull:
16024 if (areExtractShuffleVectors(II->getOperand(0), II->getOperand(1),
16025 /*AllowSplat=*/true)) {
16026 Ops.push_back(&II->getOperandUse(0));
16027 Ops.push_back(&II->getOperandUse(1));
16028 return true;
16029 }
16030 [[fallthrough]];
16031
16032 case Intrinsic::fma:
16033 if (isa<VectorType>(I->getType()) &&
16034 cast<VectorType>(I->getType())->getElementType()->isHalfTy() &&
16035 !Subtarget->hasFullFP16())
16036 return false;
16037 [[fallthrough]];
16038 case Intrinsic::aarch64_neon_sqdmull:
16039 case Intrinsic::aarch64_neon_sqdmulh:
16040 case Intrinsic::aarch64_neon_sqrdmulh:
16041 // Sink splats for index lane variants
16042 if (isSplatShuffle(II->getOperand(0)))
16043 Ops.push_back(&II->getOperandUse(0));
16044 if (isSplatShuffle(II->getOperand(1)))
16045 Ops.push_back(&II->getOperandUse(1));
16046 return !Ops.empty();
16047 case Intrinsic::aarch64_neon_fmlal:
16048 case Intrinsic::aarch64_neon_fmlal2:
16049 case Intrinsic::aarch64_neon_fmlsl:
16050 case Intrinsic::aarch64_neon_fmlsl2:
16051 // Sink splats for index lane variants
16052 if (isSplatShuffle(II->getOperand(1)))
16053 Ops.push_back(&II->getOperandUse(1));
16054 if (isSplatShuffle(II->getOperand(2)))
16055 Ops.push_back(&II->getOperandUse(2));
16056 return !Ops.empty();
16057 case Intrinsic::aarch64_sve_ptest_first:
16058 case Intrinsic::aarch64_sve_ptest_last:
16059 if (auto *IIOp = dyn_cast<IntrinsicInst>(II->getOperand(0)))
16060 if (IIOp->getIntrinsicID() == Intrinsic::aarch64_sve_ptrue)
16061 Ops.push_back(&II->getOperandUse(0));
16062 return !Ops.empty();
16063 case Intrinsic::aarch64_sme_write_horiz:
16064 case Intrinsic::aarch64_sme_write_vert:
16065 case Intrinsic::aarch64_sme_writeq_horiz:
16066 case Intrinsic::aarch64_sme_writeq_vert: {
16067 auto *Idx = dyn_cast<Instruction>(II->getOperand(1));
16068 if (!Idx || Idx->getOpcode() != Instruction::Add)
16069 return false;
16070 Ops.push_back(&II->getOperandUse(1));
16071 return true;
16072 }
16073 case Intrinsic::aarch64_sme_read_horiz:
16074 case Intrinsic::aarch64_sme_read_vert:
16075 case Intrinsic::aarch64_sme_readq_horiz:
16076 case Intrinsic::aarch64_sme_readq_vert:
16077 case Intrinsic::aarch64_sme_ld1b_vert:
16078 case Intrinsic::aarch64_sme_ld1h_vert:
16079 case Intrinsic::aarch64_sme_ld1w_vert:
16080 case Intrinsic::aarch64_sme_ld1d_vert:
16081 case Intrinsic::aarch64_sme_ld1q_vert:
16082 case Intrinsic::aarch64_sme_st1b_vert:
16083 case Intrinsic::aarch64_sme_st1h_vert:
16084 case Intrinsic::aarch64_sme_st1w_vert:
16085 case Intrinsic::aarch64_sme_st1d_vert:
16086 case Intrinsic::aarch64_sme_st1q_vert:
16087 case Intrinsic::aarch64_sme_ld1b_horiz:
16088 case Intrinsic::aarch64_sme_ld1h_horiz:
16089 case Intrinsic::aarch64_sme_ld1w_horiz:
16090 case Intrinsic::aarch64_sme_ld1d_horiz:
16091 case Intrinsic::aarch64_sme_ld1q_horiz:
16092 case Intrinsic::aarch64_sme_st1b_horiz:
16093 case Intrinsic::aarch64_sme_st1h_horiz:
16094 case Intrinsic::aarch64_sme_st1w_horiz:
16095 case Intrinsic::aarch64_sme_st1d_horiz:
16096 case Intrinsic::aarch64_sme_st1q_horiz: {
16097 auto *Idx = dyn_cast<Instruction>(II->getOperand(3));
16098 if (!Idx || Idx->getOpcode() != Instruction::Add)
16099 return false;
16100 Ops.push_back(&II->getOperandUse(3));
16101 return true;
16102 }
16103 case Intrinsic::aarch64_neon_pmull:
16104 if (!areExtractShuffleVectors(II->getOperand(0), II->getOperand(1)))
16105 return false;
16106 Ops.push_back(&II->getOperandUse(0));
16107 Ops.push_back(&II->getOperandUse(1));
16108 return true;
16109 case Intrinsic::aarch64_neon_pmull64:
16110 if (!areOperandsOfVmullHighP64(II->getArgOperand(0),
16111 II->getArgOperand(1)))
16112 return false;
16113 Ops.push_back(&II->getArgOperandUse(0));
16114 Ops.push_back(&II->getArgOperandUse(1));
16115 return true;
16116 case Intrinsic::masked_gather:
16117 if (!shouldSinkVectorOfPtrs(II->getArgOperand(0), Ops))
16118 return false;
16119 Ops.push_back(&II->getArgOperandUse(0));
16120 return true;
16121 case Intrinsic::masked_scatter:
16122 if (!shouldSinkVectorOfPtrs(II->getArgOperand(1), Ops))
16123 return false;
16124 Ops.push_back(&II->getArgOperandUse(1));
16125 return true;
16126 default:
16127 return false;
16128 }
16129 }
16130
16131 // Sink vscales closer to uses for better isel
16132 switch (I->getOpcode()) {
16133 case Instruction::GetElementPtr:
16134 case Instruction::Add:
16135 case Instruction::Sub:
16136 for (unsigned Op = 0; Op < I->getNumOperands(); ++Op) {
16137 if (shouldSinkVScale(I->getOperand(Op), Ops)) {
16138 Ops.push_back(&I->getOperandUse(Op));
16139 return true;
16140 }
16141 }
16142 break;
16143 default:
16144 break;
16145 }
16146
16147 if (!I->getType()->isVectorTy())
16148 return false;
16149
16150 switch (I->getOpcode()) {
16151 case Instruction::Sub:
16152 case Instruction::Add: {
16153 if (!areExtractExts(I->getOperand(0), I->getOperand(1)))
16154 return false;
16155
16156 // If the exts' operands extract either the lower or upper elements, we
16157 // can sink them too.
16158 auto Ext1 = cast<Instruction>(I->getOperand(0));
16159 auto Ext2 = cast<Instruction>(I->getOperand(1));
16160 if (areExtractShuffleVectors(Ext1->getOperand(0), Ext2->getOperand(0))) {
16161 Ops.push_back(&Ext1->getOperandUse(0));
16162 Ops.push_back(&Ext2->getOperandUse(0));
16163 }
16164
16165 Ops.push_back(&I->getOperandUse(0));
16166 Ops.push_back(&I->getOperandUse(1));
16167
16168 return true;
16169 }
16170 case Instruction::Or: {
16171 // Pattern: Or(And(MaskValue, A), And(Not(MaskValue), B)) ->
16172 // bitselect(MaskValue, A, B) where Not(MaskValue) = Xor(MaskValue, -1)
16173 if (Subtarget->hasNEON()) {
16174 Instruction *OtherAnd, *IA, *IB;
16175 Value *MaskValue;
16176 // MainAnd refers to And instruction that has 'Not' as one of its operands
16177 if (match(I, m_c_Or(m_OneUse(m_Instruction(OtherAnd)),
16178 m_OneUse(m_c_And(m_OneUse(m_Not(m_Value(MaskValue))),
16179 m_Instruction(IA)))))) {
16180 if (match(OtherAnd,
16181 m_c_And(m_Specific(MaskValue), m_Instruction(IB)))) {
16182 Instruction *MainAnd = I->getOperand(0) == OtherAnd
16183 ? cast<Instruction>(I->getOperand(1))
16184 : cast<Instruction>(I->getOperand(0));
16185
16186 // Both Ands should be in same basic block as Or
16187 if (I->getParent() != MainAnd->getParent() ||
16188 I->getParent() != OtherAnd->getParent())
16189 return false;
16190
16191 // Non-mask operands of both Ands should also be in same basic block
16192 if (I->getParent() != IA->getParent() ||
16193 I->getParent() != IB->getParent())
16194 return false;
16195
16196 Ops.push_back(&MainAnd->getOperandUse(MainAnd->getOperand(0) == IA ? 1 : 0));
16197 Ops.push_back(&I->getOperandUse(0));
16198 Ops.push_back(&I->getOperandUse(1));
16199
16200 return true;
16201 }
16202 }
16203 }
16204
16205 return false;
16206 }
16207 case Instruction::Mul: {
16208 int NumZExts = 0, NumSExts = 0;
16209 for (auto &Op : I->operands()) {
16210 // Make sure we are not already sinking this operand
16211 if (any_of(Ops, [&](Use *U) { return U->get() == Op; }))
16212 continue;
16213
16214 if (match(&Op, m_SExt(m_Value()))) {
16215 NumSExts++;
16216 continue;
16217 } else if (match(&Op, m_ZExt(m_Value()))) {
16218 NumZExts++;
16219 continue;
16220 }
16221
16222 ShuffleVectorInst *Shuffle = dyn_cast<ShuffleVectorInst>(Op);
16223
16224 // If the Shuffle is a splat and the operand is a zext/sext, sinking the
16225 // operand and the s/zext can help create indexed s/umull. This is
16226 // especially useful to prevent i64 mul being scalarized.
16227 if (Shuffle && isSplatShuffle(Shuffle) &&
16228 match(Shuffle->getOperand(0), m_ZExtOrSExt(m_Value()))) {
16229 Ops.push_back(&Shuffle->getOperandUse(0));
16230 Ops.push_back(&Op);
16231 if (match(Shuffle->getOperand(0), m_SExt(m_Value())))
16232 NumSExts++;
16233 else
16234 NumZExts++;
16235 continue;
16236 }
16237
16238 if (!Shuffle)
16239 continue;
16240
16241 Value *ShuffleOperand = Shuffle->getOperand(0);
16242 InsertElementInst *Insert = dyn_cast<InsertElementInst>(ShuffleOperand);
16243 if (!Insert)
16244 continue;
16245
16246 Instruction *OperandInstr = dyn_cast<Instruction>(Insert->getOperand(1));
16247 if (!OperandInstr)
16248 continue;
16249
16250 ConstantInt *ElementConstant =
16251 dyn_cast<ConstantInt>(Insert->getOperand(2));
16252 // Check that the insertelement is inserting into element 0
16253 if (!ElementConstant || !ElementConstant->isZero())
16254 continue;
16255
16256 unsigned Opcode = OperandInstr->getOpcode();
16257 if (Opcode == Instruction::SExt)
16258 NumSExts++;
16259 else if (Opcode == Instruction::ZExt)
16260 NumZExts++;
16261 else {
16262 // If we find that the top bits are known 0, then we can sink and allow
16263 // the backend to generate a umull.
16264 unsigned Bitwidth = I->getType()->getScalarSizeInBits();
16265 APInt UpperMask = APInt::getHighBitsSet(Bitwidth, Bitwidth / 2);
16266 const DataLayout &DL = I->getDataLayout();
16267 if (!MaskedValueIsZero(OperandInstr, UpperMask, DL))
16268 continue;
16269 NumZExts++;
16270 }
16271
16272 Ops.push_back(&Shuffle->getOperandUse(0));
16273 Ops.push_back(&Op);
16274 }
16275
16276 // Is it profitable to sink if we found two of the same type of extends.
16277 return !Ops.empty() && (NumSExts == 2 || NumZExts == 2);
16278 }
16279 default:
16280 return false;
16281 }
16282 return false;
16283 }
16284
createTblShuffleMask(unsigned SrcWidth,unsigned DstWidth,unsigned NumElts,bool IsLittleEndian,SmallVectorImpl<int> & Mask)16285 static bool createTblShuffleMask(unsigned SrcWidth, unsigned DstWidth,
16286 unsigned NumElts, bool IsLittleEndian,
16287 SmallVectorImpl<int> &Mask) {
16288 if (DstWidth % 8 != 0 || DstWidth <= 16 || DstWidth >= 64)
16289 return false;
16290
16291 assert(DstWidth % SrcWidth == 0 &&
16292 "TBL lowering is not supported for a conversion instruction with this "
16293 "source and destination element type.");
16294
16295 unsigned Factor = DstWidth / SrcWidth;
16296 unsigned MaskLen = NumElts * Factor;
16297
16298 Mask.clear();
16299 Mask.resize(MaskLen, NumElts);
16300
16301 unsigned SrcIndex = 0;
16302 for (unsigned I = IsLittleEndian ? 0 : Factor - 1; I < MaskLen; I += Factor)
16303 Mask[I] = SrcIndex++;
16304
16305 return true;
16306 }
16307
createTblShuffleForZExt(IRBuilderBase & Builder,Value * Op,FixedVectorType * ZExtTy,FixedVectorType * DstTy,bool IsLittleEndian)16308 static Value *createTblShuffleForZExt(IRBuilderBase &Builder, Value *Op,
16309 FixedVectorType *ZExtTy,
16310 FixedVectorType *DstTy,
16311 bool IsLittleEndian) {
16312 auto *SrcTy = cast<FixedVectorType>(Op->getType());
16313 unsigned NumElts = SrcTy->getNumElements();
16314 auto SrcWidth = cast<IntegerType>(SrcTy->getElementType())->getBitWidth();
16315 auto DstWidth = cast<IntegerType>(DstTy->getElementType())->getBitWidth();
16316
16317 SmallVector<int> Mask;
16318 if (!createTblShuffleMask(SrcWidth, DstWidth, NumElts, IsLittleEndian, Mask))
16319 return nullptr;
16320
16321 auto *FirstEltZero = Builder.CreateInsertElement(
16322 PoisonValue::get(SrcTy), Builder.getInt8(0), uint64_t(0));
16323 Value *Result = Builder.CreateShuffleVector(Op, FirstEltZero, Mask);
16324 Result = Builder.CreateBitCast(Result, DstTy);
16325 if (DstTy != ZExtTy)
16326 Result = Builder.CreateZExt(Result, ZExtTy);
16327 return Result;
16328 }
16329
createTblShuffleForSExt(IRBuilderBase & Builder,Value * Op,FixedVectorType * DstTy,bool IsLittleEndian)16330 static Value *createTblShuffleForSExt(IRBuilderBase &Builder, Value *Op,
16331 FixedVectorType *DstTy,
16332 bool IsLittleEndian) {
16333 auto *SrcTy = cast<FixedVectorType>(Op->getType());
16334 auto SrcWidth = cast<IntegerType>(SrcTy->getElementType())->getBitWidth();
16335 auto DstWidth = cast<IntegerType>(DstTy->getElementType())->getBitWidth();
16336
16337 SmallVector<int> Mask;
16338 if (!createTblShuffleMask(SrcWidth, DstWidth, SrcTy->getNumElements(),
16339 !IsLittleEndian, Mask))
16340 return nullptr;
16341
16342 auto *FirstEltZero = Builder.CreateInsertElement(
16343 PoisonValue::get(SrcTy), Builder.getInt8(0), uint64_t(0));
16344
16345 return Builder.CreateShuffleVector(Op, FirstEltZero, Mask);
16346 }
16347
createTblForTrunc(TruncInst * TI,bool IsLittleEndian)16348 static void createTblForTrunc(TruncInst *TI, bool IsLittleEndian) {
16349 IRBuilder<> Builder(TI);
16350 SmallVector<Value *> Parts;
16351 int NumElements = cast<FixedVectorType>(TI->getType())->getNumElements();
16352 auto *SrcTy = cast<FixedVectorType>(TI->getOperand(0)->getType());
16353 auto *DstTy = cast<FixedVectorType>(TI->getType());
16354 assert(SrcTy->getElementType()->isIntegerTy() &&
16355 "Non-integer type source vector element is not supported");
16356 assert(DstTy->getElementType()->isIntegerTy(8) &&
16357 "Unsupported destination vector element type");
16358 unsigned SrcElemTySz =
16359 cast<IntegerType>(SrcTy->getElementType())->getBitWidth();
16360 unsigned DstElemTySz =
16361 cast<IntegerType>(DstTy->getElementType())->getBitWidth();
16362 assert((SrcElemTySz % DstElemTySz == 0) &&
16363 "Cannot lower truncate to tbl instructions for a source element size "
16364 "that is not divisible by the destination element size");
16365 unsigned TruncFactor = SrcElemTySz / DstElemTySz;
16366 assert((SrcElemTySz == 16 || SrcElemTySz == 32 || SrcElemTySz == 64) &&
16367 "Unsupported source vector element type size");
16368 Type *VecTy = FixedVectorType::get(Builder.getInt8Ty(), 16);
16369
16370 // Create a mask to choose every nth byte from the source vector table of
16371 // bytes to create the truncated destination vector, where 'n' is the truncate
16372 // ratio. For example, for a truncate from Yxi64 to Yxi8, choose
16373 // 0,8,16,..Y*8th bytes for the little-endian format
16374 SmallVector<Constant *, 16> MaskConst;
16375 for (int Itr = 0; Itr < 16; Itr++) {
16376 if (Itr < NumElements)
16377 MaskConst.push_back(Builder.getInt8(
16378 IsLittleEndian ? Itr * TruncFactor
16379 : Itr * TruncFactor + (TruncFactor - 1)));
16380 else
16381 MaskConst.push_back(Builder.getInt8(255));
16382 }
16383
16384 int MaxTblSz = 128 * 4;
16385 int MaxSrcSz = SrcElemTySz * NumElements;
16386 int ElemsPerTbl =
16387 (MaxTblSz > MaxSrcSz) ? NumElements : (MaxTblSz / SrcElemTySz);
16388 assert(ElemsPerTbl <= 16 &&
16389 "Maximum elements selected using TBL instruction cannot exceed 16!");
16390
16391 int ShuffleCount = 128 / SrcElemTySz;
16392 SmallVector<int> ShuffleLanes;
16393 for (int i = 0; i < ShuffleCount; ++i)
16394 ShuffleLanes.push_back(i);
16395
16396 // Create TBL's table of bytes in 1,2,3 or 4 FP/SIMD registers using shuffles
16397 // over the source vector. If TBL's maximum 4 FP/SIMD registers are saturated,
16398 // call TBL & save the result in a vector of TBL results for combining later.
16399 SmallVector<Value *> Results;
16400 while (ShuffleLanes.back() < NumElements) {
16401 Parts.push_back(Builder.CreateBitCast(
16402 Builder.CreateShuffleVector(TI->getOperand(0), ShuffleLanes), VecTy));
16403
16404 if (Parts.size() == 4) {
16405 auto *F = Intrinsic::getDeclaration(TI->getModule(),
16406 Intrinsic::aarch64_neon_tbl4, VecTy);
16407 Parts.push_back(ConstantVector::get(MaskConst));
16408 Results.push_back(Builder.CreateCall(F, Parts));
16409 Parts.clear();
16410 }
16411
16412 for (int i = 0; i < ShuffleCount; ++i)
16413 ShuffleLanes[i] += ShuffleCount;
16414 }
16415
16416 assert((Parts.empty() || Results.empty()) &&
16417 "Lowering trunc for vectors requiring different TBL instructions is "
16418 "not supported!");
16419 // Call TBL for the residual table bytes present in 1,2, or 3 FP/SIMD
16420 // registers
16421 if (!Parts.empty()) {
16422 Intrinsic::ID TblID;
16423 switch (Parts.size()) {
16424 case 1:
16425 TblID = Intrinsic::aarch64_neon_tbl1;
16426 break;
16427 case 2:
16428 TblID = Intrinsic::aarch64_neon_tbl2;
16429 break;
16430 case 3:
16431 TblID = Intrinsic::aarch64_neon_tbl3;
16432 break;
16433 }
16434
16435 auto *F = Intrinsic::getDeclaration(TI->getModule(), TblID, VecTy);
16436 Parts.push_back(ConstantVector::get(MaskConst));
16437 Results.push_back(Builder.CreateCall(F, Parts));
16438 }
16439
16440 // Extract the destination vector from TBL result(s) after combining them
16441 // where applicable. Currently, at most two TBLs are supported.
16442 assert(Results.size() <= 2 && "Trunc lowering does not support generation of "
16443 "more than 2 tbl instructions!");
16444 Value *FinalResult = Results[0];
16445 if (Results.size() == 1) {
16446 if (ElemsPerTbl < 16) {
16447 SmallVector<int> FinalMask(ElemsPerTbl);
16448 std::iota(FinalMask.begin(), FinalMask.end(), 0);
16449 FinalResult = Builder.CreateShuffleVector(Results[0], FinalMask);
16450 }
16451 } else {
16452 SmallVector<int> FinalMask(ElemsPerTbl * Results.size());
16453 if (ElemsPerTbl < 16) {
16454 std::iota(FinalMask.begin(), FinalMask.begin() + ElemsPerTbl, 0);
16455 std::iota(FinalMask.begin() + ElemsPerTbl, FinalMask.end(), 16);
16456 } else {
16457 std::iota(FinalMask.begin(), FinalMask.end(), 0);
16458 }
16459 FinalResult =
16460 Builder.CreateShuffleVector(Results[0], Results[1], FinalMask);
16461 }
16462
16463 TI->replaceAllUsesWith(FinalResult);
16464 TI->eraseFromParent();
16465 }
16466
optimizeExtendOrTruncateConversion(Instruction * I,Loop * L,const TargetTransformInfo & TTI) const16467 bool AArch64TargetLowering::optimizeExtendOrTruncateConversion(
16468 Instruction *I, Loop *L, const TargetTransformInfo &TTI) const {
16469 // shuffle_vector instructions are serialized when targeting SVE,
16470 // see LowerSPLAT_VECTOR. This peephole is not beneficial.
16471 if (!EnableExtToTBL || Subtarget->useSVEForFixedLengthVectors())
16472 return false;
16473
16474 // Try to optimize conversions using tbl. This requires materializing constant
16475 // index vectors, which can increase code size and add loads. Skip the
16476 // transform unless the conversion is in a loop block guaranteed to execute
16477 // and we are not optimizing for size.
16478 Function *F = I->getParent()->getParent();
16479 if (!L || L->getHeader() != I->getParent() || F->hasMinSize() ||
16480 F->hasOptSize())
16481 return false;
16482
16483 auto *SrcTy = dyn_cast<FixedVectorType>(I->getOperand(0)->getType());
16484 auto *DstTy = dyn_cast<FixedVectorType>(I->getType());
16485 if (!SrcTy || !DstTy)
16486 return false;
16487
16488 // Convert 'zext <Y x i8> %x to <Y x i8X>' to a shuffle that can be
16489 // lowered to tbl instructions to insert the original i8 elements
16490 // into i8x lanes. This is enabled for cases where it is beneficial.
16491 auto *ZExt = dyn_cast<ZExtInst>(I);
16492 if (ZExt && SrcTy->getElementType()->isIntegerTy(8)) {
16493 auto DstWidth = DstTy->getElementType()->getScalarSizeInBits();
16494 if (DstWidth % 8 != 0)
16495 return false;
16496
16497 auto *TruncDstType =
16498 cast<FixedVectorType>(VectorType::getTruncatedElementVectorType(DstTy));
16499 // If the ZExt can be lowered to a single ZExt to the next power-of-2 and
16500 // the remaining ZExt folded into the user, don't use tbl lowering.
16501 auto SrcWidth = SrcTy->getElementType()->getScalarSizeInBits();
16502 if (TTI.getCastInstrCost(I->getOpcode(), DstTy, TruncDstType,
16503 TargetTransformInfo::getCastContextHint(I),
16504 TTI::TCK_SizeAndLatency, I) == TTI::TCC_Free) {
16505 if (SrcWidth * 2 >= TruncDstType->getElementType()->getScalarSizeInBits())
16506 return false;
16507
16508 DstTy = TruncDstType;
16509 }
16510 IRBuilder<> Builder(ZExt);
16511 Value *Result = createTblShuffleForZExt(
16512 Builder, ZExt->getOperand(0), cast<FixedVectorType>(ZExt->getType()),
16513 DstTy, Subtarget->isLittleEndian());
16514 if (!Result)
16515 return false;
16516 ZExt->replaceAllUsesWith(Result);
16517 ZExt->eraseFromParent();
16518 return true;
16519 }
16520
16521 auto *UIToFP = dyn_cast<UIToFPInst>(I);
16522 if (UIToFP && SrcTy->getElementType()->isIntegerTy(8) &&
16523 DstTy->getElementType()->isFloatTy()) {
16524 IRBuilder<> Builder(I);
16525 Value *ZExt = createTblShuffleForZExt(
16526 Builder, I->getOperand(0), FixedVectorType::getInteger(DstTy),
16527 FixedVectorType::getInteger(DstTy), Subtarget->isLittleEndian());
16528 assert(ZExt && "Cannot fail for the i8 to float conversion");
16529 auto *UI = Builder.CreateUIToFP(ZExt, DstTy);
16530 I->replaceAllUsesWith(UI);
16531 I->eraseFromParent();
16532 return true;
16533 }
16534
16535 auto *SIToFP = dyn_cast<SIToFPInst>(I);
16536 if (SIToFP && SrcTy->getElementType()->isIntegerTy(8) &&
16537 DstTy->getElementType()->isFloatTy()) {
16538 IRBuilder<> Builder(I);
16539 auto *Shuffle = createTblShuffleForSExt(Builder, I->getOperand(0),
16540 FixedVectorType::getInteger(DstTy),
16541 Subtarget->isLittleEndian());
16542 assert(Shuffle && "Cannot fail for the i8 to float conversion");
16543 auto *Cast = Builder.CreateBitCast(Shuffle, VectorType::getInteger(DstTy));
16544 auto *AShr = Builder.CreateAShr(Cast, 24, "", true);
16545 auto *SI = Builder.CreateSIToFP(AShr, DstTy);
16546 I->replaceAllUsesWith(SI);
16547 I->eraseFromParent();
16548 return true;
16549 }
16550
16551 // Convert 'fptoui <(8|16) x float> to <(8|16) x i8>' to a wide fptoui
16552 // followed by a truncate lowered to using tbl.4.
16553 auto *FPToUI = dyn_cast<FPToUIInst>(I);
16554 if (FPToUI &&
16555 (SrcTy->getNumElements() == 8 || SrcTy->getNumElements() == 16) &&
16556 SrcTy->getElementType()->isFloatTy() &&
16557 DstTy->getElementType()->isIntegerTy(8)) {
16558 IRBuilder<> Builder(I);
16559 auto *WideConv = Builder.CreateFPToUI(FPToUI->getOperand(0),
16560 VectorType::getInteger(SrcTy));
16561 auto *TruncI = Builder.CreateTrunc(WideConv, DstTy);
16562 I->replaceAllUsesWith(TruncI);
16563 I->eraseFromParent();
16564 createTblForTrunc(cast<TruncInst>(TruncI), Subtarget->isLittleEndian());
16565 return true;
16566 }
16567
16568 // Convert 'trunc <(8|16) x (i32|i64)> %x to <(8|16) x i8>' to an appropriate
16569 // tbl instruction selecting the lowest/highest (little/big endian) 8 bits
16570 // per lane of the input that is represented using 1,2,3 or 4 128-bit table
16571 // registers
16572 auto *TI = dyn_cast<TruncInst>(I);
16573 if (TI && DstTy->getElementType()->isIntegerTy(8) &&
16574 ((SrcTy->getElementType()->isIntegerTy(32) ||
16575 SrcTy->getElementType()->isIntegerTy(64)) &&
16576 (SrcTy->getNumElements() == 16 || SrcTy->getNumElements() == 8))) {
16577 createTblForTrunc(TI, Subtarget->isLittleEndian());
16578 return true;
16579 }
16580
16581 return false;
16582 }
16583
hasPairedLoad(EVT LoadedType,Align & RequiredAligment) const16584 bool AArch64TargetLowering::hasPairedLoad(EVT LoadedType,
16585 Align &RequiredAligment) const {
16586 if (!LoadedType.isSimple() ||
16587 (!LoadedType.isInteger() && !LoadedType.isFloatingPoint()))
16588 return false;
16589 // Cyclone supports unaligned accesses.
16590 RequiredAligment = Align(1);
16591 unsigned NumBits = LoadedType.getSizeInBits();
16592 return NumBits == 32 || NumBits == 64;
16593 }
16594
16595 /// A helper function for determining the number of interleaved accesses we
16596 /// will generate when lowering accesses of the given type.
getNumInterleavedAccesses(VectorType * VecTy,const DataLayout & DL,bool UseScalable) const16597 unsigned AArch64TargetLowering::getNumInterleavedAccesses(
16598 VectorType *VecTy, const DataLayout &DL, bool UseScalable) const {
16599 unsigned VecSize = 128;
16600 unsigned ElSize = DL.getTypeSizeInBits(VecTy->getElementType());
16601 unsigned MinElts = VecTy->getElementCount().getKnownMinValue();
16602 if (UseScalable && isa<FixedVectorType>(VecTy))
16603 VecSize = std::max(Subtarget->getMinSVEVectorSizeInBits(), 128u);
16604 return std::max<unsigned>(1, (MinElts * ElSize + 127) / VecSize);
16605 }
16606
16607 MachineMemOperand::Flags
getTargetMMOFlags(const Instruction & I) const16608 AArch64TargetLowering::getTargetMMOFlags(const Instruction &I) const {
16609 if (Subtarget->getProcFamily() == AArch64Subtarget::Falkor &&
16610 I.hasMetadata(FALKOR_STRIDED_ACCESS_MD))
16611 return MOStridedAccess;
16612 return MachineMemOperand::MONone;
16613 }
16614
isLegalInterleavedAccessType(VectorType * VecTy,const DataLayout & DL,bool & UseScalable) const16615 bool AArch64TargetLowering::isLegalInterleavedAccessType(
16616 VectorType *VecTy, const DataLayout &DL, bool &UseScalable) const {
16617 unsigned ElSize = DL.getTypeSizeInBits(VecTy->getElementType());
16618 auto EC = VecTy->getElementCount();
16619 unsigned MinElts = EC.getKnownMinValue();
16620
16621 UseScalable = false;
16622
16623 if (isa<FixedVectorType>(VecTy) && !Subtarget->isNeonAvailable() &&
16624 (!Subtarget->useSVEForFixedLengthVectors() ||
16625 !getSVEPredPatternFromNumElements(MinElts)))
16626 return false;
16627
16628 if (isa<ScalableVectorType>(VecTy) &&
16629 !Subtarget->isSVEorStreamingSVEAvailable())
16630 return false;
16631
16632 // Ensure the number of vector elements is greater than 1.
16633 if (MinElts < 2)
16634 return false;
16635
16636 // Ensure the element type is legal.
16637 if (ElSize != 8 && ElSize != 16 && ElSize != 32 && ElSize != 64)
16638 return false;
16639
16640 if (EC.isScalable()) {
16641 UseScalable = true;
16642 return isPowerOf2_32(MinElts) && (MinElts * ElSize) % 128 == 0;
16643 }
16644
16645 unsigned VecSize = DL.getTypeSizeInBits(VecTy);
16646 if (Subtarget->useSVEForFixedLengthVectors()) {
16647 unsigned MinSVEVectorSize =
16648 std::max(Subtarget->getMinSVEVectorSizeInBits(), 128u);
16649 if (VecSize % MinSVEVectorSize == 0 ||
16650 (VecSize < MinSVEVectorSize && isPowerOf2_32(MinElts) &&
16651 (!Subtarget->isNeonAvailable() || VecSize > 128))) {
16652 UseScalable = true;
16653 return true;
16654 }
16655 }
16656
16657 // Ensure the total vector size is 64 or a multiple of 128. Types larger than
16658 // 128 will be split into multiple interleaved accesses.
16659 return Subtarget->isNeonAvailable() && (VecSize == 64 || VecSize % 128 == 0);
16660 }
16661
getSVEContainerIRType(FixedVectorType * VTy)16662 static ScalableVectorType *getSVEContainerIRType(FixedVectorType *VTy) {
16663 if (VTy->getElementType() == Type::getDoubleTy(VTy->getContext()))
16664 return ScalableVectorType::get(VTy->getElementType(), 2);
16665
16666 if (VTy->getElementType() == Type::getFloatTy(VTy->getContext()))
16667 return ScalableVectorType::get(VTy->getElementType(), 4);
16668
16669 if (VTy->getElementType() == Type::getBFloatTy(VTy->getContext()))
16670 return ScalableVectorType::get(VTy->getElementType(), 8);
16671
16672 if (VTy->getElementType() == Type::getHalfTy(VTy->getContext()))
16673 return ScalableVectorType::get(VTy->getElementType(), 8);
16674
16675 if (VTy->getElementType() == Type::getInt64Ty(VTy->getContext()))
16676 return ScalableVectorType::get(VTy->getElementType(), 2);
16677
16678 if (VTy->getElementType() == Type::getInt32Ty(VTy->getContext()))
16679 return ScalableVectorType::get(VTy->getElementType(), 4);
16680
16681 if (VTy->getElementType() == Type::getInt16Ty(VTy->getContext()))
16682 return ScalableVectorType::get(VTy->getElementType(), 8);
16683
16684 if (VTy->getElementType() == Type::getInt8Ty(VTy->getContext()))
16685 return ScalableVectorType::get(VTy->getElementType(), 16);
16686
16687 llvm_unreachable("Cannot handle input vector type");
16688 }
16689
getStructuredLoadFunction(Module * M,unsigned Factor,bool Scalable,Type * LDVTy,Type * PtrTy)16690 static Function *getStructuredLoadFunction(Module *M, unsigned Factor,
16691 bool Scalable, Type *LDVTy,
16692 Type *PtrTy) {
16693 assert(Factor >= 2 && Factor <= 4 && "Invalid interleave factor");
16694 static const Intrinsic::ID SVELoads[3] = {Intrinsic::aarch64_sve_ld2_sret,
16695 Intrinsic::aarch64_sve_ld3_sret,
16696 Intrinsic::aarch64_sve_ld4_sret};
16697 static const Intrinsic::ID NEONLoads[3] = {Intrinsic::aarch64_neon_ld2,
16698 Intrinsic::aarch64_neon_ld3,
16699 Intrinsic::aarch64_neon_ld4};
16700 if (Scalable)
16701 return Intrinsic::getDeclaration(M, SVELoads[Factor - 2], {LDVTy});
16702
16703 return Intrinsic::getDeclaration(M, NEONLoads[Factor - 2], {LDVTy, PtrTy});
16704 }
16705
getStructuredStoreFunction(Module * M,unsigned Factor,bool Scalable,Type * STVTy,Type * PtrTy)16706 static Function *getStructuredStoreFunction(Module *M, unsigned Factor,
16707 bool Scalable, Type *STVTy,
16708 Type *PtrTy) {
16709 assert(Factor >= 2 && Factor <= 4 && "Invalid interleave factor");
16710 static const Intrinsic::ID SVEStores[3] = {Intrinsic::aarch64_sve_st2,
16711 Intrinsic::aarch64_sve_st3,
16712 Intrinsic::aarch64_sve_st4};
16713 static const Intrinsic::ID NEONStores[3] = {Intrinsic::aarch64_neon_st2,
16714 Intrinsic::aarch64_neon_st3,
16715 Intrinsic::aarch64_neon_st4};
16716 if (Scalable)
16717 return Intrinsic::getDeclaration(M, SVEStores[Factor - 2], {STVTy});
16718
16719 return Intrinsic::getDeclaration(M, NEONStores[Factor - 2], {STVTy, PtrTy});
16720 }
16721
16722 /// Lower an interleaved load into a ldN intrinsic.
16723 ///
16724 /// E.g. Lower an interleaved load (Factor = 2):
16725 /// %wide.vec = load <8 x i32>, <8 x i32>* %ptr
16726 /// %v0 = shuffle %wide.vec, undef, <0, 2, 4, 6> ; Extract even elements
16727 /// %v1 = shuffle %wide.vec, undef, <1, 3, 5, 7> ; Extract odd elements
16728 ///
16729 /// Into:
16730 /// %ld2 = { <4 x i32>, <4 x i32> } call llvm.aarch64.neon.ld2(%ptr)
16731 /// %vec0 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 0
16732 /// %vec1 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 1
lowerInterleavedLoad(LoadInst * LI,ArrayRef<ShuffleVectorInst * > Shuffles,ArrayRef<unsigned> Indices,unsigned Factor) const16733 bool AArch64TargetLowering::lowerInterleavedLoad(
16734 LoadInst *LI, ArrayRef<ShuffleVectorInst *> Shuffles,
16735 ArrayRef<unsigned> Indices, unsigned Factor) const {
16736 assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() &&
16737 "Invalid interleave factor");
16738 assert(!Shuffles.empty() && "Empty shufflevector input");
16739 assert(Shuffles.size() == Indices.size() &&
16740 "Unmatched number of shufflevectors and indices");
16741
16742 const DataLayout &DL = LI->getDataLayout();
16743
16744 VectorType *VTy = Shuffles[0]->getType();
16745
16746 // Skip if we do not have NEON and skip illegal vector types. We can
16747 // "legalize" wide vector types into multiple interleaved accesses as long as
16748 // the vector types are divisible by 128.
16749 bool UseScalable;
16750 if (!isLegalInterleavedAccessType(VTy, DL, UseScalable))
16751 return false;
16752
16753 unsigned NumLoads = getNumInterleavedAccesses(VTy, DL, UseScalable);
16754
16755 auto *FVTy = cast<FixedVectorType>(VTy);
16756
16757 // A pointer vector can not be the return type of the ldN intrinsics. Need to
16758 // load integer vectors first and then convert to pointer vectors.
16759 Type *EltTy = FVTy->getElementType();
16760 if (EltTy->isPointerTy())
16761 FVTy =
16762 FixedVectorType::get(DL.getIntPtrType(EltTy), FVTy->getNumElements());
16763
16764 // If we're going to generate more than one load, reset the sub-vector type
16765 // to something legal.
16766 FVTy = FixedVectorType::get(FVTy->getElementType(),
16767 FVTy->getNumElements() / NumLoads);
16768
16769 auto *LDVTy =
16770 UseScalable ? cast<VectorType>(getSVEContainerIRType(FVTy)) : FVTy;
16771
16772 IRBuilder<> Builder(LI);
16773
16774 // The base address of the load.
16775 Value *BaseAddr = LI->getPointerOperand();
16776
16777 Type *PtrTy = LI->getPointerOperandType();
16778 Type *PredTy = VectorType::get(Type::getInt1Ty(LDVTy->getContext()),
16779 LDVTy->getElementCount());
16780
16781 Function *LdNFunc = getStructuredLoadFunction(LI->getModule(), Factor,
16782 UseScalable, LDVTy, PtrTy);
16783
16784 // Holds sub-vectors extracted from the load intrinsic return values. The
16785 // sub-vectors are associated with the shufflevector instructions they will
16786 // replace.
16787 DenseMap<ShuffleVectorInst *, SmallVector<Value *, 4>> SubVecs;
16788
16789 Value *PTrue = nullptr;
16790 if (UseScalable) {
16791 std::optional<unsigned> PgPattern =
16792 getSVEPredPatternFromNumElements(FVTy->getNumElements());
16793 if (Subtarget->getMinSVEVectorSizeInBits() ==
16794 Subtarget->getMaxSVEVectorSizeInBits() &&
16795 Subtarget->getMinSVEVectorSizeInBits() == DL.getTypeSizeInBits(FVTy))
16796 PgPattern = AArch64SVEPredPattern::all;
16797
16798 auto *PTruePat =
16799 ConstantInt::get(Type::getInt32Ty(LDVTy->getContext()), *PgPattern);
16800 PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, {PredTy},
16801 {PTruePat});
16802 }
16803
16804 for (unsigned LoadCount = 0; LoadCount < NumLoads; ++LoadCount) {
16805
16806 // If we're generating more than one load, compute the base address of
16807 // subsequent loads as an offset from the previous.
16808 if (LoadCount > 0)
16809 BaseAddr = Builder.CreateConstGEP1_32(LDVTy->getElementType(), BaseAddr,
16810 FVTy->getNumElements() * Factor);
16811
16812 CallInst *LdN;
16813 if (UseScalable)
16814 LdN = Builder.CreateCall(LdNFunc, {PTrue, BaseAddr}, "ldN");
16815 else
16816 LdN = Builder.CreateCall(LdNFunc, BaseAddr, "ldN");
16817
16818 // Extract and store the sub-vectors returned by the load intrinsic.
16819 for (unsigned i = 0; i < Shuffles.size(); i++) {
16820 ShuffleVectorInst *SVI = Shuffles[i];
16821 unsigned Index = Indices[i];
16822
16823 Value *SubVec = Builder.CreateExtractValue(LdN, Index);
16824
16825 if (UseScalable)
16826 SubVec = Builder.CreateExtractVector(
16827 FVTy, SubVec,
16828 ConstantInt::get(Type::getInt64Ty(VTy->getContext()), 0));
16829
16830 // Convert the integer vector to pointer vector if the element is pointer.
16831 if (EltTy->isPointerTy())
16832 SubVec = Builder.CreateIntToPtr(
16833 SubVec, FixedVectorType::get(SVI->getType()->getElementType(),
16834 FVTy->getNumElements()));
16835
16836 SubVecs[SVI].push_back(SubVec);
16837 }
16838 }
16839
16840 // Replace uses of the shufflevector instructions with the sub-vectors
16841 // returned by the load intrinsic. If a shufflevector instruction is
16842 // associated with more than one sub-vector, those sub-vectors will be
16843 // concatenated into a single wide vector.
16844 for (ShuffleVectorInst *SVI : Shuffles) {
16845 auto &SubVec = SubVecs[SVI];
16846 auto *WideVec =
16847 SubVec.size() > 1 ? concatenateVectors(Builder, SubVec) : SubVec[0];
16848 SVI->replaceAllUsesWith(WideVec);
16849 }
16850
16851 return true;
16852 }
16853
16854 template <typename Iter>
hasNearbyPairedStore(Iter It,Iter End,Value * Ptr,const DataLayout & DL)16855 bool hasNearbyPairedStore(Iter It, Iter End, Value *Ptr, const DataLayout &DL) {
16856 int MaxLookupDist = 20;
16857 unsigned IdxWidth = DL.getIndexSizeInBits(0);
16858 APInt OffsetA(IdxWidth, 0), OffsetB(IdxWidth, 0);
16859 const Value *PtrA1 =
16860 Ptr->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA);
16861
16862 while (++It != End) {
16863 if (It->isDebugOrPseudoInst())
16864 continue;
16865 if (MaxLookupDist-- == 0)
16866 break;
16867 if (const auto *SI = dyn_cast<StoreInst>(&*It)) {
16868 const Value *PtrB1 =
16869 SI->getPointerOperand()->stripAndAccumulateInBoundsConstantOffsets(
16870 DL, OffsetB);
16871 if (PtrA1 == PtrB1 &&
16872 (OffsetA.sextOrTrunc(IdxWidth) - OffsetB.sextOrTrunc(IdxWidth))
16873 .abs() == 16)
16874 return true;
16875 }
16876 }
16877
16878 return false;
16879 }
16880
16881 /// Lower an interleaved store into a stN intrinsic.
16882 ///
16883 /// E.g. Lower an interleaved store (Factor = 3):
16884 /// %i.vec = shuffle <8 x i32> %v0, <8 x i32> %v1,
16885 /// <0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11>
16886 /// store <12 x i32> %i.vec, <12 x i32>* %ptr
16887 ///
16888 /// Into:
16889 /// %sub.v0 = shuffle <8 x i32> %v0, <8 x i32> v1, <0, 1, 2, 3>
16890 /// %sub.v1 = shuffle <8 x i32> %v0, <8 x i32> v1, <4, 5, 6, 7>
16891 /// %sub.v2 = shuffle <8 x i32> %v0, <8 x i32> v1, <8, 9, 10, 11>
16892 /// call void llvm.aarch64.neon.st3(%sub.v0, %sub.v1, %sub.v2, %ptr)
16893 ///
16894 /// Note that the new shufflevectors will be removed and we'll only generate one
16895 /// st3 instruction in CodeGen.
16896 ///
16897 /// Example for a more general valid mask (Factor 3). Lower:
16898 /// %i.vec = shuffle <32 x i32> %v0, <32 x i32> %v1,
16899 /// <4, 32, 16, 5, 33, 17, 6, 34, 18, 7, 35, 19>
16900 /// store <12 x i32> %i.vec, <12 x i32>* %ptr
16901 ///
16902 /// Into:
16903 /// %sub.v0 = shuffle <32 x i32> %v0, <32 x i32> v1, <4, 5, 6, 7>
16904 /// %sub.v1 = shuffle <32 x i32> %v0, <32 x i32> v1, <32, 33, 34, 35>
16905 /// %sub.v2 = shuffle <32 x i32> %v0, <32 x i32> v1, <16, 17, 18, 19>
16906 /// call void llvm.aarch64.neon.st3(%sub.v0, %sub.v1, %sub.v2, %ptr)
lowerInterleavedStore(StoreInst * SI,ShuffleVectorInst * SVI,unsigned Factor) const16907 bool AArch64TargetLowering::lowerInterleavedStore(StoreInst *SI,
16908 ShuffleVectorInst *SVI,
16909 unsigned Factor) const {
16910
16911 assert(Factor >= 2 && Factor <= getMaxSupportedInterleaveFactor() &&
16912 "Invalid interleave factor");
16913
16914 auto *VecTy = cast<FixedVectorType>(SVI->getType());
16915 assert(VecTy->getNumElements() % Factor == 0 && "Invalid interleaved store");
16916
16917 unsigned LaneLen = VecTy->getNumElements() / Factor;
16918 Type *EltTy = VecTy->getElementType();
16919 auto *SubVecTy = FixedVectorType::get(EltTy, LaneLen);
16920
16921 const DataLayout &DL = SI->getDataLayout();
16922 bool UseScalable;
16923
16924 // Skip if we do not have NEON and skip illegal vector types. We can
16925 // "legalize" wide vector types into multiple interleaved accesses as long as
16926 // the vector types are divisible by 128.
16927 if (!isLegalInterleavedAccessType(SubVecTy, DL, UseScalable))
16928 return false;
16929
16930 unsigned NumStores = getNumInterleavedAccesses(SubVecTy, DL, UseScalable);
16931
16932 Value *Op0 = SVI->getOperand(0);
16933 Value *Op1 = SVI->getOperand(1);
16934 IRBuilder<> Builder(SI);
16935
16936 // StN intrinsics don't support pointer vectors as arguments. Convert pointer
16937 // vectors to integer vectors.
16938 if (EltTy->isPointerTy()) {
16939 Type *IntTy = DL.getIntPtrType(EltTy);
16940 unsigned NumOpElts =
16941 cast<FixedVectorType>(Op0->getType())->getNumElements();
16942
16943 // Convert to the corresponding integer vector.
16944 auto *IntVecTy = FixedVectorType::get(IntTy, NumOpElts);
16945 Op0 = Builder.CreatePtrToInt(Op0, IntVecTy);
16946 Op1 = Builder.CreatePtrToInt(Op1, IntVecTy);
16947
16948 SubVecTy = FixedVectorType::get(IntTy, LaneLen);
16949 }
16950
16951 // If we're going to generate more than one store, reset the lane length
16952 // and sub-vector type to something legal.
16953 LaneLen /= NumStores;
16954 SubVecTy = FixedVectorType::get(SubVecTy->getElementType(), LaneLen);
16955
16956 auto *STVTy = UseScalable ? cast<VectorType>(getSVEContainerIRType(SubVecTy))
16957 : SubVecTy;
16958
16959 // The base address of the store.
16960 Value *BaseAddr = SI->getPointerOperand();
16961
16962 auto Mask = SVI->getShuffleMask();
16963
16964 // Sanity check if all the indices are NOT in range.
16965 // If mask is `poison`, `Mask` may be a vector of -1s.
16966 // If all of them are `poison`, OOB read will happen later.
16967 if (llvm::all_of(Mask, [](int Idx) { return Idx == PoisonMaskElem; })) {
16968 return false;
16969 }
16970 // A 64bit st2 which does not start at element 0 will involved adding extra
16971 // ext elements making the st2 unprofitable, and if there is a nearby store
16972 // that points to BaseAddr+16 or BaseAddr-16 then it can be better left as a
16973 // zip;ldp pair which has higher throughput.
16974 if (Factor == 2 && SubVecTy->getPrimitiveSizeInBits() == 64 &&
16975 (Mask[0] != 0 ||
16976 hasNearbyPairedStore(SI->getIterator(), SI->getParent()->end(), BaseAddr,
16977 DL) ||
16978 hasNearbyPairedStore(SI->getReverseIterator(), SI->getParent()->rend(),
16979 BaseAddr, DL)))
16980 return false;
16981
16982 Type *PtrTy = SI->getPointerOperandType();
16983 Type *PredTy = VectorType::get(Type::getInt1Ty(STVTy->getContext()),
16984 STVTy->getElementCount());
16985
16986 Function *StNFunc = getStructuredStoreFunction(SI->getModule(), Factor,
16987 UseScalable, STVTy, PtrTy);
16988
16989 Value *PTrue = nullptr;
16990 if (UseScalable) {
16991 std::optional<unsigned> PgPattern =
16992 getSVEPredPatternFromNumElements(SubVecTy->getNumElements());
16993 if (Subtarget->getMinSVEVectorSizeInBits() ==
16994 Subtarget->getMaxSVEVectorSizeInBits() &&
16995 Subtarget->getMinSVEVectorSizeInBits() ==
16996 DL.getTypeSizeInBits(SubVecTy))
16997 PgPattern = AArch64SVEPredPattern::all;
16998
16999 auto *PTruePat =
17000 ConstantInt::get(Type::getInt32Ty(STVTy->getContext()), *PgPattern);
17001 PTrue = Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, {PredTy},
17002 {PTruePat});
17003 }
17004
17005 for (unsigned StoreCount = 0; StoreCount < NumStores; ++StoreCount) {
17006
17007 SmallVector<Value *, 5> Ops;
17008
17009 // Split the shufflevector operands into sub vectors for the new stN call.
17010 for (unsigned i = 0; i < Factor; i++) {
17011 Value *Shuffle;
17012 unsigned IdxI = StoreCount * LaneLen * Factor + i;
17013 if (Mask[IdxI] >= 0) {
17014 Shuffle = Builder.CreateShuffleVector(
17015 Op0, Op1, createSequentialMask(Mask[IdxI], LaneLen, 0));
17016 } else {
17017 unsigned StartMask = 0;
17018 for (unsigned j = 1; j < LaneLen; j++) {
17019 unsigned IdxJ = StoreCount * LaneLen * Factor + j * Factor + i;
17020 if (Mask[IdxJ] >= 0) {
17021 StartMask = Mask[IdxJ] - j;
17022 break;
17023 }
17024 }
17025 // Note: Filling undef gaps with random elements is ok, since
17026 // those elements were being written anyway (with undefs).
17027 // In the case of all undefs we're defaulting to using elems from 0
17028 // Note: StartMask cannot be negative, it's checked in
17029 // isReInterleaveMask
17030 Shuffle = Builder.CreateShuffleVector(
17031 Op0, Op1, createSequentialMask(StartMask, LaneLen, 0));
17032 }
17033
17034 if (UseScalable)
17035 Shuffle = Builder.CreateInsertVector(
17036 STVTy, UndefValue::get(STVTy), Shuffle,
17037 ConstantInt::get(Type::getInt64Ty(STVTy->getContext()), 0));
17038
17039 Ops.push_back(Shuffle);
17040 }
17041
17042 if (UseScalable)
17043 Ops.push_back(PTrue);
17044
17045 // If we generating more than one store, we compute the base address of
17046 // subsequent stores as an offset from the previous.
17047 if (StoreCount > 0)
17048 BaseAddr = Builder.CreateConstGEP1_32(SubVecTy->getElementType(),
17049 BaseAddr, LaneLen * Factor);
17050
17051 Ops.push_back(BaseAddr);
17052 Builder.CreateCall(StNFunc, Ops);
17053 }
17054 return true;
17055 }
17056
lowerDeinterleaveIntrinsicToLoad(IntrinsicInst * DI,LoadInst * LI) const17057 bool AArch64TargetLowering::lowerDeinterleaveIntrinsicToLoad(
17058 IntrinsicInst *DI, LoadInst *LI) const {
17059 // Only deinterleave2 supported at present.
17060 if (DI->getIntrinsicID() != Intrinsic::vector_deinterleave2)
17061 return false;
17062
17063 // Only a factor of 2 supported at present.
17064 const unsigned Factor = 2;
17065
17066 VectorType *VTy = cast<VectorType>(DI->getType()->getContainedType(0));
17067 const DataLayout &DL = DI->getDataLayout();
17068 bool UseScalable;
17069 if (!isLegalInterleavedAccessType(VTy, DL, UseScalable))
17070 return false;
17071
17072 // TODO: Add support for using SVE instructions with fixed types later, using
17073 // the code from lowerInterleavedLoad to obtain the correct container type.
17074 if (UseScalable && !VTy->isScalableTy())
17075 return false;
17076
17077 unsigned NumLoads = getNumInterleavedAccesses(VTy, DL, UseScalable);
17078
17079 VectorType *LdTy =
17080 VectorType::get(VTy->getElementType(),
17081 VTy->getElementCount().divideCoefficientBy(NumLoads));
17082
17083 Type *PtrTy = LI->getPointerOperandType();
17084 Function *LdNFunc = getStructuredLoadFunction(DI->getModule(), Factor,
17085 UseScalable, LdTy, PtrTy);
17086
17087 IRBuilder<> Builder(LI);
17088
17089 Value *Pred = nullptr;
17090 if (UseScalable)
17091 Pred =
17092 Builder.CreateVectorSplat(LdTy->getElementCount(), Builder.getTrue());
17093
17094 Value *BaseAddr = LI->getPointerOperand();
17095 Value *Result;
17096 if (NumLoads > 1) {
17097 Value *Left = PoisonValue::get(VTy);
17098 Value *Right = PoisonValue::get(VTy);
17099
17100 for (unsigned I = 0; I < NumLoads; ++I) {
17101 Value *Offset = Builder.getInt64(I * Factor);
17102
17103 Value *Address = Builder.CreateGEP(LdTy, BaseAddr, {Offset});
17104 Value *LdN = nullptr;
17105 if (UseScalable)
17106 LdN = Builder.CreateCall(LdNFunc, {Pred, Address}, "ldN");
17107 else
17108 LdN = Builder.CreateCall(LdNFunc, Address, "ldN");
17109
17110 Value *Idx =
17111 Builder.getInt64(I * LdTy->getElementCount().getKnownMinValue());
17112 Left = Builder.CreateInsertVector(
17113 VTy, Left, Builder.CreateExtractValue(LdN, 0), Idx);
17114 Right = Builder.CreateInsertVector(
17115 VTy, Right, Builder.CreateExtractValue(LdN, 1), Idx);
17116 }
17117
17118 Result = PoisonValue::get(DI->getType());
17119 Result = Builder.CreateInsertValue(Result, Left, 0);
17120 Result = Builder.CreateInsertValue(Result, Right, 1);
17121 } else {
17122 if (UseScalable)
17123 Result = Builder.CreateCall(LdNFunc, {Pred, BaseAddr}, "ldN");
17124 else
17125 Result = Builder.CreateCall(LdNFunc, BaseAddr, "ldN");
17126 }
17127
17128 DI->replaceAllUsesWith(Result);
17129 return true;
17130 }
17131
lowerInterleaveIntrinsicToStore(IntrinsicInst * II,StoreInst * SI) const17132 bool AArch64TargetLowering::lowerInterleaveIntrinsicToStore(
17133 IntrinsicInst *II, StoreInst *SI) const {
17134 // Only interleave2 supported at present.
17135 if (II->getIntrinsicID() != Intrinsic::vector_interleave2)
17136 return false;
17137
17138 // Only a factor of 2 supported at present.
17139 const unsigned Factor = 2;
17140
17141 VectorType *VTy = cast<VectorType>(II->getOperand(0)->getType());
17142 const DataLayout &DL = II->getDataLayout();
17143 bool UseScalable;
17144 if (!isLegalInterleavedAccessType(VTy, DL, UseScalable))
17145 return false;
17146
17147 // TODO: Add support for using SVE instructions with fixed types later, using
17148 // the code from lowerInterleavedStore to obtain the correct container type.
17149 if (UseScalable && !VTy->isScalableTy())
17150 return false;
17151
17152 unsigned NumStores = getNumInterleavedAccesses(VTy, DL, UseScalable);
17153
17154 VectorType *StTy =
17155 VectorType::get(VTy->getElementType(),
17156 VTy->getElementCount().divideCoefficientBy(NumStores));
17157
17158 Type *PtrTy = SI->getPointerOperandType();
17159 Function *StNFunc = getStructuredStoreFunction(SI->getModule(), Factor,
17160 UseScalable, StTy, PtrTy);
17161
17162 IRBuilder<> Builder(SI);
17163
17164 Value *BaseAddr = SI->getPointerOperand();
17165 Value *Pred = nullptr;
17166
17167 if (UseScalable)
17168 Pred =
17169 Builder.CreateVectorSplat(StTy->getElementCount(), Builder.getTrue());
17170
17171 Value *L = II->getOperand(0);
17172 Value *R = II->getOperand(1);
17173
17174 for (unsigned I = 0; I < NumStores; ++I) {
17175 Value *Address = BaseAddr;
17176 if (NumStores > 1) {
17177 Value *Offset = Builder.getInt64(I * Factor);
17178 Address = Builder.CreateGEP(StTy, BaseAddr, {Offset});
17179
17180 Value *Idx =
17181 Builder.getInt64(I * StTy->getElementCount().getKnownMinValue());
17182 L = Builder.CreateExtractVector(StTy, II->getOperand(0), Idx);
17183 R = Builder.CreateExtractVector(StTy, II->getOperand(1), Idx);
17184 }
17185
17186 if (UseScalable)
17187 Builder.CreateCall(StNFunc, {L, R, Pred, Address});
17188 else
17189 Builder.CreateCall(StNFunc, {L, R, Address});
17190 }
17191
17192 return true;
17193 }
17194
getOptimalMemOpType(const MemOp & Op,const AttributeList & FuncAttributes) const17195 EVT AArch64TargetLowering::getOptimalMemOpType(
17196 const MemOp &Op, const AttributeList &FuncAttributes) const {
17197 bool CanImplicitFloat = !FuncAttributes.hasFnAttr(Attribute::NoImplicitFloat);
17198 bool CanUseNEON = Subtarget->hasNEON() && CanImplicitFloat;
17199 bool CanUseFP = Subtarget->hasFPARMv8() && CanImplicitFloat;
17200 // Only use AdvSIMD to implement memset of 32-byte and above. It would have
17201 // taken one instruction to materialize the v2i64 zero and one store (with
17202 // restrictive addressing mode). Just do i64 stores.
17203 bool IsSmallMemset = Op.isMemset() && Op.size() < 32;
17204 auto AlignmentIsAcceptable = [&](EVT VT, Align AlignCheck) {
17205 if (Op.isAligned(AlignCheck))
17206 return true;
17207 unsigned Fast;
17208 return allowsMisalignedMemoryAccesses(VT, 0, Align(1),
17209 MachineMemOperand::MONone, &Fast) &&
17210 Fast;
17211 };
17212
17213 if (CanUseNEON && Op.isMemset() && !IsSmallMemset &&
17214 AlignmentIsAcceptable(MVT::v16i8, Align(16)))
17215 return MVT::v16i8;
17216 if (CanUseFP && !IsSmallMemset && AlignmentIsAcceptable(MVT::f128, Align(16)))
17217 return MVT::f128;
17218 if (Op.size() >= 8 && AlignmentIsAcceptable(MVT::i64, Align(8)))
17219 return MVT::i64;
17220 if (Op.size() >= 4 && AlignmentIsAcceptable(MVT::i32, Align(4)))
17221 return MVT::i32;
17222 return MVT::Other;
17223 }
17224
getOptimalMemOpLLT(const MemOp & Op,const AttributeList & FuncAttributes) const17225 LLT AArch64TargetLowering::getOptimalMemOpLLT(
17226 const MemOp &Op, const AttributeList &FuncAttributes) const {
17227 bool CanImplicitFloat = !FuncAttributes.hasFnAttr(Attribute::NoImplicitFloat);
17228 bool CanUseNEON = Subtarget->hasNEON() && CanImplicitFloat;
17229 bool CanUseFP = Subtarget->hasFPARMv8() && CanImplicitFloat;
17230 // Only use AdvSIMD to implement memset of 32-byte and above. It would have
17231 // taken one instruction to materialize the v2i64 zero and one store (with
17232 // restrictive addressing mode). Just do i64 stores.
17233 bool IsSmallMemset = Op.isMemset() && Op.size() < 32;
17234 auto AlignmentIsAcceptable = [&](EVT VT, Align AlignCheck) {
17235 if (Op.isAligned(AlignCheck))
17236 return true;
17237 unsigned Fast;
17238 return allowsMisalignedMemoryAccesses(VT, 0, Align(1),
17239 MachineMemOperand::MONone, &Fast) &&
17240 Fast;
17241 };
17242
17243 if (CanUseNEON && Op.isMemset() && !IsSmallMemset &&
17244 AlignmentIsAcceptable(MVT::v2i64, Align(16)))
17245 return LLT::fixed_vector(2, 64);
17246 if (CanUseFP && !IsSmallMemset && AlignmentIsAcceptable(MVT::f128, Align(16)))
17247 return LLT::scalar(128);
17248 if (Op.size() >= 8 && AlignmentIsAcceptable(MVT::i64, Align(8)))
17249 return LLT::scalar(64);
17250 if (Op.size() >= 4 && AlignmentIsAcceptable(MVT::i32, Align(4)))
17251 return LLT::scalar(32);
17252 return LLT();
17253 }
17254
17255 // 12-bit optionally shifted immediates are legal for adds.
isLegalAddImmediate(int64_t Immed) const17256 bool AArch64TargetLowering::isLegalAddImmediate(int64_t Immed) const {
17257 if (Immed == std::numeric_limits<int64_t>::min()) {
17258 LLVM_DEBUG(dbgs() << "Illegal add imm " << Immed
17259 << ": avoid UB for INT64_MIN\n");
17260 return false;
17261 }
17262 // Same encoding for add/sub, just flip the sign.
17263 Immed = std::abs(Immed);
17264 bool IsLegal = ((Immed >> 12) == 0 ||
17265 ((Immed & 0xfff) == 0 && Immed >> 24 == 0));
17266 LLVM_DEBUG(dbgs() << "Is " << Immed
17267 << " legal add imm: " << (IsLegal ? "yes" : "no") << "\n");
17268 return IsLegal;
17269 }
17270
isLegalAddScalableImmediate(int64_t Imm) const17271 bool AArch64TargetLowering::isLegalAddScalableImmediate(int64_t Imm) const {
17272 // We will only emit addvl/inc* instructions for SVE2
17273 if (!Subtarget->hasSVE2())
17274 return false;
17275
17276 // addvl's immediates are in terms of the number of bytes in a register.
17277 // Since there are 16 in the base supported size (128bits), we need to
17278 // divide the immediate by that much to give us a useful immediate to
17279 // multiply by vscale. We can't have a remainder as a result of this.
17280 if (Imm % 16 == 0)
17281 return isInt<6>(Imm / 16);
17282
17283 // Inc[b|h|w|d] instructions take a pattern and a positive immediate
17284 // multiplier. For now, assume a pattern of 'all'. Incb would be a subset
17285 // of addvl as a result, so only take h|w|d into account.
17286 // Dec[h|w|d] will cover subtractions.
17287 // Immediates are in the range [1,16], so we can't do a 2's complement check.
17288 // FIXME: Can we make use of other patterns to cover other immediates?
17289
17290 // inch|dech
17291 if (Imm % 8 == 0)
17292 return std::abs(Imm / 8) <= 16;
17293 // incw|decw
17294 if (Imm % 4 == 0)
17295 return std::abs(Imm / 4) <= 16;
17296 // incd|decd
17297 if (Imm % 2 == 0)
17298 return std::abs(Imm / 2) <= 16;
17299
17300 return false;
17301 }
17302
17303 // Return false to prevent folding
17304 // (mul (add x, c1), c2) -> (add (mul x, c2), c2*c1) in DAGCombine,
17305 // if the folding leads to worse code.
isMulAddWithConstProfitable(SDValue AddNode,SDValue ConstNode) const17306 bool AArch64TargetLowering::isMulAddWithConstProfitable(
17307 SDValue AddNode, SDValue ConstNode) const {
17308 // Let the DAGCombiner decide for vector types and large types.
17309 const EVT VT = AddNode.getValueType();
17310 if (VT.isVector() || VT.getScalarSizeInBits() > 64)
17311 return true;
17312
17313 // It is worse if c1 is legal add immediate, while c1*c2 is not
17314 // and has to be composed by at least two instructions.
17315 const ConstantSDNode *C1Node = cast<ConstantSDNode>(AddNode.getOperand(1));
17316 const ConstantSDNode *C2Node = cast<ConstantSDNode>(ConstNode);
17317 const int64_t C1 = C1Node->getSExtValue();
17318 const APInt C1C2 = C1Node->getAPIntValue() * C2Node->getAPIntValue();
17319 if (!isLegalAddImmediate(C1) || isLegalAddImmediate(C1C2.getSExtValue()))
17320 return true;
17321 SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn;
17322 // Adapt to the width of a register.
17323 unsigned BitSize = VT.getSizeInBits() <= 32 ? 32 : 64;
17324 AArch64_IMM::expandMOVImm(C1C2.getZExtValue(), BitSize, Insn);
17325 if (Insn.size() > 1)
17326 return false;
17327
17328 // Default to true and let the DAGCombiner decide.
17329 return true;
17330 }
17331
17332 // Integer comparisons are implemented with ADDS/SUBS, so the range of valid
17333 // immediates is the same as for an add or a sub.
isLegalICmpImmediate(int64_t Immed) const17334 bool AArch64TargetLowering::isLegalICmpImmediate(int64_t Immed) const {
17335 return isLegalAddImmediate(Immed);
17336 }
17337
17338 /// isLegalAddressingMode - Return true if the addressing mode represented
17339 /// by AM is legal for this target, for a load/store of the specified type.
isLegalAddressingMode(const DataLayout & DL,const AddrMode & AMode,Type * Ty,unsigned AS,Instruction * I) const17340 bool AArch64TargetLowering::isLegalAddressingMode(const DataLayout &DL,
17341 const AddrMode &AMode, Type *Ty,
17342 unsigned AS, Instruction *I) const {
17343 // AArch64 has five basic addressing modes:
17344 // reg
17345 // reg + 9-bit signed offset
17346 // reg + SIZE_IN_BYTES * 12-bit unsigned offset
17347 // reg1 + reg2
17348 // reg + SIZE_IN_BYTES * reg
17349
17350 // No global is ever allowed as a base.
17351 if (AMode.BaseGV)
17352 return false;
17353
17354 // No reg+reg+imm addressing.
17355 if (AMode.HasBaseReg && AMode.BaseOffs && AMode.Scale)
17356 return false;
17357
17358 // Canonicalise `1*ScaledReg + imm` into `BaseReg + imm` and
17359 // `2*ScaledReg` into `BaseReg + ScaledReg`
17360 AddrMode AM = AMode;
17361 if (AM.Scale && !AM.HasBaseReg) {
17362 if (AM.Scale == 1) {
17363 AM.HasBaseReg = true;
17364 AM.Scale = 0;
17365 } else if (AM.Scale == 2) {
17366 AM.HasBaseReg = true;
17367 AM.Scale = 1;
17368 } else {
17369 return false;
17370 }
17371 }
17372
17373 // A base register is required in all addressing modes.
17374 if (!AM.HasBaseReg)
17375 return false;
17376
17377 if (Ty->isScalableTy()) {
17378 if (isa<ScalableVectorType>(Ty)) {
17379 // See if we have a foldable vscale-based offset, for vector types which
17380 // are either legal or smaller than the minimum; more work will be
17381 // required if we need to consider addressing for types which need
17382 // legalization by splitting.
17383 uint64_t VecNumBytes = DL.getTypeSizeInBits(Ty).getKnownMinValue() / 8;
17384 if (AM.HasBaseReg && !AM.BaseOffs && AM.ScalableOffset && !AM.Scale &&
17385 (AM.ScalableOffset % VecNumBytes == 0) && VecNumBytes <= 16 &&
17386 isPowerOf2_64(VecNumBytes))
17387 return isInt<4>(AM.ScalableOffset / (int64_t)VecNumBytes);
17388
17389 uint64_t VecElemNumBytes =
17390 DL.getTypeSizeInBits(cast<VectorType>(Ty)->getElementType()) / 8;
17391 return AM.HasBaseReg && !AM.BaseOffs && !AM.ScalableOffset &&
17392 (AM.Scale == 0 || (uint64_t)AM.Scale == VecElemNumBytes);
17393 }
17394
17395 return AM.HasBaseReg && !AM.BaseOffs && !AM.ScalableOffset && !AM.Scale;
17396 }
17397
17398 // No scalable offsets allowed for non-scalable types.
17399 if (AM.ScalableOffset)
17400 return false;
17401
17402 // check reg + imm case:
17403 // i.e., reg + 0, reg + imm9, reg + SIZE_IN_BYTES * uimm12
17404 uint64_t NumBytes = 0;
17405 if (Ty->isSized()) {
17406 uint64_t NumBits = DL.getTypeSizeInBits(Ty);
17407 NumBytes = NumBits / 8;
17408 if (!isPowerOf2_64(NumBits))
17409 NumBytes = 0;
17410 }
17411
17412 return Subtarget->getInstrInfo()->isLegalAddressingMode(NumBytes, AM.BaseOffs,
17413 AM.Scale);
17414 }
17415
17416 // Check whether the 2 offsets belong to the same imm24 range, and their high
17417 // 12bits are same, then their high part can be decoded with the offset of add.
17418 int64_t
getPreferredLargeGEPBaseOffset(int64_t MinOffset,int64_t MaxOffset) const17419 AArch64TargetLowering::getPreferredLargeGEPBaseOffset(int64_t MinOffset,
17420 int64_t MaxOffset) const {
17421 int64_t HighPart = MinOffset & ~0xfffULL;
17422 if (MinOffset >> 12 == MaxOffset >> 12 && isLegalAddImmediate(HighPart)) {
17423 // Rebase the value to an integer multiple of imm12.
17424 return HighPart;
17425 }
17426
17427 return 0;
17428 }
17429
shouldConsiderGEPOffsetSplit() const17430 bool AArch64TargetLowering::shouldConsiderGEPOffsetSplit() const {
17431 // Consider splitting large offset of struct or array.
17432 return true;
17433 }
17434
isFMAFasterThanFMulAndFAdd(const MachineFunction & MF,EVT VT) const17435 bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(
17436 const MachineFunction &MF, EVT VT) const {
17437 VT = VT.getScalarType();
17438
17439 if (!VT.isSimple())
17440 return false;
17441
17442 switch (VT.getSimpleVT().SimpleTy) {
17443 case MVT::f16:
17444 return Subtarget->hasFullFP16();
17445 case MVT::f32:
17446 case MVT::f64:
17447 return true;
17448 default:
17449 break;
17450 }
17451
17452 return false;
17453 }
17454
isFMAFasterThanFMulAndFAdd(const Function & F,Type * Ty) const17455 bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(const Function &F,
17456 Type *Ty) const {
17457 switch (Ty->getScalarType()->getTypeID()) {
17458 case Type::FloatTyID:
17459 case Type::DoubleTyID:
17460 return true;
17461 default:
17462 return false;
17463 }
17464 }
17465
generateFMAsInMachineCombiner(EVT VT,CodeGenOptLevel OptLevel) const17466 bool AArch64TargetLowering::generateFMAsInMachineCombiner(
17467 EVT VT, CodeGenOptLevel OptLevel) const {
17468 return (OptLevel >= CodeGenOptLevel::Aggressive) && !VT.isScalableVector() &&
17469 !useSVEForFixedLengthVectorVT(VT);
17470 }
17471
17472 const MCPhysReg *
getScratchRegisters(CallingConv::ID) const17473 AArch64TargetLowering::getScratchRegisters(CallingConv::ID) const {
17474 // LR is a callee-save register, but we must treat it as clobbered by any call
17475 // site. Hence we include LR in the scratch registers, which are in turn added
17476 // as implicit-defs for stackmaps and patchpoints.
17477 static const MCPhysReg ScratchRegs[] = {
17478 AArch64::X16, AArch64::X17, AArch64::LR, 0
17479 };
17480 return ScratchRegs;
17481 }
17482
getRoundingControlRegisters() const17483 ArrayRef<MCPhysReg> AArch64TargetLowering::getRoundingControlRegisters() const {
17484 static const MCPhysReg RCRegs[] = {AArch64::FPCR};
17485 return RCRegs;
17486 }
17487
17488 bool
isDesirableToCommuteWithShift(const SDNode * N,CombineLevel Level) const17489 AArch64TargetLowering::isDesirableToCommuteWithShift(const SDNode *N,
17490 CombineLevel Level) const {
17491 assert((N->getOpcode() == ISD::SHL || N->getOpcode() == ISD::SRA ||
17492 N->getOpcode() == ISD::SRL) &&
17493 "Expected shift op");
17494
17495 SDValue ShiftLHS = N->getOperand(0);
17496 EVT VT = N->getValueType(0);
17497
17498 // If ShiftLHS is unsigned bit extraction: ((x >> C) & mask), then do not
17499 // combine it with shift 'N' to let it be lowered to UBFX except:
17500 // ((x >> C) & mask) << C.
17501 if (ShiftLHS.getOpcode() == ISD::AND && (VT == MVT::i32 || VT == MVT::i64) &&
17502 isa<ConstantSDNode>(ShiftLHS.getOperand(1))) {
17503 uint64_t TruncMask = ShiftLHS.getConstantOperandVal(1);
17504 if (isMask_64(TruncMask)) {
17505 SDValue AndLHS = ShiftLHS.getOperand(0);
17506 if (AndLHS.getOpcode() == ISD::SRL) {
17507 if (auto *SRLC = dyn_cast<ConstantSDNode>(AndLHS.getOperand(1))) {
17508 if (N->getOpcode() == ISD::SHL)
17509 if (auto *SHLC = dyn_cast<ConstantSDNode>(N->getOperand(1)))
17510 return SRLC->getZExtValue() == SHLC->getZExtValue();
17511 return false;
17512 }
17513 }
17514 }
17515 }
17516 return true;
17517 }
17518
isDesirableToCommuteXorWithShift(const SDNode * N) const17519 bool AArch64TargetLowering::isDesirableToCommuteXorWithShift(
17520 const SDNode *N) const {
17521 assert(N->getOpcode() == ISD::XOR &&
17522 (N->getOperand(0).getOpcode() == ISD::SHL ||
17523 N->getOperand(0).getOpcode() == ISD::SRL) &&
17524 "Expected XOR(SHIFT) pattern");
17525
17526 // Only commute if the entire NOT mask is a hidden shifted mask.
17527 auto *XorC = dyn_cast<ConstantSDNode>(N->getOperand(1));
17528 auto *ShiftC = dyn_cast<ConstantSDNode>(N->getOperand(0).getOperand(1));
17529 if (XorC && ShiftC) {
17530 unsigned MaskIdx, MaskLen;
17531 if (XorC->getAPIntValue().isShiftedMask(MaskIdx, MaskLen)) {
17532 unsigned ShiftAmt = ShiftC->getZExtValue();
17533 unsigned BitWidth = N->getValueType(0).getScalarSizeInBits();
17534 if (N->getOperand(0).getOpcode() == ISD::SHL)
17535 return MaskIdx == ShiftAmt && MaskLen == (BitWidth - ShiftAmt);
17536 return MaskIdx == 0 && MaskLen == (BitWidth - ShiftAmt);
17537 }
17538 }
17539
17540 return false;
17541 }
17542
shouldFoldConstantShiftPairToMask(const SDNode * N,CombineLevel Level) const17543 bool AArch64TargetLowering::shouldFoldConstantShiftPairToMask(
17544 const SDNode *N, CombineLevel Level) const {
17545 assert(((N->getOpcode() == ISD::SHL &&
17546 N->getOperand(0).getOpcode() == ISD::SRL) ||
17547 (N->getOpcode() == ISD::SRL &&
17548 N->getOperand(0).getOpcode() == ISD::SHL)) &&
17549 "Expected shift-shift mask");
17550 // Don't allow multiuse shift folding with the same shift amount.
17551 if (!N->getOperand(0)->hasOneUse())
17552 return false;
17553
17554 // Only fold srl(shl(x,c1),c2) iff C1 >= C2 to prevent loss of UBFX patterns.
17555 EVT VT = N->getValueType(0);
17556 if (N->getOpcode() == ISD::SRL && (VT == MVT::i32 || VT == MVT::i64)) {
17557 auto *C1 = dyn_cast<ConstantSDNode>(N->getOperand(0).getOperand(1));
17558 auto *C2 = dyn_cast<ConstantSDNode>(N->getOperand(1));
17559 return (!C1 || !C2 || C1->getZExtValue() >= C2->getZExtValue());
17560 }
17561
17562 return true;
17563 }
17564
shouldFoldSelectWithIdentityConstant(unsigned BinOpcode,EVT VT) const17565 bool AArch64TargetLowering::shouldFoldSelectWithIdentityConstant(
17566 unsigned BinOpcode, EVT VT) const {
17567 return VT.isScalableVector() && isTypeLegal(VT);
17568 }
17569
shouldConvertConstantLoadToIntImm(const APInt & Imm,Type * Ty) const17570 bool AArch64TargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm,
17571 Type *Ty) const {
17572 assert(Ty->isIntegerTy());
17573
17574 unsigned BitSize = Ty->getPrimitiveSizeInBits();
17575 if (BitSize == 0)
17576 return false;
17577
17578 int64_t Val = Imm.getSExtValue();
17579 if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, BitSize))
17580 return true;
17581
17582 if ((int64_t)Val < 0)
17583 Val = ~Val;
17584 if (BitSize == 32)
17585 Val &= (1LL << 32) - 1;
17586
17587 unsigned Shift = llvm::Log2_64((uint64_t)Val) / 16;
17588 // MOVZ is free so return true for one or fewer MOVK.
17589 return Shift < 3;
17590 }
17591
isExtractSubvectorCheap(EVT ResVT,EVT SrcVT,unsigned Index) const17592 bool AArch64TargetLowering::isExtractSubvectorCheap(EVT ResVT, EVT SrcVT,
17593 unsigned Index) const {
17594 if (!isOperationLegalOrCustom(ISD::EXTRACT_SUBVECTOR, ResVT))
17595 return false;
17596
17597 return (Index == 0 || Index == ResVT.getVectorMinNumElements());
17598 }
17599
17600 /// Turn vector tests of the signbit in the form of:
17601 /// xor (sra X, elt_size(X)-1), -1
17602 /// into:
17603 /// cmge X, X, #0
foldVectorXorShiftIntoCmp(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)17604 static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
17605 const AArch64Subtarget *Subtarget) {
17606 EVT VT = N->getValueType(0);
17607 if (!Subtarget->hasNEON() || !VT.isVector())
17608 return SDValue();
17609
17610 // There must be a shift right algebraic before the xor, and the xor must be a
17611 // 'not' operation.
17612 SDValue Shift = N->getOperand(0);
17613 SDValue Ones = N->getOperand(1);
17614 if (Shift.getOpcode() != AArch64ISD::VASHR || !Shift.hasOneUse() ||
17615 !ISD::isBuildVectorAllOnes(Ones.getNode()))
17616 return SDValue();
17617
17618 // The shift should be smearing the sign bit across each vector element.
17619 auto *ShiftAmt = dyn_cast<ConstantSDNode>(Shift.getOperand(1));
17620 EVT ShiftEltTy = Shift.getValueType().getVectorElementType();
17621 if (!ShiftAmt || ShiftAmt->getZExtValue() != ShiftEltTy.getSizeInBits() - 1)
17622 return SDValue();
17623
17624 return DAG.getNode(AArch64ISD::CMGEz, SDLoc(N), VT, Shift.getOperand(0));
17625 }
17626
17627 // Given a vecreduce_add node, detect the below pattern and convert it to the
17628 // node sequence with UABDL, [S|U]ADB and UADDLP.
17629 //
17630 // i32 vecreduce_add(
17631 // v16i32 abs(
17632 // v16i32 sub(
17633 // v16i32 [sign|zero]_extend(v16i8 a), v16i32 [sign|zero]_extend(v16i8 b))))
17634 // =================>
17635 // i32 vecreduce_add(
17636 // v4i32 UADDLP(
17637 // v8i16 add(
17638 // v8i16 zext(
17639 // v8i8 [S|U]ABD low8:v16i8 a, low8:v16i8 b
17640 // v8i16 zext(
17641 // v8i8 [S|U]ABD high8:v16i8 a, high8:v16i8 b
performVecReduceAddCombineWithUADDLP(SDNode * N,SelectionDAG & DAG)17642 static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
17643 SelectionDAG &DAG) {
17644 // Assumed i32 vecreduce_add
17645 if (N->getValueType(0) != MVT::i32)
17646 return SDValue();
17647
17648 SDValue VecReduceOp0 = N->getOperand(0);
17649 unsigned Opcode = VecReduceOp0.getOpcode();
17650 // Assumed v16i32 abs
17651 if (Opcode != ISD::ABS || VecReduceOp0->getValueType(0) != MVT::v16i32)
17652 return SDValue();
17653
17654 SDValue ABS = VecReduceOp0;
17655 // Assumed v16i32 sub
17656 if (ABS->getOperand(0)->getOpcode() != ISD::SUB ||
17657 ABS->getOperand(0)->getValueType(0) != MVT::v16i32)
17658 return SDValue();
17659
17660 SDValue SUB = ABS->getOperand(0);
17661 unsigned Opcode0 = SUB->getOperand(0).getOpcode();
17662 unsigned Opcode1 = SUB->getOperand(1).getOpcode();
17663 // Assumed v16i32 type
17664 if (SUB->getOperand(0)->getValueType(0) != MVT::v16i32 ||
17665 SUB->getOperand(1)->getValueType(0) != MVT::v16i32)
17666 return SDValue();
17667
17668 // Assumed zext or sext
17669 bool IsZExt = false;
17670 if (Opcode0 == ISD::ZERO_EXTEND && Opcode1 == ISD::ZERO_EXTEND) {
17671 IsZExt = true;
17672 } else if (Opcode0 == ISD::SIGN_EXTEND && Opcode1 == ISD::SIGN_EXTEND) {
17673 IsZExt = false;
17674 } else
17675 return SDValue();
17676
17677 SDValue EXT0 = SUB->getOperand(0);
17678 SDValue EXT1 = SUB->getOperand(1);
17679 // Assumed zext's operand has v16i8 type
17680 if (EXT0->getOperand(0)->getValueType(0) != MVT::v16i8 ||
17681 EXT1->getOperand(0)->getValueType(0) != MVT::v16i8)
17682 return SDValue();
17683
17684 // Pattern is dectected. Let's convert it to sequence of nodes.
17685 SDLoc DL(N);
17686
17687 // First, create the node pattern of UABD/SABD.
17688 SDValue UABDHigh8Op0 =
17689 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT0->getOperand(0),
17690 DAG.getConstant(8, DL, MVT::i64));
17691 SDValue UABDHigh8Op1 =
17692 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
17693 DAG.getConstant(8, DL, MVT::i64));
17694 SDValue UABDHigh8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8,
17695 UABDHigh8Op0, UABDHigh8Op1);
17696 SDValue UABDL = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDHigh8);
17697
17698 // Second, create the node pattern of UABAL.
17699 SDValue UABDLo8Op0 =
17700 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT0->getOperand(0),
17701 DAG.getConstant(0, DL, MVT::i64));
17702 SDValue UABDLo8Op1 =
17703 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, EXT1->getOperand(0),
17704 DAG.getConstant(0, DL, MVT::i64));
17705 SDValue UABDLo8 = DAG.getNode(IsZExt ? ISD::ABDU : ISD::ABDS, DL, MVT::v8i8,
17706 UABDLo8Op0, UABDLo8Op1);
17707 SDValue ZExtUABD = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::v8i16, UABDLo8);
17708 SDValue UABAL = DAG.getNode(ISD::ADD, DL, MVT::v8i16, UABDL, ZExtUABD);
17709
17710 // Third, create the node of UADDLP.
17711 SDValue UADDLP = DAG.getNode(AArch64ISD::UADDLP, DL, MVT::v4i32, UABAL);
17712
17713 // Fourth, create the node of VECREDUCE_ADD.
17714 return DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, UADDLP);
17715 }
17716
17717 // Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce
17718 // vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one))
17719 // vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B))
17720 // If we have vectors larger than v16i8 we extract v16i8 vectors,
17721 // Follow the same steps above to get DOT instructions concatenate them
17722 // and generate vecreduce.add(concat_vector(DOT, DOT2, ..)).
performVecReduceAddCombine(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * ST)17723 static SDValue performVecReduceAddCombine(SDNode *N, SelectionDAG &DAG,
17724 const AArch64Subtarget *ST) {
17725 if (!ST->isNeonAvailable())
17726 return SDValue();
17727
17728 if (!ST->hasDotProd())
17729 return performVecReduceAddCombineWithUADDLP(N, DAG);
17730
17731 SDValue Op0 = N->getOperand(0);
17732 if (N->getValueType(0) != MVT::i32 || Op0.getValueType().isScalableVT() ||
17733 Op0.getValueType().getVectorElementType() != MVT::i32)
17734 return SDValue();
17735
17736 unsigned ExtOpcode = Op0.getOpcode();
17737 SDValue A = Op0;
17738 SDValue B;
17739 if (ExtOpcode == ISD::MUL) {
17740 A = Op0.getOperand(0);
17741 B = Op0.getOperand(1);
17742 if (A.getOpcode() != B.getOpcode() ||
17743 A.getOperand(0).getValueType() != B.getOperand(0).getValueType())
17744 return SDValue();
17745 ExtOpcode = A.getOpcode();
17746 }
17747 if (ExtOpcode != ISD::ZERO_EXTEND && ExtOpcode != ISD::SIGN_EXTEND)
17748 return SDValue();
17749
17750 EVT Op0VT = A.getOperand(0).getValueType();
17751 bool IsValidElementCount = Op0VT.getVectorNumElements() % 8 == 0;
17752 bool IsValidSize = Op0VT.getScalarSizeInBits() == 8;
17753 if (!IsValidElementCount || !IsValidSize)
17754 return SDValue();
17755
17756 SDLoc DL(Op0);
17757 // For non-mla reductions B can be set to 1. For MLA we take the operand of
17758 // the extend B.
17759 if (!B)
17760 B = DAG.getConstant(1, DL, Op0VT);
17761 else
17762 B = B.getOperand(0);
17763
17764 unsigned IsMultipleOf16 = Op0VT.getVectorNumElements() % 16 == 0;
17765 unsigned NumOfVecReduce;
17766 EVT TargetType;
17767 if (IsMultipleOf16) {
17768 NumOfVecReduce = Op0VT.getVectorNumElements() / 16;
17769 TargetType = MVT::v4i32;
17770 } else {
17771 NumOfVecReduce = Op0VT.getVectorNumElements() / 8;
17772 TargetType = MVT::v2i32;
17773 }
17774 auto DotOpcode =
17775 (ExtOpcode == ISD::ZERO_EXTEND) ? AArch64ISD::UDOT : AArch64ISD::SDOT;
17776 // Handle the case where we need to generate only one Dot operation.
17777 if (NumOfVecReduce == 1) {
17778 SDValue Zeros = DAG.getConstant(0, DL, TargetType);
17779 SDValue Dot = DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros,
17780 A.getOperand(0), B);
17781 return DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
17782 }
17783 // Generate Dot instructions that are multiple of 16.
17784 unsigned VecReduce16Num = Op0VT.getVectorNumElements() / 16;
17785 SmallVector<SDValue, 4> SDotVec16;
17786 unsigned I = 0;
17787 for (; I < VecReduce16Num; I += 1) {
17788 SDValue Zeros = DAG.getConstant(0, DL, MVT::v4i32);
17789 SDValue Op0 =
17790 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v16i8, A.getOperand(0),
17791 DAG.getConstant(I * 16, DL, MVT::i64));
17792 SDValue Op1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v16i8, B,
17793 DAG.getConstant(I * 16, DL, MVT::i64));
17794 SDValue Dot =
17795 DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros, Op0, Op1);
17796 SDotVec16.push_back(Dot);
17797 }
17798 // Concatenate dot operations.
17799 EVT SDot16EVT =
17800 EVT::getVectorVT(*DAG.getContext(), MVT::i32, 4 * VecReduce16Num);
17801 SDValue ConcatSDot16 =
17802 DAG.getNode(ISD::CONCAT_VECTORS, DL, SDot16EVT, SDotVec16);
17803 SDValue VecReduceAdd16 =
17804 DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), ConcatSDot16);
17805 unsigned VecReduce8Num = (Op0VT.getVectorNumElements() % 16) / 8;
17806 if (VecReduce8Num == 0)
17807 return VecReduceAdd16;
17808
17809 // Generate the remainder Dot operation that is multiple of 8.
17810 SmallVector<SDValue, 4> SDotVec8;
17811 SDValue Zeros = DAG.getConstant(0, DL, MVT::v2i32);
17812 SDValue Vec8Op0 =
17813 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, A.getOperand(0),
17814 DAG.getConstant(I * 16, DL, MVT::i64));
17815 SDValue Vec8Op1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v8i8, B,
17816 DAG.getConstant(I * 16, DL, MVT::i64));
17817 SDValue Dot =
17818 DAG.getNode(DotOpcode, DL, Zeros.getValueType(), Zeros, Vec8Op0, Vec8Op1);
17819 SDValue VecReudceAdd8 =
17820 DAG.getNode(ISD::VECREDUCE_ADD, DL, N->getValueType(0), Dot);
17821 return DAG.getNode(ISD::ADD, DL, N->getValueType(0), VecReduceAdd16,
17822 VecReudceAdd8);
17823 }
17824
17825 // Given an (integer) vecreduce, we know the order of the inputs does not
17826 // matter. We can convert UADDV(add(zext(extract_lo(x)), zext(extract_hi(x))))
17827 // into UADDV(UADDLP(x)). This can also happen through an extra add, where we
17828 // transform UADDV(add(y, add(zext(extract_lo(x)), zext(extract_hi(x))))).
performUADDVAddCombine(SDValue A,SelectionDAG & DAG)17829 static SDValue performUADDVAddCombine(SDValue A, SelectionDAG &DAG) {
17830 auto DetectAddExtract = [&](SDValue A) {
17831 // Look for add(zext(extract_lo(x)), zext(extract_hi(x))), returning
17832 // UADDLP(x) if found.
17833 assert(A.getOpcode() == ISD::ADD);
17834 EVT VT = A.getValueType();
17835 SDValue Op0 = A.getOperand(0);
17836 SDValue Op1 = A.getOperand(1);
17837 if (Op0.getOpcode() != Op0.getOpcode() ||
17838 (Op0.getOpcode() != ISD::ZERO_EXTEND &&
17839 Op0.getOpcode() != ISD::SIGN_EXTEND))
17840 return SDValue();
17841 SDValue Ext0 = Op0.getOperand(0);
17842 SDValue Ext1 = Op1.getOperand(0);
17843 if (Ext0.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
17844 Ext1.getOpcode() != ISD::EXTRACT_SUBVECTOR ||
17845 Ext0.getOperand(0) != Ext1.getOperand(0))
17846 return SDValue();
17847 // Check that the type is twice the add types, and the extract are from
17848 // upper/lower parts of the same source.
17849 if (Ext0.getOperand(0).getValueType().getVectorNumElements() !=
17850 VT.getVectorNumElements() * 2)
17851 return SDValue();
17852 if ((Ext0.getConstantOperandVal(1) != 0 ||
17853 Ext1.getConstantOperandVal(1) != VT.getVectorNumElements()) &&
17854 (Ext1.getConstantOperandVal(1) != 0 ||
17855 Ext0.getConstantOperandVal(1) != VT.getVectorNumElements()))
17856 return SDValue();
17857 unsigned Opcode = Op0.getOpcode() == ISD::ZERO_EXTEND ? AArch64ISD::UADDLP
17858 : AArch64ISD::SADDLP;
17859 return DAG.getNode(Opcode, SDLoc(A), VT, Ext0.getOperand(0));
17860 };
17861
17862 if (SDValue R = DetectAddExtract(A))
17863 return R;
17864
17865 if (A.getOperand(0).getOpcode() == ISD::ADD && A.getOperand(0).hasOneUse())
17866 if (SDValue R = performUADDVAddCombine(A.getOperand(0), DAG))
17867 return DAG.getNode(ISD::ADD, SDLoc(A), A.getValueType(), R,
17868 A.getOperand(1));
17869 if (A.getOperand(1).getOpcode() == ISD::ADD && A.getOperand(1).hasOneUse())
17870 if (SDValue R = performUADDVAddCombine(A.getOperand(1), DAG))
17871 return DAG.getNode(ISD::ADD, SDLoc(A), A.getValueType(), R,
17872 A.getOperand(0));
17873 return SDValue();
17874 }
17875
17876 // We can convert a UADDV(add(zext(64-bit source), zext(64-bit source))) into
17877 // UADDLV(concat), where the concat represents the 64-bit zext sources.
performUADDVZextCombine(SDValue A,SelectionDAG & DAG)17878 static SDValue performUADDVZextCombine(SDValue A, SelectionDAG &DAG) {
17879 // Look for add(zext(64-bit source), zext(64-bit source)), returning
17880 // UADDLV(concat(zext, zext)) if found.
17881 assert(A.getOpcode() == ISD::ADD);
17882 EVT VT = A.getValueType();
17883 if (VT != MVT::v8i16 && VT != MVT::v4i32 && VT != MVT::v2i64)
17884 return SDValue();
17885 SDValue Op0 = A.getOperand(0);
17886 SDValue Op1 = A.getOperand(1);
17887 if (Op0.getOpcode() != ISD::ZERO_EXTEND || Op0.getOpcode() != Op1.getOpcode())
17888 return SDValue();
17889 SDValue Ext0 = Op0.getOperand(0);
17890 SDValue Ext1 = Op1.getOperand(0);
17891 EVT ExtVT0 = Ext0.getValueType();
17892 EVT ExtVT1 = Ext1.getValueType();
17893 // Check zext VTs are the same and 64-bit length.
17894 if (ExtVT0 != ExtVT1 ||
17895 VT.getScalarSizeInBits() != (2 * ExtVT0.getScalarSizeInBits()))
17896 return SDValue();
17897 // Get VT for concat of zext sources.
17898 EVT PairVT = ExtVT0.getDoubleNumVectorElementsVT(*DAG.getContext());
17899 SDValue Concat =
17900 DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(A), PairVT, Ext0, Ext1);
17901
17902 switch (VT.getSimpleVT().SimpleTy) {
17903 case MVT::v2i64:
17904 case MVT::v4i32:
17905 return DAG.getNode(AArch64ISD::UADDLV, SDLoc(A), VT, Concat);
17906 case MVT::v8i16: {
17907 SDValue Uaddlv =
17908 DAG.getNode(AArch64ISD::UADDLV, SDLoc(A), MVT::v4i32, Concat);
17909 return DAG.getNode(AArch64ISD::NVCAST, SDLoc(A), MVT::v8i16, Uaddlv);
17910 }
17911 default:
17912 llvm_unreachable("Unhandled vector type");
17913 }
17914 }
17915
performUADDVCombine(SDNode * N,SelectionDAG & DAG)17916 static SDValue performUADDVCombine(SDNode *N, SelectionDAG &DAG) {
17917 SDValue A = N->getOperand(0);
17918 if (A.getOpcode() == ISD::ADD) {
17919 if (SDValue R = performUADDVAddCombine(A, DAG))
17920 return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), R);
17921 else if (SDValue R = performUADDVZextCombine(A, DAG))
17922 return R;
17923 }
17924 return SDValue();
17925 }
17926
performXorCombine(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)17927 static SDValue performXorCombine(SDNode *N, SelectionDAG &DAG,
17928 TargetLowering::DAGCombinerInfo &DCI,
17929 const AArch64Subtarget *Subtarget) {
17930 if (DCI.isBeforeLegalizeOps())
17931 return SDValue();
17932
17933 return foldVectorXorShiftIntoCmp(N, DAG, Subtarget);
17934 }
17935
17936 SDValue
BuildSDIVPow2(SDNode * N,const APInt & Divisor,SelectionDAG & DAG,SmallVectorImpl<SDNode * > & Created) const17937 AArch64TargetLowering::BuildSDIVPow2(SDNode *N, const APInt &Divisor,
17938 SelectionDAG &DAG,
17939 SmallVectorImpl<SDNode *> &Created) const {
17940 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
17941 if (isIntDivCheap(N->getValueType(0), Attr))
17942 return SDValue(N, 0); // Lower SDIV as SDIV
17943
17944 EVT VT = N->getValueType(0);
17945
17946 // For scalable and fixed types, mark them as cheap so we can handle it much
17947 // later. This allows us to handle larger than legal types.
17948 if (VT.isScalableVector() ||
17949 (VT.isFixedLengthVector() && Subtarget->useSVEForFixedLengthVectors()))
17950 return SDValue(N, 0);
17951
17952 // fold (sdiv X, pow2)
17953 if ((VT != MVT::i32 && VT != MVT::i64) ||
17954 !(Divisor.isPowerOf2() || Divisor.isNegatedPowerOf2()))
17955 return SDValue();
17956
17957 // If the divisor is 2 or -2, the default expansion is better. It will add
17958 // (N->getValueType(0) >> (BitWidth - 1)) to it before shifting right.
17959 if (Divisor == 2 ||
17960 Divisor == APInt(Divisor.getBitWidth(), -2, /*isSigned*/ true))
17961 return SDValue();
17962
17963 return TargetLowering::buildSDIVPow2WithCMov(N, Divisor, DAG, Created);
17964 }
17965
17966 SDValue
BuildSREMPow2(SDNode * N,const APInt & Divisor,SelectionDAG & DAG,SmallVectorImpl<SDNode * > & Created) const17967 AArch64TargetLowering::BuildSREMPow2(SDNode *N, const APInt &Divisor,
17968 SelectionDAG &DAG,
17969 SmallVectorImpl<SDNode *> &Created) const {
17970 AttributeList Attr = DAG.getMachineFunction().getFunction().getAttributes();
17971 if (isIntDivCheap(N->getValueType(0), Attr))
17972 return SDValue(N, 0); // Lower SREM as SREM
17973
17974 EVT VT = N->getValueType(0);
17975
17976 // For scalable and fixed types, mark them as cheap so we can handle it much
17977 // later. This allows us to handle larger than legal types.
17978 if (VT.isScalableVector() || Subtarget->useSVEForFixedLengthVectors())
17979 return SDValue(N, 0);
17980
17981 // fold (srem X, pow2)
17982 if ((VT != MVT::i32 && VT != MVT::i64) ||
17983 !(Divisor.isPowerOf2() || Divisor.isNegatedPowerOf2()))
17984 return SDValue();
17985
17986 unsigned Lg2 = Divisor.countr_zero();
17987 if (Lg2 == 0)
17988 return SDValue();
17989
17990 SDLoc DL(N);
17991 SDValue N0 = N->getOperand(0);
17992 SDValue Pow2MinusOne = DAG.getConstant((1ULL << Lg2) - 1, DL, VT);
17993 SDValue Zero = DAG.getConstant(0, DL, VT);
17994 SDValue CCVal, CSNeg;
17995 if (Lg2 == 1) {
17996 SDValue Cmp = getAArch64Cmp(N0, Zero, ISD::SETGE, CCVal, DAG, DL);
17997 SDValue And = DAG.getNode(ISD::AND, DL, VT, N0, Pow2MinusOne);
17998 CSNeg = DAG.getNode(AArch64ISD::CSNEG, DL, VT, And, And, CCVal, Cmp);
17999
18000 Created.push_back(Cmp.getNode());
18001 Created.push_back(And.getNode());
18002 } else {
18003 SDValue CCVal = DAG.getConstant(AArch64CC::MI, DL, MVT_CC);
18004 SDVTList VTs = DAG.getVTList(VT, MVT::i32);
18005
18006 SDValue Negs = DAG.getNode(AArch64ISD::SUBS, DL, VTs, Zero, N0);
18007 SDValue AndPos = DAG.getNode(ISD::AND, DL, VT, N0, Pow2MinusOne);
18008 SDValue AndNeg = DAG.getNode(ISD::AND, DL, VT, Negs, Pow2MinusOne);
18009 CSNeg = DAG.getNode(AArch64ISD::CSNEG, DL, VT, AndPos, AndNeg, CCVal,
18010 Negs.getValue(1));
18011
18012 Created.push_back(Negs.getNode());
18013 Created.push_back(AndPos.getNode());
18014 Created.push_back(AndNeg.getNode());
18015 }
18016
18017 return CSNeg;
18018 }
18019
IsSVECntIntrinsic(SDValue S)18020 static std::optional<unsigned> IsSVECntIntrinsic(SDValue S) {
18021 switch(getIntrinsicID(S.getNode())) {
18022 default:
18023 break;
18024 case Intrinsic::aarch64_sve_cntb:
18025 return 8;
18026 case Intrinsic::aarch64_sve_cnth:
18027 return 16;
18028 case Intrinsic::aarch64_sve_cntw:
18029 return 32;
18030 case Intrinsic::aarch64_sve_cntd:
18031 return 64;
18032 }
18033 return {};
18034 }
18035
18036 /// Calculates what the pre-extend type is, based on the extension
18037 /// operation node provided by \p Extend.
18038 ///
18039 /// In the case that \p Extend is a SIGN_EXTEND or a ZERO_EXTEND, the
18040 /// pre-extend type is pulled directly from the operand, while other extend
18041 /// operations need a bit more inspection to get this information.
18042 ///
18043 /// \param Extend The SDNode from the DAG that represents the extend operation
18044 ///
18045 /// \returns The type representing the \p Extend source type, or \p MVT::Other
18046 /// if no valid type can be determined
calculatePreExtendType(SDValue Extend)18047 static EVT calculatePreExtendType(SDValue Extend) {
18048 switch (Extend.getOpcode()) {
18049 case ISD::SIGN_EXTEND:
18050 case ISD::ZERO_EXTEND:
18051 return Extend.getOperand(0).getValueType();
18052 case ISD::AssertSext:
18053 case ISD::AssertZext:
18054 case ISD::SIGN_EXTEND_INREG: {
18055 VTSDNode *TypeNode = dyn_cast<VTSDNode>(Extend.getOperand(1));
18056 if (!TypeNode)
18057 return MVT::Other;
18058 return TypeNode->getVT();
18059 }
18060 case ISD::AND: {
18061 ConstantSDNode *Constant =
18062 dyn_cast<ConstantSDNode>(Extend.getOperand(1).getNode());
18063 if (!Constant)
18064 return MVT::Other;
18065
18066 uint32_t Mask = Constant->getZExtValue();
18067
18068 if (Mask == UCHAR_MAX)
18069 return MVT::i8;
18070 else if (Mask == USHRT_MAX)
18071 return MVT::i16;
18072 else if (Mask == UINT_MAX)
18073 return MVT::i32;
18074
18075 return MVT::Other;
18076 }
18077 default:
18078 return MVT::Other;
18079 }
18080 }
18081
18082 /// Combines a buildvector(sext/zext) or shuffle(sext/zext, undef) node pattern
18083 /// into sext/zext(buildvector) or sext/zext(shuffle) making use of the vector
18084 /// SExt/ZExt rather than the scalar SExt/ZExt
performBuildShuffleExtendCombine(SDValue BV,SelectionDAG & DAG)18085 static SDValue performBuildShuffleExtendCombine(SDValue BV, SelectionDAG &DAG) {
18086 EVT VT = BV.getValueType();
18087 if (BV.getOpcode() != ISD::BUILD_VECTOR &&
18088 BV.getOpcode() != ISD::VECTOR_SHUFFLE)
18089 return SDValue();
18090
18091 // Use the first item in the buildvector/shuffle to get the size of the
18092 // extend, and make sure it looks valid.
18093 SDValue Extend = BV->getOperand(0);
18094 unsigned ExtendOpcode = Extend.getOpcode();
18095 bool IsSExt = ExtendOpcode == ISD::SIGN_EXTEND ||
18096 ExtendOpcode == ISD::SIGN_EXTEND_INREG ||
18097 ExtendOpcode == ISD::AssertSext;
18098 if (!IsSExt && ExtendOpcode != ISD::ZERO_EXTEND &&
18099 ExtendOpcode != ISD::AssertZext && ExtendOpcode != ISD::AND)
18100 return SDValue();
18101 // Shuffle inputs are vector, limit to SIGN_EXTEND and ZERO_EXTEND to ensure
18102 // calculatePreExtendType will work without issue.
18103 if (BV.getOpcode() == ISD::VECTOR_SHUFFLE &&
18104 ExtendOpcode != ISD::SIGN_EXTEND && ExtendOpcode != ISD::ZERO_EXTEND)
18105 return SDValue();
18106
18107 // Restrict valid pre-extend data type
18108 EVT PreExtendType = calculatePreExtendType(Extend);
18109 if (PreExtendType == MVT::Other ||
18110 PreExtendType.getScalarSizeInBits() != VT.getScalarSizeInBits() / 2)
18111 return SDValue();
18112
18113 // Make sure all other operands are equally extended
18114 for (SDValue Op : drop_begin(BV->ops())) {
18115 if (Op.isUndef())
18116 continue;
18117 unsigned Opc = Op.getOpcode();
18118 bool OpcIsSExt = Opc == ISD::SIGN_EXTEND || Opc == ISD::SIGN_EXTEND_INREG ||
18119 Opc == ISD::AssertSext;
18120 if (OpcIsSExt != IsSExt || calculatePreExtendType(Op) != PreExtendType)
18121 return SDValue();
18122 }
18123
18124 SDValue NBV;
18125 SDLoc DL(BV);
18126 if (BV.getOpcode() == ISD::BUILD_VECTOR) {
18127 EVT PreExtendVT = VT.changeVectorElementType(PreExtendType);
18128 EVT PreExtendLegalType =
18129 PreExtendType.getScalarSizeInBits() < 32 ? MVT::i32 : PreExtendType;
18130 SmallVector<SDValue, 8> NewOps;
18131 for (SDValue Op : BV->ops())
18132 NewOps.push_back(Op.isUndef() ? DAG.getUNDEF(PreExtendLegalType)
18133 : DAG.getAnyExtOrTrunc(Op.getOperand(0), DL,
18134 PreExtendLegalType));
18135 NBV = DAG.getNode(ISD::BUILD_VECTOR, DL, PreExtendVT, NewOps);
18136 } else { // BV.getOpcode() == ISD::VECTOR_SHUFFLE
18137 EVT PreExtendVT = VT.changeVectorElementType(PreExtendType.getScalarType());
18138 NBV = DAG.getVectorShuffle(PreExtendVT, DL, BV.getOperand(0).getOperand(0),
18139 BV.getOperand(1).isUndef()
18140 ? DAG.getUNDEF(PreExtendVT)
18141 : BV.getOperand(1).getOperand(0),
18142 cast<ShuffleVectorSDNode>(BV)->getMask());
18143 }
18144 return DAG.getNode(IsSExt ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, VT, NBV);
18145 }
18146
18147 /// Combines a mul(dup(sext/zext)) node pattern into mul(sext/zext(dup))
18148 /// making use of the vector SExt/ZExt rather than the scalar SExt/ZExt
performMulVectorExtendCombine(SDNode * Mul,SelectionDAG & DAG)18149 static SDValue performMulVectorExtendCombine(SDNode *Mul, SelectionDAG &DAG) {
18150 // If the value type isn't a vector, none of the operands are going to be dups
18151 EVT VT = Mul->getValueType(0);
18152 if (VT != MVT::v8i16 && VT != MVT::v4i32 && VT != MVT::v2i64)
18153 return SDValue();
18154
18155 SDValue Op0 = performBuildShuffleExtendCombine(Mul->getOperand(0), DAG);
18156 SDValue Op1 = performBuildShuffleExtendCombine(Mul->getOperand(1), DAG);
18157
18158 // Neither operands have been changed, don't make any further changes
18159 if (!Op0 && !Op1)
18160 return SDValue();
18161
18162 SDLoc DL(Mul);
18163 return DAG.getNode(Mul->getOpcode(), DL, VT, Op0 ? Op0 : Mul->getOperand(0),
18164 Op1 ? Op1 : Mul->getOperand(1));
18165 }
18166
18167 // Combine v4i32 Mul(And(Srl(X, 15), 0x10001), 0xffff) -> v8i16 CMLTz
18168 // Same for other types with equivalent constants.
performMulVectorCmpZeroCombine(SDNode * N,SelectionDAG & DAG)18169 static SDValue performMulVectorCmpZeroCombine(SDNode *N, SelectionDAG &DAG) {
18170 EVT VT = N->getValueType(0);
18171 if (VT != MVT::v2i64 && VT != MVT::v1i64 && VT != MVT::v2i32 &&
18172 VT != MVT::v4i32 && VT != MVT::v4i16 && VT != MVT::v8i16)
18173 return SDValue();
18174 if (N->getOperand(0).getOpcode() != ISD::AND ||
18175 N->getOperand(0).getOperand(0).getOpcode() != ISD::SRL)
18176 return SDValue();
18177
18178 SDValue And = N->getOperand(0);
18179 SDValue Srl = And.getOperand(0);
18180
18181 APInt V1, V2, V3;
18182 if (!ISD::isConstantSplatVector(N->getOperand(1).getNode(), V1) ||
18183 !ISD::isConstantSplatVector(And.getOperand(1).getNode(), V2) ||
18184 !ISD::isConstantSplatVector(Srl.getOperand(1).getNode(), V3))
18185 return SDValue();
18186
18187 unsigned HalfSize = VT.getScalarSizeInBits() / 2;
18188 if (!V1.isMask(HalfSize) || V2 != (1ULL | 1ULL << HalfSize) ||
18189 V3 != (HalfSize - 1))
18190 return SDValue();
18191
18192 EVT HalfVT = EVT::getVectorVT(*DAG.getContext(),
18193 EVT::getIntegerVT(*DAG.getContext(), HalfSize),
18194 VT.getVectorElementCount() * 2);
18195
18196 SDLoc DL(N);
18197 SDValue In = DAG.getNode(AArch64ISD::NVCAST, DL, HalfVT, Srl.getOperand(0));
18198 SDValue CM = DAG.getNode(AArch64ISD::CMLTz, DL, HalfVT, In);
18199 return DAG.getNode(AArch64ISD::NVCAST, DL, VT, CM);
18200 }
18201
18202 // Transform vector add(zext i8 to i32, zext i8 to i32)
18203 // into sext(add(zext(i8 to i16), zext(i8 to i16)) to i32)
18204 // This allows extra uses of saddl/uaddl at the lower vector widths, and less
18205 // extends.
performVectorExtCombine(SDNode * N,SelectionDAG & DAG)18206 static SDValue performVectorExtCombine(SDNode *N, SelectionDAG &DAG) {
18207 EVT VT = N->getValueType(0);
18208 if (!VT.isFixedLengthVector() || VT.getSizeInBits() <= 128 ||
18209 (N->getOperand(0).getOpcode() != ISD::ZERO_EXTEND &&
18210 N->getOperand(0).getOpcode() != ISD::SIGN_EXTEND) ||
18211 (N->getOperand(1).getOpcode() != ISD::ZERO_EXTEND &&
18212 N->getOperand(1).getOpcode() != ISD::SIGN_EXTEND) ||
18213 N->getOperand(0).getOperand(0).getValueType() !=
18214 N->getOperand(1).getOperand(0).getValueType())
18215 return SDValue();
18216
18217 if (N->getOpcode() == ISD::MUL &&
18218 N->getOperand(0).getOpcode() != N->getOperand(1).getOpcode())
18219 return SDValue();
18220
18221 SDValue N0 = N->getOperand(0).getOperand(0);
18222 SDValue N1 = N->getOperand(1).getOperand(0);
18223 EVT InVT = N0.getValueType();
18224
18225 EVT S1 = InVT.getScalarType();
18226 EVT S2 = VT.getScalarType();
18227 if ((S2 == MVT::i32 && S1 == MVT::i8) ||
18228 (S2 == MVT::i64 && (S1 == MVT::i8 || S1 == MVT::i16))) {
18229 SDLoc DL(N);
18230 EVT HalfVT = EVT::getVectorVT(*DAG.getContext(),
18231 S2.getHalfSizedIntegerVT(*DAG.getContext()),
18232 VT.getVectorElementCount());
18233 SDValue NewN0 = DAG.getNode(N->getOperand(0).getOpcode(), DL, HalfVT, N0);
18234 SDValue NewN1 = DAG.getNode(N->getOperand(1).getOpcode(), DL, HalfVT, N1);
18235 SDValue NewOp = DAG.getNode(N->getOpcode(), DL, HalfVT, NewN0, NewN1);
18236 return DAG.getNode(N->getOpcode() == ISD::MUL ? N->getOperand(0).getOpcode()
18237 : (unsigned)ISD::SIGN_EXTEND,
18238 DL, VT, NewOp);
18239 }
18240 return SDValue();
18241 }
18242
performMulCombine(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)18243 static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
18244 TargetLowering::DAGCombinerInfo &DCI,
18245 const AArch64Subtarget *Subtarget) {
18246
18247 if (SDValue Ext = performMulVectorExtendCombine(N, DAG))
18248 return Ext;
18249 if (SDValue Ext = performMulVectorCmpZeroCombine(N, DAG))
18250 return Ext;
18251 if (SDValue Ext = performVectorExtCombine(N, DAG))
18252 return Ext;
18253
18254 if (DCI.isBeforeLegalizeOps())
18255 return SDValue();
18256
18257 // Canonicalize X*(Y+1) -> X*Y+X and (X+1)*Y -> X*Y+Y,
18258 // and in MachineCombiner pass, add+mul will be combined into madd.
18259 // Similarly, X*(1-Y) -> X - X*Y and (1-Y)*X -> X - Y*X.
18260 SDLoc DL(N);
18261 EVT VT = N->getValueType(0);
18262 SDValue N0 = N->getOperand(0);
18263 SDValue N1 = N->getOperand(1);
18264 SDValue MulOper;
18265 unsigned AddSubOpc;
18266
18267 auto IsAddSubWith1 = [&](SDValue V) -> bool {
18268 AddSubOpc = V->getOpcode();
18269 if ((AddSubOpc == ISD::ADD || AddSubOpc == ISD::SUB) && V->hasOneUse()) {
18270 SDValue Opnd = V->getOperand(1);
18271 MulOper = V->getOperand(0);
18272 if (AddSubOpc == ISD::SUB)
18273 std::swap(Opnd, MulOper);
18274 if (auto C = dyn_cast<ConstantSDNode>(Opnd))
18275 return C->isOne();
18276 }
18277 return false;
18278 };
18279
18280 if (IsAddSubWith1(N0)) {
18281 SDValue MulVal = DAG.getNode(ISD::MUL, DL, VT, N1, MulOper);
18282 return DAG.getNode(AddSubOpc, DL, VT, N1, MulVal);
18283 }
18284
18285 if (IsAddSubWith1(N1)) {
18286 SDValue MulVal = DAG.getNode(ISD::MUL, DL, VT, N0, MulOper);
18287 return DAG.getNode(AddSubOpc, DL, VT, N0, MulVal);
18288 }
18289
18290 // The below optimizations require a constant RHS.
18291 if (!isa<ConstantSDNode>(N1))
18292 return SDValue();
18293
18294 ConstantSDNode *C = cast<ConstantSDNode>(N1);
18295 const APInt &ConstValue = C->getAPIntValue();
18296
18297 // Allow the scaling to be folded into the `cnt` instruction by preventing
18298 // the scaling to be obscured here. This makes it easier to pattern match.
18299 if (IsSVECntIntrinsic(N0) ||
18300 (N0->getOpcode() == ISD::TRUNCATE &&
18301 (IsSVECntIntrinsic(N0->getOperand(0)))))
18302 if (ConstValue.sge(1) && ConstValue.sle(16))
18303 return SDValue();
18304
18305 // Multiplication of a power of two plus/minus one can be done more
18306 // cheaply as shift+add/sub. For now, this is true unilaterally. If
18307 // future CPUs have a cheaper MADD instruction, this may need to be
18308 // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and
18309 // 64-bit is 5 cycles, so this is always a win.
18310 // More aggressively, some multiplications N0 * C can be lowered to
18311 // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M,
18312 // e.g. 6=3*2=(2+1)*2, 45=(1+4)*(1+8)
18313 // TODO: lower more cases.
18314
18315 // TrailingZeroes is used to test if the mul can be lowered to
18316 // shift+add+shift.
18317 unsigned TrailingZeroes = ConstValue.countr_zero();
18318 if (TrailingZeroes) {
18319 // Conservatively do not lower to shift+add+shift if the mul might be
18320 // folded into smul or umul.
18321 if (N0->hasOneUse() && (isSignExtended(N0, DAG) ||
18322 isZeroExtended(N0, DAG)))
18323 return SDValue();
18324 // Conservatively do not lower to shift+add+shift if the mul might be
18325 // folded into madd or msub.
18326 if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ADD ||
18327 N->use_begin()->getOpcode() == ISD::SUB))
18328 return SDValue();
18329 }
18330 // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub
18331 // and shift+add+shift.
18332 APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes);
18333 unsigned ShiftAmt;
18334
18335 auto Shl = [&](SDValue N0, unsigned N1) {
18336 if (!N0.getNode())
18337 return SDValue();
18338 // If shift causes overflow, ignore this combine.
18339 if (N1 >= N0.getValueSizeInBits())
18340 return SDValue();
18341 SDValue RHS = DAG.getConstant(N1, DL, MVT::i64);
18342 return DAG.getNode(ISD::SHL, DL, VT, N0, RHS);
18343 };
18344 auto Add = [&](SDValue N0, SDValue N1) {
18345 if (!N0.getNode() || !N1.getNode())
18346 return SDValue();
18347 return DAG.getNode(ISD::ADD, DL, VT, N0, N1);
18348 };
18349 auto Sub = [&](SDValue N0, SDValue N1) {
18350 if (!N0.getNode() || !N1.getNode())
18351 return SDValue();
18352 return DAG.getNode(ISD::SUB, DL, VT, N0, N1);
18353 };
18354 auto Negate = [&](SDValue N) {
18355 if (!N0.getNode())
18356 return SDValue();
18357 SDValue Zero = DAG.getConstant(0, DL, VT);
18358 return DAG.getNode(ISD::SUB, DL, VT, Zero, N);
18359 };
18360
18361 // Can the const C be decomposed into (1+2^M1)*(1+2^N1), eg:
18362 // C = 45 is equal to (1+4)*(1+8), we don't decompose it into (1+2)*(16-1) as
18363 // the (2^N - 1) can't be execused via a single instruction.
18364 auto isPowPlusPlusConst = [](APInt C, APInt &M, APInt &N) {
18365 unsigned BitWidth = C.getBitWidth();
18366 for (unsigned i = 1; i < BitWidth / 2; i++) {
18367 APInt Rem;
18368 APInt X(BitWidth, (1 << i) + 1);
18369 APInt::sdivrem(C, X, N, Rem);
18370 APInt NVMinus1 = N - 1;
18371 if (Rem == 0 && NVMinus1.isPowerOf2()) {
18372 M = X;
18373 return true;
18374 }
18375 }
18376 return false;
18377 };
18378
18379 // Can the const C be decomposed into (2^M + 1) * 2^N + 1), eg:
18380 // C = 11 is equal to (1+4)*2+1, we don't decompose it into (1+2)*4-1 as
18381 // the (2^N - 1) can't be execused via a single instruction.
18382 auto isPowPlusPlusOneConst = [](APInt C, APInt &M, APInt &N) {
18383 APInt CVMinus1 = C - 1;
18384 if (CVMinus1.isNegative())
18385 return false;
18386 unsigned TrailingZeroes = CVMinus1.countr_zero();
18387 APInt SCVMinus1 = CVMinus1.ashr(TrailingZeroes) - 1;
18388 if (SCVMinus1.isPowerOf2()) {
18389 unsigned BitWidth = SCVMinus1.getBitWidth();
18390 M = APInt(BitWidth, SCVMinus1.logBase2());
18391 N = APInt(BitWidth, TrailingZeroes);
18392 return true;
18393 }
18394 return false;
18395 };
18396
18397 // Can the const C be decomposed into (1 - (1 - 2^M) * 2^N), eg:
18398 // C = 29 is equal to 1 - (1 - 2^3) * 2^2.
18399 auto isPowMinusMinusOneConst = [](APInt C, APInt &M, APInt &N) {
18400 APInt CVMinus1 = C - 1;
18401 if (CVMinus1.isNegative())
18402 return false;
18403 unsigned TrailingZeroes = CVMinus1.countr_zero();
18404 APInt CVPlus1 = CVMinus1.ashr(TrailingZeroes) + 1;
18405 if (CVPlus1.isPowerOf2()) {
18406 unsigned BitWidth = CVPlus1.getBitWidth();
18407 M = APInt(BitWidth, CVPlus1.logBase2());
18408 N = APInt(BitWidth, TrailingZeroes);
18409 return true;
18410 }
18411 return false;
18412 };
18413
18414 if (ConstValue.isNonNegative()) {
18415 // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M)
18416 // (mul x, 2^N - 1) => (sub (shl x, N), x)
18417 // (mul x, (2^(N-M) - 1) * 2^M) => (sub (shl x, N), (shl x, M))
18418 // (mul x, (2^M + 1) * (2^N + 1))
18419 // => MV = (add (shl x, M), x); (add (shl MV, N), MV)
18420 // (mul x, (2^M + 1) * 2^N + 1))
18421 // => MV = add (shl x, M), x); add (shl MV, N), x)
18422 // (mul x, 1 - (1 - 2^M) * 2^N))
18423 // => MV = sub (x - (shl x, M)); sub (x - (shl MV, N))
18424 APInt SCVMinus1 = ShiftedConstValue - 1;
18425 APInt SCVPlus1 = ShiftedConstValue + 1;
18426 APInt CVPlus1 = ConstValue + 1;
18427 APInt CVM, CVN;
18428 if (SCVMinus1.isPowerOf2()) {
18429 ShiftAmt = SCVMinus1.logBase2();
18430 return Shl(Add(Shl(N0, ShiftAmt), N0), TrailingZeroes);
18431 } else if (CVPlus1.isPowerOf2()) {
18432 ShiftAmt = CVPlus1.logBase2();
18433 return Sub(Shl(N0, ShiftAmt), N0);
18434 } else if (SCVPlus1.isPowerOf2()) {
18435 ShiftAmt = SCVPlus1.logBase2() + TrailingZeroes;
18436 return Sub(Shl(N0, ShiftAmt), Shl(N0, TrailingZeroes));
18437 }
18438 if (Subtarget->hasALULSLFast() &&
18439 isPowPlusPlusConst(ConstValue, CVM, CVN)) {
18440 APInt CVMMinus1 = CVM - 1;
18441 APInt CVNMinus1 = CVN - 1;
18442 unsigned ShiftM1 = CVMMinus1.logBase2();
18443 unsigned ShiftN1 = CVNMinus1.logBase2();
18444 // ALULSLFast implicate that Shifts <= 4 places are fast
18445 if (ShiftM1 <= 4 && ShiftN1 <= 4) {
18446 SDValue MVal = Add(Shl(N0, ShiftM1), N0);
18447 return Add(Shl(MVal, ShiftN1), MVal);
18448 }
18449 }
18450 if (Subtarget->hasALULSLFast() &&
18451 isPowPlusPlusOneConst(ConstValue, CVM, CVN)) {
18452 unsigned ShiftM = CVM.getZExtValue();
18453 unsigned ShiftN = CVN.getZExtValue();
18454 // ALULSLFast implicate that Shifts <= 4 places are fast
18455 if (ShiftM <= 4 && ShiftN <= 4) {
18456 SDValue MVal = Add(Shl(N0, CVM.getZExtValue()), N0);
18457 return Add(Shl(MVal, CVN.getZExtValue()), N0);
18458 }
18459 }
18460
18461 if (Subtarget->hasALULSLFast() &&
18462 isPowMinusMinusOneConst(ConstValue, CVM, CVN)) {
18463 unsigned ShiftM = CVM.getZExtValue();
18464 unsigned ShiftN = CVN.getZExtValue();
18465 // ALULSLFast implicate that Shifts <= 4 places are fast
18466 if (ShiftM <= 4 && ShiftN <= 4) {
18467 SDValue MVal = Sub(N0, Shl(N0, CVM.getZExtValue()));
18468 return Sub(N0, Shl(MVal, CVN.getZExtValue()));
18469 }
18470 }
18471 } else {
18472 // (mul x, -(2^N - 1)) => (sub x, (shl x, N))
18473 // (mul x, -(2^N + 1)) => - (add (shl x, N), x)
18474 // (mul x, -(2^(N-M) - 1) * 2^M) => (sub (shl x, M), (shl x, N))
18475 APInt SCVPlus1 = -ShiftedConstValue + 1;
18476 APInt CVNegPlus1 = -ConstValue + 1;
18477 APInt CVNegMinus1 = -ConstValue - 1;
18478 if (CVNegPlus1.isPowerOf2()) {
18479 ShiftAmt = CVNegPlus1.logBase2();
18480 return Sub(N0, Shl(N0, ShiftAmt));
18481 } else if (CVNegMinus1.isPowerOf2()) {
18482 ShiftAmt = CVNegMinus1.logBase2();
18483 return Negate(Add(Shl(N0, ShiftAmt), N0));
18484 } else if (SCVPlus1.isPowerOf2()) {
18485 ShiftAmt = SCVPlus1.logBase2() + TrailingZeroes;
18486 return Sub(Shl(N0, TrailingZeroes), Shl(N0, ShiftAmt));
18487 }
18488 }
18489
18490 return SDValue();
18491 }
18492
performVectorCompareAndMaskUnaryOpCombine(SDNode * N,SelectionDAG & DAG)18493 static SDValue performVectorCompareAndMaskUnaryOpCombine(SDNode *N,
18494 SelectionDAG &DAG) {
18495 // Take advantage of vector comparisons producing 0 or -1 in each lane to
18496 // optimize away operation when it's from a constant.
18497 //
18498 // The general transformation is:
18499 // UNARYOP(AND(VECTOR_CMP(x,y), constant)) -->
18500 // AND(VECTOR_CMP(x,y), constant2)
18501 // constant2 = UNARYOP(constant)
18502
18503 // Early exit if this isn't a vector operation, the operand of the
18504 // unary operation isn't a bitwise AND, or if the sizes of the operations
18505 // aren't the same.
18506 EVT VT = N->getValueType(0);
18507 if (!VT.isVector() || N->getOperand(0)->getOpcode() != ISD::AND ||
18508 N->getOperand(0)->getOperand(0)->getOpcode() != ISD::SETCC ||
18509 VT.getSizeInBits() != N->getOperand(0)->getValueType(0).getSizeInBits())
18510 return SDValue();
18511
18512 // Now check that the other operand of the AND is a constant. We could
18513 // make the transformation for non-constant splats as well, but it's unclear
18514 // that would be a benefit as it would not eliminate any operations, just
18515 // perform one more step in scalar code before moving to the vector unit.
18516 if (BuildVectorSDNode *BV =
18517 dyn_cast<BuildVectorSDNode>(N->getOperand(0)->getOperand(1))) {
18518 // Bail out if the vector isn't a constant.
18519 if (!BV->isConstant())
18520 return SDValue();
18521
18522 // Everything checks out. Build up the new and improved node.
18523 SDLoc DL(N);
18524 EVT IntVT = BV->getValueType(0);
18525 // Create a new constant of the appropriate type for the transformed
18526 // DAG.
18527 SDValue SourceConst = DAG.getNode(N->getOpcode(), DL, VT, SDValue(BV, 0));
18528 // The AND node needs bitcasts to/from an integer vector type around it.
18529 SDValue MaskConst = DAG.getNode(ISD::BITCAST, DL, IntVT, SourceConst);
18530 SDValue NewAnd = DAG.getNode(ISD::AND, DL, IntVT,
18531 N->getOperand(0)->getOperand(0), MaskConst);
18532 SDValue Res = DAG.getNode(ISD::BITCAST, DL, VT, NewAnd);
18533 return Res;
18534 }
18535
18536 return SDValue();
18537 }
18538
performIntToFpCombine(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)18539 static SDValue performIntToFpCombine(SDNode *N, SelectionDAG &DAG,
18540 const AArch64Subtarget *Subtarget) {
18541 // First try to optimize away the conversion when it's conditionally from
18542 // a constant. Vectors only.
18543 if (SDValue Res = performVectorCompareAndMaskUnaryOpCombine(N, DAG))
18544 return Res;
18545
18546 EVT VT = N->getValueType(0);
18547 if (VT != MVT::f32 && VT != MVT::f64)
18548 return SDValue();
18549
18550 // Only optimize when the source and destination types have the same width.
18551 if (VT.getSizeInBits() != N->getOperand(0).getValueSizeInBits())
18552 return SDValue();
18553
18554 // If the result of an integer load is only used by an integer-to-float
18555 // conversion, use a fp load instead and a AdvSIMD scalar {S|U}CVTF instead.
18556 // This eliminates an "integer-to-vector-move" UOP and improves throughput.
18557 SDValue N0 = N->getOperand(0);
18558 if (Subtarget->isNeonAvailable() && ISD::isNormalLoad(N0.getNode()) &&
18559 N0.hasOneUse() &&
18560 // Do not change the width of a volatile load.
18561 !cast<LoadSDNode>(N0)->isVolatile()) {
18562 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
18563 SDValue Load = DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
18564 LN0->getPointerInfo(), LN0->getAlign(),
18565 LN0->getMemOperand()->getFlags());
18566
18567 // Make sure successors of the original load stay after it by updating them
18568 // to use the new Chain.
18569 DAG.ReplaceAllUsesOfValueWith(SDValue(LN0, 1), Load.getValue(1));
18570
18571 unsigned Opcode =
18572 (N->getOpcode() == ISD::SINT_TO_FP) ? AArch64ISD::SITOF : AArch64ISD::UITOF;
18573 return DAG.getNode(Opcode, SDLoc(N), VT, Load);
18574 }
18575
18576 return SDValue();
18577 }
18578
18579 /// Fold a floating-point multiply by power of two into floating-point to
18580 /// fixed-point conversion.
performFpToIntCombine(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)18581 static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG,
18582 TargetLowering::DAGCombinerInfo &DCI,
18583 const AArch64Subtarget *Subtarget) {
18584 if (!Subtarget->isNeonAvailable())
18585 return SDValue();
18586
18587 if (!N->getValueType(0).isSimple())
18588 return SDValue();
18589
18590 SDValue Op = N->getOperand(0);
18591 if (!Op.getValueType().isSimple() || Op.getOpcode() != ISD::FMUL)
18592 return SDValue();
18593
18594 if (!Op.getValueType().is64BitVector() && !Op.getValueType().is128BitVector())
18595 return SDValue();
18596
18597 SDValue ConstVec = Op->getOperand(1);
18598 if (!isa<BuildVectorSDNode>(ConstVec))
18599 return SDValue();
18600
18601 MVT FloatTy = Op.getSimpleValueType().getVectorElementType();
18602 uint32_t FloatBits = FloatTy.getSizeInBits();
18603 if (FloatBits != 32 && FloatBits != 64 &&
18604 (FloatBits != 16 || !Subtarget->hasFullFP16()))
18605 return SDValue();
18606
18607 MVT IntTy = N->getSimpleValueType(0).getVectorElementType();
18608 uint32_t IntBits = IntTy.getSizeInBits();
18609 if (IntBits != 16 && IntBits != 32 && IntBits != 64)
18610 return SDValue();
18611
18612 // Avoid conversions where iN is larger than the float (e.g., float -> i64).
18613 if (IntBits > FloatBits)
18614 return SDValue();
18615
18616 BitVector UndefElements;
18617 BuildVectorSDNode *BV = cast<BuildVectorSDNode>(ConstVec);
18618 int32_t Bits = IntBits == 64 ? 64 : 32;
18619 int32_t C = BV->getConstantFPSplatPow2ToLog2Int(&UndefElements, Bits + 1);
18620 if (C == -1 || C == 0 || C > Bits)
18621 return SDValue();
18622
18623 EVT ResTy = Op.getValueType().changeVectorElementTypeToInteger();
18624 if (!DAG.getTargetLoweringInfo().isTypeLegal(ResTy))
18625 return SDValue();
18626
18627 if (N->getOpcode() == ISD::FP_TO_SINT_SAT ||
18628 N->getOpcode() == ISD::FP_TO_UINT_SAT) {
18629 EVT SatVT = cast<VTSDNode>(N->getOperand(1))->getVT();
18630 if (SatVT.getScalarSizeInBits() != IntBits || IntBits != FloatBits)
18631 return SDValue();
18632 }
18633
18634 SDLoc DL(N);
18635 bool IsSigned = (N->getOpcode() == ISD::FP_TO_SINT ||
18636 N->getOpcode() == ISD::FP_TO_SINT_SAT);
18637 unsigned IntrinsicOpcode = IsSigned ? Intrinsic::aarch64_neon_vcvtfp2fxs
18638 : Intrinsic::aarch64_neon_vcvtfp2fxu;
18639 SDValue FixConv =
18640 DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ResTy,
18641 DAG.getConstant(IntrinsicOpcode, DL, MVT::i32),
18642 Op->getOperand(0), DAG.getConstant(C, DL, MVT::i32));
18643 // We can handle smaller integers by generating an extra trunc.
18644 if (IntBits < FloatBits)
18645 FixConv = DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), FixConv);
18646
18647 return FixConv;
18648 }
18649
tryCombineToBSL(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const AArch64TargetLowering & TLI)18650 static SDValue tryCombineToBSL(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
18651 const AArch64TargetLowering &TLI) {
18652 EVT VT = N->getValueType(0);
18653 SelectionDAG &DAG = DCI.DAG;
18654 SDLoc DL(N);
18655 const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
18656
18657 if (!VT.isVector())
18658 return SDValue();
18659
18660 if (VT.isScalableVector() && !Subtarget.hasSVE2())
18661 return SDValue();
18662
18663 if (VT.isFixedLengthVector() &&
18664 (!Subtarget.isNeonAvailable() || TLI.useSVEForFixedLengthVectorVT(VT)))
18665 return SDValue();
18666
18667 SDValue N0 = N->getOperand(0);
18668 if (N0.getOpcode() != ISD::AND)
18669 return SDValue();
18670
18671 SDValue N1 = N->getOperand(1);
18672 if (N1.getOpcode() != ISD::AND)
18673 return SDValue();
18674
18675 // InstCombine does (not (neg a)) => (add a -1).
18676 // Try: (or (and (neg a) b) (and (add a -1) c)) => (bsl (neg a) b c)
18677 // Loop over all combinations of AND operands.
18678 for (int i = 1; i >= 0; --i) {
18679 for (int j = 1; j >= 0; --j) {
18680 SDValue O0 = N0->getOperand(i);
18681 SDValue O1 = N1->getOperand(j);
18682 SDValue Sub, Add, SubSibling, AddSibling;
18683
18684 // Find a SUB and an ADD operand, one from each AND.
18685 if (O0.getOpcode() == ISD::SUB && O1.getOpcode() == ISD::ADD) {
18686 Sub = O0;
18687 Add = O1;
18688 SubSibling = N0->getOperand(1 - i);
18689 AddSibling = N1->getOperand(1 - j);
18690 } else if (O0.getOpcode() == ISD::ADD && O1.getOpcode() == ISD::SUB) {
18691 Add = O0;
18692 Sub = O1;
18693 AddSibling = N0->getOperand(1 - i);
18694 SubSibling = N1->getOperand(1 - j);
18695 } else
18696 continue;
18697
18698 if (!ISD::isConstantSplatVectorAllZeros(Sub.getOperand(0).getNode()))
18699 continue;
18700
18701 // Constant ones is always righthand operand of the Add.
18702 if (!ISD::isConstantSplatVectorAllOnes(Add.getOperand(1).getNode()))
18703 continue;
18704
18705 if (Sub.getOperand(1) != Add.getOperand(0))
18706 continue;
18707
18708 return DAG.getNode(AArch64ISD::BSP, DL, VT, Sub, SubSibling, AddSibling);
18709 }
18710 }
18711
18712 // (or (and a b) (and (not a) c)) => (bsl a b c)
18713 // We only have to look for constant vectors here since the general, variable
18714 // case can be handled in TableGen.
18715 unsigned Bits = VT.getScalarSizeInBits();
18716 uint64_t BitMask = Bits == 64 ? -1ULL : ((1ULL << Bits) - 1);
18717 for (int i = 1; i >= 0; --i)
18718 for (int j = 1; j >= 0; --j) {
18719 APInt Val1, Val2;
18720
18721 if (ISD::isConstantSplatVector(N0->getOperand(i).getNode(), Val1) &&
18722 ISD::isConstantSplatVector(N1->getOperand(j).getNode(), Val2) &&
18723 (BitMask & ~Val1.getZExtValue()) == Val2.getZExtValue()) {
18724 return DAG.getNode(AArch64ISD::BSP, DL, VT, N0->getOperand(i),
18725 N0->getOperand(1 - i), N1->getOperand(1 - j));
18726 }
18727 BuildVectorSDNode *BVN0 = dyn_cast<BuildVectorSDNode>(N0->getOperand(i));
18728 BuildVectorSDNode *BVN1 = dyn_cast<BuildVectorSDNode>(N1->getOperand(j));
18729 if (!BVN0 || !BVN1)
18730 continue;
18731
18732 bool FoundMatch = true;
18733 for (unsigned k = 0; k < VT.getVectorNumElements(); ++k) {
18734 ConstantSDNode *CN0 = dyn_cast<ConstantSDNode>(BVN0->getOperand(k));
18735 ConstantSDNode *CN1 = dyn_cast<ConstantSDNode>(BVN1->getOperand(k));
18736 if (!CN0 || !CN1 ||
18737 CN0->getZExtValue() != (BitMask & ~CN1->getZExtValue())) {
18738 FoundMatch = false;
18739 break;
18740 }
18741 }
18742 if (FoundMatch)
18743 return DAG.getNode(AArch64ISD::BSP, DL, VT, N0->getOperand(i),
18744 N0->getOperand(1 - i), N1->getOperand(1 - j));
18745 }
18746
18747 return SDValue();
18748 }
18749
18750 // Given a tree of and/or(csel(0, 1, cc0), csel(0, 1, cc1)), we may be able to
18751 // convert to csel(ccmp(.., cc0)), depending on cc1:
18752
18753 // (AND (CSET cc0 cmp0) (CSET cc1 (CMP x1 y1)))
18754 // =>
18755 // (CSET cc1 (CCMP x1 y1 !cc1 cc0 cmp0))
18756 //
18757 // (OR (CSET cc0 cmp0) (CSET cc1 (CMP x1 y1)))
18758 // =>
18759 // (CSET cc1 (CCMP x1 y1 cc1 !cc0 cmp0))
performANDORCSELCombine(SDNode * N,SelectionDAG & DAG)18760 static SDValue performANDORCSELCombine(SDNode *N, SelectionDAG &DAG) {
18761 EVT VT = N->getValueType(0);
18762 SDValue CSel0 = N->getOperand(0);
18763 SDValue CSel1 = N->getOperand(1);
18764
18765 if (CSel0.getOpcode() != AArch64ISD::CSEL ||
18766 CSel1.getOpcode() != AArch64ISD::CSEL)
18767 return SDValue();
18768
18769 if (!CSel0->hasOneUse() || !CSel1->hasOneUse())
18770 return SDValue();
18771
18772 if (!isNullConstant(CSel0.getOperand(0)) ||
18773 !isOneConstant(CSel0.getOperand(1)) ||
18774 !isNullConstant(CSel1.getOperand(0)) ||
18775 !isOneConstant(CSel1.getOperand(1)))
18776 return SDValue();
18777
18778 SDValue Cmp0 = CSel0.getOperand(3);
18779 SDValue Cmp1 = CSel1.getOperand(3);
18780 AArch64CC::CondCode CC0 = (AArch64CC::CondCode)CSel0.getConstantOperandVal(2);
18781 AArch64CC::CondCode CC1 = (AArch64CC::CondCode)CSel1.getConstantOperandVal(2);
18782 if (!Cmp0->hasOneUse() || !Cmp1->hasOneUse())
18783 return SDValue();
18784 if (Cmp1.getOpcode() != AArch64ISD::SUBS &&
18785 Cmp0.getOpcode() == AArch64ISD::SUBS) {
18786 std::swap(Cmp0, Cmp1);
18787 std::swap(CC0, CC1);
18788 }
18789
18790 if (Cmp1.getOpcode() != AArch64ISD::SUBS)
18791 return SDValue();
18792
18793 SDLoc DL(N);
18794 SDValue CCmp, Condition;
18795 unsigned NZCV;
18796
18797 if (N->getOpcode() == ISD::AND) {
18798 AArch64CC::CondCode InvCC0 = AArch64CC::getInvertedCondCode(CC0);
18799 Condition = DAG.getConstant(InvCC0, DL, MVT_CC);
18800 NZCV = AArch64CC::getNZCVToSatisfyCondCode(CC1);
18801 } else {
18802 AArch64CC::CondCode InvCC1 = AArch64CC::getInvertedCondCode(CC1);
18803 Condition = DAG.getConstant(CC0, DL, MVT_CC);
18804 NZCV = AArch64CC::getNZCVToSatisfyCondCode(InvCC1);
18805 }
18806
18807 SDValue NZCVOp = DAG.getConstant(NZCV, DL, MVT::i32);
18808
18809 auto *Op1 = dyn_cast<ConstantSDNode>(Cmp1.getOperand(1));
18810 if (Op1 && Op1->getAPIntValue().isNegative() &&
18811 Op1->getAPIntValue().sgt(-32)) {
18812 // CCMP accept the constant int the range [0, 31]
18813 // if the Op1 is a constant in the range [-31, -1], we
18814 // can select to CCMN to avoid the extra mov
18815 SDValue AbsOp1 =
18816 DAG.getConstant(Op1->getAPIntValue().abs(), DL, Op1->getValueType(0));
18817 CCmp = DAG.getNode(AArch64ISD::CCMN, DL, MVT_CC, Cmp1.getOperand(0), AbsOp1,
18818 NZCVOp, Condition, Cmp0);
18819 } else {
18820 CCmp = DAG.getNode(AArch64ISD::CCMP, DL, MVT_CC, Cmp1.getOperand(0),
18821 Cmp1.getOperand(1), NZCVOp, Condition, Cmp0);
18822 }
18823 return DAG.getNode(AArch64ISD::CSEL, DL, VT, CSel0.getOperand(0),
18824 CSel0.getOperand(1), DAG.getConstant(CC1, DL, MVT::i32),
18825 CCmp);
18826 }
18827
performORCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget,const AArch64TargetLowering & TLI)18828 static SDValue performORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
18829 const AArch64Subtarget *Subtarget,
18830 const AArch64TargetLowering &TLI) {
18831 SelectionDAG &DAG = DCI.DAG;
18832 EVT VT = N->getValueType(0);
18833
18834 if (SDValue R = performANDORCSELCombine(N, DAG))
18835 return R;
18836
18837 if (!DAG.getTargetLoweringInfo().isTypeLegal(VT))
18838 return SDValue();
18839
18840 if (SDValue Res = tryCombineToBSL(N, DCI, TLI))
18841 return Res;
18842
18843 return SDValue();
18844 }
18845
isConstantSplatVectorMaskForType(SDNode * N,EVT MemVT)18846 static bool isConstantSplatVectorMaskForType(SDNode *N, EVT MemVT) {
18847 if (!MemVT.getVectorElementType().isSimple())
18848 return false;
18849
18850 uint64_t MaskForTy = 0ull;
18851 switch (MemVT.getVectorElementType().getSimpleVT().SimpleTy) {
18852 case MVT::i8:
18853 MaskForTy = 0xffull;
18854 break;
18855 case MVT::i16:
18856 MaskForTy = 0xffffull;
18857 break;
18858 case MVT::i32:
18859 MaskForTy = 0xffffffffull;
18860 break;
18861 default:
18862 return false;
18863 break;
18864 }
18865
18866 if (N->getOpcode() == AArch64ISD::DUP || N->getOpcode() == ISD::SPLAT_VECTOR)
18867 if (auto *Op0 = dyn_cast<ConstantSDNode>(N->getOperand(0)))
18868 return Op0->getAPIntValue().getLimitedValue() == MaskForTy;
18869
18870 return false;
18871 }
18872
performReinterpretCastCombine(SDNode * N)18873 static SDValue performReinterpretCastCombine(SDNode *N) {
18874 SDValue LeafOp = SDValue(N, 0);
18875 SDValue Op = N->getOperand(0);
18876 while (Op.getOpcode() == AArch64ISD::REINTERPRET_CAST &&
18877 LeafOp.getValueType() != Op.getValueType())
18878 Op = Op->getOperand(0);
18879 if (LeafOp.getValueType() == Op.getValueType())
18880 return Op;
18881 return SDValue();
18882 }
18883
performSVEAndCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)18884 static SDValue performSVEAndCombine(SDNode *N,
18885 TargetLowering::DAGCombinerInfo &DCI) {
18886 SelectionDAG &DAG = DCI.DAG;
18887 SDValue Src = N->getOperand(0);
18888 unsigned Opc = Src->getOpcode();
18889
18890 // Zero/any extend of an unsigned unpack
18891 if (Opc == AArch64ISD::UUNPKHI || Opc == AArch64ISD::UUNPKLO) {
18892 SDValue UnpkOp = Src->getOperand(0);
18893 SDValue Dup = N->getOperand(1);
18894
18895 if (Dup.getOpcode() != ISD::SPLAT_VECTOR)
18896 return SDValue();
18897
18898 SDLoc DL(N);
18899 ConstantSDNode *C = dyn_cast<ConstantSDNode>(Dup->getOperand(0));
18900 if (!C)
18901 return SDValue();
18902
18903 uint64_t ExtVal = C->getZExtValue();
18904
18905 auto MaskAndTypeMatch = [ExtVal](EVT VT) -> bool {
18906 return ((ExtVal == 0xFF && VT == MVT::i8) ||
18907 (ExtVal == 0xFFFF && VT == MVT::i16) ||
18908 (ExtVal == 0xFFFFFFFF && VT == MVT::i32));
18909 };
18910
18911 // If the mask is fully covered by the unpack, we don't need to push
18912 // a new AND onto the operand
18913 EVT EltTy = UnpkOp->getValueType(0).getVectorElementType();
18914 if (MaskAndTypeMatch(EltTy))
18915 return Src;
18916
18917 // If this is 'and (uunpklo/hi (extload MemTy -> ExtTy)), mask', then check
18918 // to see if the mask is all-ones of size MemTy.
18919 auto MaskedLoadOp = dyn_cast<MaskedLoadSDNode>(UnpkOp);
18920 if (MaskedLoadOp && (MaskedLoadOp->getExtensionType() == ISD::ZEXTLOAD ||
18921 MaskedLoadOp->getExtensionType() == ISD::EXTLOAD)) {
18922 EVT EltTy = MaskedLoadOp->getMemoryVT().getVectorElementType();
18923 if (MaskAndTypeMatch(EltTy))
18924 return Src;
18925 }
18926
18927 // Truncate to prevent a DUP with an over wide constant
18928 APInt Mask = C->getAPIntValue().trunc(EltTy.getSizeInBits());
18929
18930 // Otherwise, make sure we propagate the AND to the operand
18931 // of the unpack
18932 Dup = DAG.getNode(ISD::SPLAT_VECTOR, DL, UnpkOp->getValueType(0),
18933 DAG.getConstant(Mask.zextOrTrunc(32), DL, MVT::i32));
18934
18935 SDValue And = DAG.getNode(ISD::AND, DL,
18936 UnpkOp->getValueType(0), UnpkOp, Dup);
18937
18938 return DAG.getNode(Opc, DL, N->getValueType(0), And);
18939 }
18940
18941 if (DCI.isBeforeLegalizeOps())
18942 return SDValue();
18943
18944 // If both sides of AND operations are i1 splat_vectors then
18945 // we can produce just i1 splat_vector as the result.
18946 if (isAllActivePredicate(DAG, N->getOperand(0)))
18947 return N->getOperand(1);
18948 if (isAllActivePredicate(DAG, N->getOperand(1)))
18949 return N->getOperand(0);
18950
18951 if (!EnableCombineMGatherIntrinsics)
18952 return SDValue();
18953
18954 SDValue Mask = N->getOperand(1);
18955
18956 if (!Src.hasOneUse())
18957 return SDValue();
18958
18959 EVT MemVT;
18960
18961 // SVE load instructions perform an implicit zero-extend, which makes them
18962 // perfect candidates for combining.
18963 switch (Opc) {
18964 case AArch64ISD::LD1_MERGE_ZERO:
18965 case AArch64ISD::LDNF1_MERGE_ZERO:
18966 case AArch64ISD::LDFF1_MERGE_ZERO:
18967 MemVT = cast<VTSDNode>(Src->getOperand(3))->getVT();
18968 break;
18969 case AArch64ISD::GLD1_MERGE_ZERO:
18970 case AArch64ISD::GLD1_SCALED_MERGE_ZERO:
18971 case AArch64ISD::GLD1_SXTW_MERGE_ZERO:
18972 case AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO:
18973 case AArch64ISD::GLD1_UXTW_MERGE_ZERO:
18974 case AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO:
18975 case AArch64ISD::GLD1_IMM_MERGE_ZERO:
18976 case AArch64ISD::GLDFF1_MERGE_ZERO:
18977 case AArch64ISD::GLDFF1_SCALED_MERGE_ZERO:
18978 case AArch64ISD::GLDFF1_SXTW_MERGE_ZERO:
18979 case AArch64ISD::GLDFF1_SXTW_SCALED_MERGE_ZERO:
18980 case AArch64ISD::GLDFF1_UXTW_MERGE_ZERO:
18981 case AArch64ISD::GLDFF1_UXTW_SCALED_MERGE_ZERO:
18982 case AArch64ISD::GLDFF1_IMM_MERGE_ZERO:
18983 case AArch64ISD::GLDNT1_MERGE_ZERO:
18984 MemVT = cast<VTSDNode>(Src->getOperand(4))->getVT();
18985 break;
18986 default:
18987 return SDValue();
18988 }
18989
18990 if (isConstantSplatVectorMaskForType(Mask.getNode(), MemVT))
18991 return Src;
18992
18993 return SDValue();
18994 }
18995
18996 // Transform and(fcmp(a, b), fcmp(c, d)) into fccmp(fcmp(a, b), c, d)
performANDSETCCCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)18997 static SDValue performANDSETCCCombine(SDNode *N,
18998 TargetLowering::DAGCombinerInfo &DCI) {
18999
19000 // This function performs an optimization on a specific pattern involving
19001 // an AND operation and SETCC (Set Condition Code) node.
19002
19003 SDValue SetCC = N->getOperand(0);
19004 EVT VT = N->getValueType(0);
19005 SelectionDAG &DAG = DCI.DAG;
19006
19007 // Checks if the current node (N) is used by any SELECT instruction and
19008 // returns an empty SDValue to avoid applying the optimization to prevent
19009 // incorrect results
19010 for (auto U : N->uses())
19011 if (U->getOpcode() == ISD::SELECT)
19012 return SDValue();
19013
19014 // Check if the operand is a SETCC node with floating-point comparison
19015 if (SetCC.getOpcode() == ISD::SETCC &&
19016 SetCC.getOperand(0).getValueType() == MVT::f32) {
19017
19018 SDValue Cmp;
19019 AArch64CC::CondCode CC;
19020
19021 // Check if the DAG is after legalization and if we can emit the conjunction
19022 if (!DCI.isBeforeLegalize() &&
19023 (Cmp = emitConjunction(DAG, SDValue(N, 0), CC))) {
19024
19025 AArch64CC::CondCode InvertedCC = AArch64CC::getInvertedCondCode(CC);
19026
19027 SDLoc DL(N);
19028 return DAG.getNode(AArch64ISD::CSINC, DL, VT, DAG.getConstant(0, DL, VT),
19029 DAG.getConstant(0, DL, VT),
19030 DAG.getConstant(InvertedCC, DL, MVT::i32), Cmp);
19031 }
19032 }
19033 return SDValue();
19034 }
19035
performANDCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)19036 static SDValue performANDCombine(SDNode *N,
19037 TargetLowering::DAGCombinerInfo &DCI) {
19038 SelectionDAG &DAG = DCI.DAG;
19039 SDValue LHS = N->getOperand(0);
19040 SDValue RHS = N->getOperand(1);
19041 EVT VT = N->getValueType(0);
19042
19043 if (SDValue R = performANDORCSELCombine(N, DAG))
19044 return R;
19045
19046 if (SDValue R = performANDSETCCCombine(N,DCI))
19047 return R;
19048
19049 if (!DAG.getTargetLoweringInfo().isTypeLegal(VT))
19050 return SDValue();
19051
19052 if (VT.isScalableVector())
19053 return performSVEAndCombine(N, DCI);
19054
19055 // The combining code below works only for NEON vectors. In particular, it
19056 // does not work for SVE when dealing with vectors wider than 128 bits.
19057 if (!VT.is64BitVector() && !VT.is128BitVector())
19058 return SDValue();
19059
19060 BuildVectorSDNode *BVN = dyn_cast<BuildVectorSDNode>(RHS.getNode());
19061 if (!BVN)
19062 return SDValue();
19063
19064 // AND does not accept an immediate, so check if we can use a BIC immediate
19065 // instruction instead. We do this here instead of using a (and x, (mvni imm))
19066 // pattern in isel, because some immediates may be lowered to the preferred
19067 // (and x, (movi imm)) form, even though an mvni representation also exists.
19068 APInt DefBits(VT.getSizeInBits(), 0);
19069 APInt UndefBits(VT.getSizeInBits(), 0);
19070 if (resolveBuildVector(BVN, DefBits, UndefBits)) {
19071 SDValue NewOp;
19072
19073 // Any bits known to already be 0 need not be cleared again, which can help
19074 // reduce the size of the immediate to one supported by the instruction.
19075 KnownBits Known = DAG.computeKnownBits(LHS);
19076 APInt ZeroSplat(VT.getSizeInBits(), 0);
19077 for (unsigned I = 0; I < VT.getSizeInBits() / Known.Zero.getBitWidth(); I++)
19078 ZeroSplat |= Known.Zero.zext(VT.getSizeInBits())
19079 << (Known.Zero.getBitWidth() * I);
19080
19081 DefBits = ~(DefBits | ZeroSplat);
19082 if ((NewOp = tryAdvSIMDModImm32(AArch64ISD::BICi, SDValue(N, 0), DAG,
19083 DefBits, &LHS)) ||
19084 (NewOp = tryAdvSIMDModImm16(AArch64ISD::BICi, SDValue(N, 0), DAG,
19085 DefBits, &LHS)))
19086 return NewOp;
19087
19088 UndefBits = ~(UndefBits | ZeroSplat);
19089 if ((NewOp = tryAdvSIMDModImm32(AArch64ISD::BICi, SDValue(N, 0), DAG,
19090 UndefBits, &LHS)) ||
19091 (NewOp = tryAdvSIMDModImm16(AArch64ISD::BICi, SDValue(N, 0), DAG,
19092 UndefBits, &LHS)))
19093 return NewOp;
19094 }
19095
19096 return SDValue();
19097 }
19098
performFADDCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)19099 static SDValue performFADDCombine(SDNode *N,
19100 TargetLowering::DAGCombinerInfo &DCI) {
19101 SelectionDAG &DAG = DCI.DAG;
19102 SDValue LHS = N->getOperand(0);
19103 SDValue RHS = N->getOperand(1);
19104 EVT VT = N->getValueType(0);
19105 SDLoc DL(N);
19106
19107 if (!N->getFlags().hasAllowReassociation())
19108 return SDValue();
19109
19110 // Combine fadd(a, vcmla(b, c, d)) -> vcmla(fadd(a, b), b, c)
19111 auto ReassocComplex = [&](SDValue A, SDValue B) {
19112 if (A.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
19113 return SDValue();
19114 unsigned Opc = A.getConstantOperandVal(0);
19115 if (Opc != Intrinsic::aarch64_neon_vcmla_rot0 &&
19116 Opc != Intrinsic::aarch64_neon_vcmla_rot90 &&
19117 Opc != Intrinsic::aarch64_neon_vcmla_rot180 &&
19118 Opc != Intrinsic::aarch64_neon_vcmla_rot270)
19119 return SDValue();
19120 SDValue VCMLA = DAG.getNode(
19121 ISD::INTRINSIC_WO_CHAIN, DL, VT, A.getOperand(0),
19122 DAG.getNode(ISD::FADD, DL, VT, A.getOperand(1), B, N->getFlags()),
19123 A.getOperand(2), A.getOperand(3));
19124 VCMLA->setFlags(A->getFlags());
19125 return VCMLA;
19126 };
19127 if (SDValue R = ReassocComplex(LHS, RHS))
19128 return R;
19129 if (SDValue R = ReassocComplex(RHS, LHS))
19130 return R;
19131
19132 return SDValue();
19133 }
19134
hasPairwiseAdd(unsigned Opcode,EVT VT,bool FullFP16)19135 static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) {
19136 switch (Opcode) {
19137 case ISD::STRICT_FADD:
19138 case ISD::FADD:
19139 return (FullFP16 && VT == MVT::f16) || VT == MVT::f32 || VT == MVT::f64;
19140 case ISD::ADD:
19141 return VT == MVT::i64;
19142 default:
19143 return false;
19144 }
19145 }
19146
19147 static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
19148 AArch64CC::CondCode Cond);
19149
isPredicateCCSettingOp(SDValue N)19150 static bool isPredicateCCSettingOp(SDValue N) {
19151 if ((N.getOpcode() == ISD::SETCC) ||
19152 (N.getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
19153 (N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilege ||
19154 N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilegt ||
19155 N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehi ||
19156 N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilehs ||
19157 N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilele ||
19158 N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelo ||
19159 N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilels ||
19160 N.getConstantOperandVal(0) == Intrinsic::aarch64_sve_whilelt ||
19161 // get_active_lane_mask is lowered to a whilelo instruction.
19162 N.getConstantOperandVal(0) == Intrinsic::get_active_lane_mask)))
19163 return true;
19164
19165 return false;
19166 }
19167
19168 // Materialize : i1 = extract_vector_elt t37, Constant:i64<0>
19169 // ... into: "ptrue p, all" + PTEST
19170 static SDValue
performFirstTrueTestVectorCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)19171 performFirstTrueTestVectorCombine(SDNode *N,
19172 TargetLowering::DAGCombinerInfo &DCI,
19173 const AArch64Subtarget *Subtarget) {
19174 assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);
19175 // Make sure PTEST can be legalised with illegal types.
19176 if (!Subtarget->hasSVE() || DCI.isBeforeLegalize())
19177 return SDValue();
19178
19179 SDValue N0 = N->getOperand(0);
19180 EVT VT = N0.getValueType();
19181
19182 if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1 ||
19183 !isNullConstant(N->getOperand(1)))
19184 return SDValue();
19185
19186 // Restricted the DAG combine to only cases where we're extracting from a
19187 // flag-setting operation.
19188 if (!isPredicateCCSettingOp(N0))
19189 return SDValue();
19190
19191 // Extracts of lane 0 for SVE can be expressed as PTEST(Op, FIRST) ? 1 : 0
19192 SelectionDAG &DAG = DCI.DAG;
19193 SDValue Pg = getPTrue(DAG, SDLoc(N), VT, AArch64SVEPredPattern::all);
19194 return getPTest(DAG, N->getValueType(0), Pg, N0, AArch64CC::FIRST_ACTIVE);
19195 }
19196
19197 // Materialize : Idx = (add (mul vscale, NumEls), -1)
19198 // i1 = extract_vector_elt t37, Constant:i64<Idx>
19199 // ... into: "ptrue p, all" + PTEST
19200 static SDValue
performLastTrueTestVectorCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)19201 performLastTrueTestVectorCombine(SDNode *N,
19202 TargetLowering::DAGCombinerInfo &DCI,
19203 const AArch64Subtarget *Subtarget) {
19204 assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);
19205 // Make sure PTEST is legal types.
19206 if (!Subtarget->hasSVE() || DCI.isBeforeLegalize())
19207 return SDValue();
19208
19209 SDValue N0 = N->getOperand(0);
19210 EVT OpVT = N0.getValueType();
19211
19212 if (!OpVT.isScalableVector() || OpVT.getVectorElementType() != MVT::i1)
19213 return SDValue();
19214
19215 // Idx == (add (mul vscale, NumEls), -1)
19216 SDValue Idx = N->getOperand(1);
19217 if (Idx.getOpcode() != ISD::ADD || !isAllOnesConstant(Idx.getOperand(1)))
19218 return SDValue();
19219
19220 SDValue VS = Idx.getOperand(0);
19221 if (VS.getOpcode() != ISD::VSCALE)
19222 return SDValue();
19223
19224 unsigned NumEls = OpVT.getVectorElementCount().getKnownMinValue();
19225 if (VS.getConstantOperandVal(0) != NumEls)
19226 return SDValue();
19227
19228 // Extracts of lane EC-1 for SVE can be expressed as PTEST(Op, LAST) ? 1 : 0
19229 SelectionDAG &DAG = DCI.DAG;
19230 SDValue Pg = getPTrue(DAG, SDLoc(N), OpVT, AArch64SVEPredPattern::all);
19231 return getPTest(DAG, N->getValueType(0), Pg, N0, AArch64CC::LAST_ACTIVE);
19232 }
19233
19234 static SDValue
performExtractVectorEltCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)19235 performExtractVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
19236 const AArch64Subtarget *Subtarget) {
19237 assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT);
19238 if (SDValue Res = performFirstTrueTestVectorCombine(N, DCI, Subtarget))
19239 return Res;
19240 if (SDValue Res = performLastTrueTestVectorCombine(N, DCI, Subtarget))
19241 return Res;
19242
19243 SelectionDAG &DAG = DCI.DAG;
19244 SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
19245
19246 EVT VT = N->getValueType(0);
19247 const bool FullFP16 = DAG.getSubtarget<AArch64Subtarget>().hasFullFP16();
19248 bool IsStrict = N0->isStrictFPOpcode();
19249
19250 // extract(dup x) -> x
19251 if (N0.getOpcode() == AArch64ISD::DUP)
19252 return VT.isInteger() ? DAG.getZExtOrTrunc(N0.getOperand(0), SDLoc(N), VT)
19253 : N0.getOperand(0);
19254
19255 // Rewrite for pairwise fadd pattern
19256 // (f32 (extract_vector_elt
19257 // (fadd (vXf32 Other)
19258 // (vector_shuffle (vXf32 Other) undef <1,X,...> )) 0))
19259 // ->
19260 // (f32 (fadd (extract_vector_elt (vXf32 Other) 0)
19261 // (extract_vector_elt (vXf32 Other) 1))
19262 // For strict_fadd we need to make sure the old strict_fadd can be deleted, so
19263 // we can only do this when it's used only by the extract_vector_elt.
19264 if (isNullConstant(N1) && hasPairwiseAdd(N0->getOpcode(), VT, FullFP16) &&
19265 (!IsStrict || N0.hasOneUse())) {
19266 SDLoc DL(N0);
19267 SDValue N00 = N0->getOperand(IsStrict ? 1 : 0);
19268 SDValue N01 = N0->getOperand(IsStrict ? 2 : 1);
19269
19270 ShuffleVectorSDNode *Shuffle = dyn_cast<ShuffleVectorSDNode>(N01);
19271 SDValue Other = N00;
19272
19273 // And handle the commutative case.
19274 if (!Shuffle) {
19275 Shuffle = dyn_cast<ShuffleVectorSDNode>(N00);
19276 Other = N01;
19277 }
19278
19279 if (Shuffle && Shuffle->getMaskElt(0) == 1 &&
19280 Other == Shuffle->getOperand(0)) {
19281 SDValue Extract1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other,
19282 DAG.getConstant(0, DL, MVT::i64));
19283 SDValue Extract2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other,
19284 DAG.getConstant(1, DL, MVT::i64));
19285 if (!IsStrict)
19286 return DAG.getNode(N0->getOpcode(), DL, VT, Extract1, Extract2);
19287
19288 // For strict_fadd we need uses of the final extract_vector to be replaced
19289 // with the strict_fadd, but we also need uses of the chain output of the
19290 // original strict_fadd to use the chain output of the new strict_fadd as
19291 // otherwise it may not be deleted.
19292 SDValue Ret = DAG.getNode(N0->getOpcode(), DL,
19293 {VT, MVT::Other},
19294 {N0->getOperand(0), Extract1, Extract2});
19295 DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Ret);
19296 DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Ret.getValue(1));
19297 return SDValue(N, 0);
19298 }
19299 }
19300
19301 return SDValue();
19302 }
19303
performConcatVectorsCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)19304 static SDValue performConcatVectorsCombine(SDNode *N,
19305 TargetLowering::DAGCombinerInfo &DCI,
19306 SelectionDAG &DAG) {
19307 SDLoc dl(N);
19308 EVT VT = N->getValueType(0);
19309 SDValue N0 = N->getOperand(0), N1 = N->getOperand(1);
19310 unsigned N0Opc = N0->getOpcode(), N1Opc = N1->getOpcode();
19311
19312 if (VT.isScalableVector())
19313 return SDValue();
19314
19315 // Optimize concat_vectors of truncated vectors, where the intermediate
19316 // type is illegal, to avoid said illegality, e.g.,
19317 // (v4i16 (concat_vectors (v2i16 (truncate (v2i64))),
19318 // (v2i16 (truncate (v2i64)))))
19319 // ->
19320 // (v4i16 (truncate (vector_shuffle (v4i32 (bitcast (v2i64))),
19321 // (v4i32 (bitcast (v2i64))),
19322 // <0, 2, 4, 6>)))
19323 // This isn't really target-specific, but ISD::TRUNCATE legality isn't keyed
19324 // on both input and result type, so we might generate worse code.
19325 // On AArch64 we know it's fine for v2i64->v4i16 and v4i32->v8i8.
19326 if (N->getNumOperands() == 2 && N0Opc == ISD::TRUNCATE &&
19327 N1Opc == ISD::TRUNCATE) {
19328 SDValue N00 = N0->getOperand(0);
19329 SDValue N10 = N1->getOperand(0);
19330 EVT N00VT = N00.getValueType();
19331
19332 if (N00VT == N10.getValueType() &&
19333 (N00VT == MVT::v2i64 || N00VT == MVT::v4i32) &&
19334 N00VT.getScalarSizeInBits() == 4 * VT.getScalarSizeInBits()) {
19335 MVT MidVT = (N00VT == MVT::v2i64 ? MVT::v4i32 : MVT::v8i16);
19336 SmallVector<int, 8> Mask(MidVT.getVectorNumElements());
19337 for (size_t i = 0; i < Mask.size(); ++i)
19338 Mask[i] = i * 2;
19339 return DAG.getNode(ISD::TRUNCATE, dl, VT,
19340 DAG.getVectorShuffle(
19341 MidVT, dl,
19342 DAG.getNode(ISD::BITCAST, dl, MidVT, N00),
19343 DAG.getNode(ISD::BITCAST, dl, MidVT, N10), Mask));
19344 }
19345 }
19346
19347 if (N->getOperand(0).getValueType() == MVT::v4i8 ||
19348 N->getOperand(0).getValueType() == MVT::v2i16 ||
19349 N->getOperand(0).getValueType() == MVT::v2i8) {
19350 EVT SrcVT = N->getOperand(0).getValueType();
19351 // If we have a concat of v4i8 loads, convert them to a buildvector of f32
19352 // loads to prevent having to go through the v4i8 load legalization that
19353 // needs to extend each element into a larger type.
19354 if (N->getNumOperands() % 2 == 0 &&
19355 all_of(N->op_values(), [SrcVT](SDValue V) {
19356 if (V.getValueType() != SrcVT)
19357 return false;
19358 if (V.isUndef())
19359 return true;
19360 LoadSDNode *LD = dyn_cast<LoadSDNode>(V);
19361 return LD && V.hasOneUse() && LD->isSimple() && !LD->isIndexed() &&
19362 LD->getExtensionType() == ISD::NON_EXTLOAD;
19363 })) {
19364 EVT FVT = SrcVT == MVT::v2i8 ? MVT::f16 : MVT::f32;
19365 EVT NVT = EVT::getVectorVT(*DAG.getContext(), FVT, N->getNumOperands());
19366 SmallVector<SDValue> Ops;
19367
19368 for (unsigned i = 0; i < N->getNumOperands(); i++) {
19369 SDValue V = N->getOperand(i);
19370 if (V.isUndef())
19371 Ops.push_back(DAG.getUNDEF(FVT));
19372 else {
19373 LoadSDNode *LD = cast<LoadSDNode>(V);
19374 SDValue NewLoad = DAG.getLoad(FVT, dl, LD->getChain(),
19375 LD->getBasePtr(), LD->getMemOperand());
19376 DAG.ReplaceAllUsesOfValueWith(SDValue(LD, 1), NewLoad.getValue(1));
19377 Ops.push_back(NewLoad);
19378 }
19379 }
19380 return DAG.getBitcast(N->getValueType(0),
19381 DAG.getBuildVector(NVT, dl, Ops));
19382 }
19383 }
19384
19385 // Canonicalise concat_vectors to replace concatenations of truncated nots
19386 // with nots of concatenated truncates. This in some cases allows for multiple
19387 // redundant negations to be eliminated.
19388 // (concat_vectors (v4i16 (truncate (not (v4i32)))),
19389 // (v4i16 (truncate (not (v4i32)))))
19390 // ->
19391 // (not (concat_vectors (v4i16 (truncate (v4i32))),
19392 // (v4i16 (truncate (v4i32)))))
19393 if (N->getNumOperands() == 2 && N0Opc == ISD::TRUNCATE &&
19394 N1Opc == ISD::TRUNCATE && N->isOnlyUserOf(N0.getNode()) &&
19395 N->isOnlyUserOf(N1.getNode())) {
19396 auto isBitwiseVectorNegate = [](SDValue V) {
19397 return V->getOpcode() == ISD::XOR &&
19398 ISD::isConstantSplatVectorAllOnes(V.getOperand(1).getNode());
19399 };
19400 SDValue N00 = N0->getOperand(0);
19401 SDValue N10 = N1->getOperand(0);
19402 if (isBitwiseVectorNegate(N00) && N0->isOnlyUserOf(N00.getNode()) &&
19403 isBitwiseVectorNegate(N10) && N1->isOnlyUserOf(N10.getNode())) {
19404 return DAG.getNOT(
19405 dl,
19406 DAG.getNode(ISD::CONCAT_VECTORS, dl, VT,
19407 DAG.getNode(ISD::TRUNCATE, dl, N0.getValueType(),
19408 N00->getOperand(0)),
19409 DAG.getNode(ISD::TRUNCATE, dl, N1.getValueType(),
19410 N10->getOperand(0))),
19411 VT);
19412 }
19413 }
19414
19415 // Wait till after everything is legalized to try this. That way we have
19416 // legal vector types and such.
19417 if (DCI.isBeforeLegalizeOps())
19418 return SDValue();
19419
19420 // Optimise concat_vectors of two identical binops with a 128-bit destination
19421 // size, combine into an binop of two contacts of the source vectors. eg:
19422 // concat(uhadd(a,b), uhadd(c, d)) -> uhadd(concat(a, c), concat(b, d))
19423 if (N->getNumOperands() == 2 && N0Opc == N1Opc && VT.is128BitVector() &&
19424 DAG.getTargetLoweringInfo().isBinOp(N0Opc) && N0->hasOneUse() &&
19425 N1->hasOneUse()) {
19426 SDValue N00 = N0->getOperand(0);
19427 SDValue N01 = N0->getOperand(1);
19428 SDValue N10 = N1->getOperand(0);
19429 SDValue N11 = N1->getOperand(1);
19430
19431 if (!N00.isUndef() && !N01.isUndef() && !N10.isUndef() && !N11.isUndef()) {
19432 SDValue Concat0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, N00, N10);
19433 SDValue Concat1 = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, N01, N11);
19434 return DAG.getNode(N0Opc, dl, VT, Concat0, Concat1);
19435 }
19436 }
19437
19438 auto IsRSHRN = [](SDValue Shr) {
19439 if (Shr.getOpcode() != AArch64ISD::VLSHR)
19440 return false;
19441 SDValue Op = Shr.getOperand(0);
19442 EVT VT = Op.getValueType();
19443 unsigned ShtAmt = Shr.getConstantOperandVal(1);
19444 if (ShtAmt > VT.getScalarSizeInBits() / 2 || Op.getOpcode() != ISD::ADD)
19445 return false;
19446
19447 APInt Imm;
19448 if (Op.getOperand(1).getOpcode() == AArch64ISD::MOVIshift)
19449 Imm = APInt(VT.getScalarSizeInBits(),
19450 Op.getOperand(1).getConstantOperandVal(0)
19451 << Op.getOperand(1).getConstantOperandVal(1));
19452 else if (Op.getOperand(1).getOpcode() == AArch64ISD::DUP &&
19453 isa<ConstantSDNode>(Op.getOperand(1).getOperand(0)))
19454 Imm = APInt(VT.getScalarSizeInBits(),
19455 Op.getOperand(1).getConstantOperandVal(0));
19456 else
19457 return false;
19458
19459 if (Imm != 1ULL << (ShtAmt - 1))
19460 return false;
19461 return true;
19462 };
19463
19464 // concat(rshrn(x), rshrn(y)) -> rshrn(concat(x, y))
19465 if (N->getNumOperands() == 2 && IsRSHRN(N0) &&
19466 ((IsRSHRN(N1) &&
19467 N0.getConstantOperandVal(1) == N1.getConstantOperandVal(1)) ||
19468 N1.isUndef())) {
19469 SDValue X = N0.getOperand(0).getOperand(0);
19470 SDValue Y = N1.isUndef() ? DAG.getUNDEF(X.getValueType())
19471 : N1.getOperand(0).getOperand(0);
19472 EVT BVT =
19473 X.getValueType().getDoubleNumVectorElementsVT(*DCI.DAG.getContext());
19474 SDValue CC = DAG.getNode(ISD::CONCAT_VECTORS, dl, BVT, X, Y);
19475 SDValue Add = DAG.getNode(
19476 ISD::ADD, dl, BVT, CC,
19477 DAG.getConstant(1ULL << (N0.getConstantOperandVal(1) - 1), dl, BVT));
19478 SDValue Shr =
19479 DAG.getNode(AArch64ISD::VLSHR, dl, BVT, Add, N0.getOperand(1));
19480 return Shr;
19481 }
19482
19483 // concat(zip1(a, b), zip2(a, b)) is zip1(a, b)
19484 if (N->getNumOperands() == 2 && N0Opc == AArch64ISD::ZIP1 &&
19485 N1Opc == AArch64ISD::ZIP2 && N0.getOperand(0) == N1.getOperand(0) &&
19486 N0.getOperand(1) == N1.getOperand(1)) {
19487 SDValue E0 = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, N0.getOperand(0),
19488 DAG.getUNDEF(N0.getValueType()));
19489 SDValue E1 = DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, N0.getOperand(1),
19490 DAG.getUNDEF(N0.getValueType()));
19491 return DAG.getNode(AArch64ISD::ZIP1, dl, VT, E0, E1);
19492 }
19493
19494 // If we see a (concat_vectors (v1x64 A), (v1x64 A)) it's really a vector
19495 // splat. The indexed instructions are going to be expecting a DUPLANE64, so
19496 // canonicalise to that.
19497 if (N->getNumOperands() == 2 && N0 == N1 && VT.getVectorNumElements() == 2) {
19498 assert(VT.getScalarSizeInBits() == 64);
19499 return DAG.getNode(AArch64ISD::DUPLANE64, dl, VT, WidenVector(N0, DAG),
19500 DAG.getConstant(0, dl, MVT::i64));
19501 }
19502
19503 // Canonicalise concat_vectors so that the right-hand vector has as few
19504 // bit-casts as possible before its real operation. The primary matching
19505 // destination for these operations will be the narrowing "2" instructions,
19506 // which depend on the operation being performed on this right-hand vector.
19507 // For example,
19508 // (concat_vectors LHS, (v1i64 (bitconvert (v4i16 RHS))))
19509 // becomes
19510 // (bitconvert (concat_vectors (v4i16 (bitconvert LHS)), RHS))
19511
19512 if (N->getNumOperands() != 2 || N1Opc != ISD::BITCAST)
19513 return SDValue();
19514 SDValue RHS = N1->getOperand(0);
19515 MVT RHSTy = RHS.getValueType().getSimpleVT();
19516 // If the RHS is not a vector, this is not the pattern we're looking for.
19517 if (!RHSTy.isVector())
19518 return SDValue();
19519
19520 LLVM_DEBUG(
19521 dbgs() << "aarch64-lower: concat_vectors bitcast simplification\n");
19522
19523 MVT ConcatTy = MVT::getVectorVT(RHSTy.getVectorElementType(),
19524 RHSTy.getVectorNumElements() * 2);
19525 return DAG.getNode(ISD::BITCAST, dl, VT,
19526 DAG.getNode(ISD::CONCAT_VECTORS, dl, ConcatTy,
19527 DAG.getNode(ISD::BITCAST, dl, RHSTy, N0),
19528 RHS));
19529 }
19530
19531 static SDValue
performExtractSubvectorCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)19532 performExtractSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
19533 SelectionDAG &DAG) {
19534 if (DCI.isBeforeLegalizeOps())
19535 return SDValue();
19536
19537 EVT VT = N->getValueType(0);
19538 if (!VT.isScalableVector() || VT.getVectorElementType() != MVT::i1)
19539 return SDValue();
19540
19541 SDValue V = N->getOperand(0);
19542
19543 // NOTE: This combine exists in DAGCombiner, but that version's legality check
19544 // blocks this combine because the non-const case requires custom lowering.
19545 //
19546 // ty1 extract_vector(ty2 splat(const))) -> ty1 splat(const)
19547 if (V.getOpcode() == ISD::SPLAT_VECTOR)
19548 if (isa<ConstantSDNode>(V.getOperand(0)))
19549 return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), VT, V.getOperand(0));
19550
19551 return SDValue();
19552 }
19553
19554 static SDValue
performInsertSubvectorCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)19555 performInsertSubvectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
19556 SelectionDAG &DAG) {
19557 SDLoc DL(N);
19558 SDValue Vec = N->getOperand(0);
19559 SDValue SubVec = N->getOperand(1);
19560 uint64_t IdxVal = N->getConstantOperandVal(2);
19561 EVT VecVT = Vec.getValueType();
19562 EVT SubVT = SubVec.getValueType();
19563
19564 // Only do this for legal fixed vector types.
19565 if (!VecVT.isFixedLengthVector() ||
19566 !DAG.getTargetLoweringInfo().isTypeLegal(VecVT) ||
19567 !DAG.getTargetLoweringInfo().isTypeLegal(SubVT))
19568 return SDValue();
19569
19570 // Ignore widening patterns.
19571 if (IdxVal == 0 && Vec.isUndef())
19572 return SDValue();
19573
19574 // Subvector must be half the width and an "aligned" insertion.
19575 unsigned NumSubElts = SubVT.getVectorNumElements();
19576 if ((SubVT.getSizeInBits() * 2) != VecVT.getSizeInBits() ||
19577 (IdxVal != 0 && IdxVal != NumSubElts))
19578 return SDValue();
19579
19580 // Fold insert_subvector -> concat_vectors
19581 // insert_subvector(Vec,Sub,lo) -> concat_vectors(Sub,extract(Vec,hi))
19582 // insert_subvector(Vec,Sub,hi) -> concat_vectors(extract(Vec,lo),Sub)
19583 SDValue Lo, Hi;
19584 if (IdxVal == 0) {
19585 Lo = SubVec;
19586 Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, Vec,
19587 DAG.getVectorIdxConstant(NumSubElts, DL));
19588 } else {
19589 Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVT, Vec,
19590 DAG.getVectorIdxConstant(0, DL));
19591 Hi = SubVec;
19592 }
19593 return DAG.getNode(ISD::CONCAT_VECTORS, DL, VecVT, Lo, Hi);
19594 }
19595
tryCombineFixedPointConvert(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)19596 static SDValue tryCombineFixedPointConvert(SDNode *N,
19597 TargetLowering::DAGCombinerInfo &DCI,
19598 SelectionDAG &DAG) {
19599 // Wait until after everything is legalized to try this. That way we have
19600 // legal vector types and such.
19601 if (DCI.isBeforeLegalizeOps())
19602 return SDValue();
19603 // Transform a scalar conversion of a value from a lane extract into a
19604 // lane extract of a vector conversion. E.g., from foo1 to foo2:
19605 // double foo1(int64x2_t a) { return vcvtd_n_f64_s64(a[1], 9); }
19606 // double foo2(int64x2_t a) { return vcvtq_n_f64_s64(a, 9)[1]; }
19607 //
19608 // The second form interacts better with instruction selection and the
19609 // register allocator to avoid cross-class register copies that aren't
19610 // coalescable due to a lane reference.
19611
19612 // Check the operand and see if it originates from a lane extract.
19613 SDValue Op1 = N->getOperand(1);
19614 if (Op1.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
19615 return SDValue();
19616
19617 // Yep, no additional predication needed. Perform the transform.
19618 SDValue IID = N->getOperand(0);
19619 SDValue Shift = N->getOperand(2);
19620 SDValue Vec = Op1.getOperand(0);
19621 SDValue Lane = Op1.getOperand(1);
19622 EVT ResTy = N->getValueType(0);
19623 EVT VecResTy;
19624 SDLoc DL(N);
19625
19626 // The vector width should be 128 bits by the time we get here, even
19627 // if it started as 64 bits (the extract_vector handling will have
19628 // done so). Bail if it is not.
19629 if (Vec.getValueSizeInBits() != 128)
19630 return SDValue();
19631
19632 if (Vec.getValueType() == MVT::v4i32)
19633 VecResTy = MVT::v4f32;
19634 else if (Vec.getValueType() == MVT::v2i64)
19635 VecResTy = MVT::v2f64;
19636 else
19637 return SDValue();
19638
19639 SDValue Convert =
19640 DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, VecResTy, IID, Vec, Shift);
19641 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResTy, Convert, Lane);
19642 }
19643
19644 // AArch64 high-vector "long" operations are formed by performing the non-high
19645 // version on an extract_subvector of each operand which gets the high half:
19646 //
19647 // (longop2 LHS, RHS) == (longop (extract_high LHS), (extract_high RHS))
19648 //
19649 // However, there are cases which don't have an extract_high explicitly, but
19650 // have another operation that can be made compatible with one for free. For
19651 // example:
19652 //
19653 // (dupv64 scalar) --> (extract_high (dup128 scalar))
19654 //
19655 // This routine does the actual conversion of such DUPs, once outer routines
19656 // have determined that everything else is in order.
19657 // It also supports immediate DUP-like nodes (MOVI/MVNi), which we can fold
19658 // similarly here.
tryExtendDUPToExtractHigh(SDValue N,SelectionDAG & DAG)19659 static SDValue tryExtendDUPToExtractHigh(SDValue N, SelectionDAG &DAG) {
19660 MVT VT = N.getSimpleValueType();
19661 if (N.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
19662 N.getConstantOperandVal(1) == 0)
19663 N = N.getOperand(0);
19664
19665 switch (N.getOpcode()) {
19666 case AArch64ISD::DUP:
19667 case AArch64ISD::DUPLANE8:
19668 case AArch64ISD::DUPLANE16:
19669 case AArch64ISD::DUPLANE32:
19670 case AArch64ISD::DUPLANE64:
19671 case AArch64ISD::MOVI:
19672 case AArch64ISD::MOVIshift:
19673 case AArch64ISD::MOVIedit:
19674 case AArch64ISD::MOVImsl:
19675 case AArch64ISD::MVNIshift:
19676 case AArch64ISD::MVNImsl:
19677 break;
19678 default:
19679 // FMOV could be supported, but isn't very useful, as it would only occur
19680 // if you passed a bitcast' floating point immediate to an eligible long
19681 // integer op (addl, smull, ...).
19682 return SDValue();
19683 }
19684
19685 if (!VT.is64BitVector())
19686 return SDValue();
19687
19688 SDLoc DL(N);
19689 unsigned NumElems = VT.getVectorNumElements();
19690 if (N.getValueType().is64BitVector()) {
19691 MVT ElementTy = VT.getVectorElementType();
19692 MVT NewVT = MVT::getVectorVT(ElementTy, NumElems * 2);
19693 N = DAG.getNode(N->getOpcode(), DL, NewVT, N->ops());
19694 }
19695
19696 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, N,
19697 DAG.getConstant(NumElems, DL, MVT::i64));
19698 }
19699
isEssentiallyExtractHighSubvector(SDValue N)19700 static bool isEssentiallyExtractHighSubvector(SDValue N) {
19701 if (N.getOpcode() == ISD::BITCAST)
19702 N = N.getOperand(0);
19703 if (N.getOpcode() != ISD::EXTRACT_SUBVECTOR)
19704 return false;
19705 if (N.getOperand(0).getValueType().isScalableVector())
19706 return false;
19707 return N.getConstantOperandAPInt(1) ==
19708 N.getOperand(0).getValueType().getVectorNumElements() / 2;
19709 }
19710
19711 /// Helper structure to keep track of ISD::SET_CC operands.
19712 struct GenericSetCCInfo {
19713 const SDValue *Opnd0;
19714 const SDValue *Opnd1;
19715 ISD::CondCode CC;
19716 };
19717
19718 /// Helper structure to keep track of a SET_CC lowered into AArch64 code.
19719 struct AArch64SetCCInfo {
19720 const SDValue *Cmp;
19721 AArch64CC::CondCode CC;
19722 };
19723
19724 /// Helper structure to keep track of SetCC information.
19725 union SetCCInfo {
19726 GenericSetCCInfo Generic;
19727 AArch64SetCCInfo AArch64;
19728 };
19729
19730 /// Helper structure to be able to read SetCC information. If set to
19731 /// true, IsAArch64 field, Info is a AArch64SetCCInfo, otherwise Info is a
19732 /// GenericSetCCInfo.
19733 struct SetCCInfoAndKind {
19734 SetCCInfo Info;
19735 bool IsAArch64;
19736 };
19737
19738 /// Check whether or not \p Op is a SET_CC operation, either a generic or
19739 /// an
19740 /// AArch64 lowered one.
19741 /// \p SetCCInfo is filled accordingly.
19742 /// \post SetCCInfo is meanginfull only when this function returns true.
19743 /// \return True when Op is a kind of SET_CC operation.
isSetCC(SDValue Op,SetCCInfoAndKind & SetCCInfo)19744 static bool isSetCC(SDValue Op, SetCCInfoAndKind &SetCCInfo) {
19745 // If this is a setcc, this is straight forward.
19746 if (Op.getOpcode() == ISD::SETCC) {
19747 SetCCInfo.Info.Generic.Opnd0 = &Op.getOperand(0);
19748 SetCCInfo.Info.Generic.Opnd1 = &Op.getOperand(1);
19749 SetCCInfo.Info.Generic.CC = cast<CondCodeSDNode>(Op.getOperand(2))->get();
19750 SetCCInfo.IsAArch64 = false;
19751 return true;
19752 }
19753 // Otherwise, check if this is a matching csel instruction.
19754 // In other words:
19755 // - csel 1, 0, cc
19756 // - csel 0, 1, !cc
19757 if (Op.getOpcode() != AArch64ISD::CSEL)
19758 return false;
19759 // Set the information about the operands.
19760 // TODO: we want the operands of the Cmp not the csel
19761 SetCCInfo.Info.AArch64.Cmp = &Op.getOperand(3);
19762 SetCCInfo.IsAArch64 = true;
19763 SetCCInfo.Info.AArch64.CC =
19764 static_cast<AArch64CC::CondCode>(Op.getConstantOperandVal(2));
19765
19766 // Check that the operands matches the constraints:
19767 // (1) Both operands must be constants.
19768 // (2) One must be 1 and the other must be 0.
19769 ConstantSDNode *TValue = dyn_cast<ConstantSDNode>(Op.getOperand(0));
19770 ConstantSDNode *FValue = dyn_cast<ConstantSDNode>(Op.getOperand(1));
19771
19772 // Check (1).
19773 if (!TValue || !FValue)
19774 return false;
19775
19776 // Check (2).
19777 if (!TValue->isOne()) {
19778 // Update the comparison when we are interested in !cc.
19779 std::swap(TValue, FValue);
19780 SetCCInfo.Info.AArch64.CC =
19781 AArch64CC::getInvertedCondCode(SetCCInfo.Info.AArch64.CC);
19782 }
19783 return TValue->isOne() && FValue->isZero();
19784 }
19785
19786 // Returns true if Op is setcc or zext of setcc.
isSetCCOrZExtSetCC(const SDValue & Op,SetCCInfoAndKind & Info)19787 static bool isSetCCOrZExtSetCC(const SDValue& Op, SetCCInfoAndKind &Info) {
19788 if (isSetCC(Op, Info))
19789 return true;
19790 return ((Op.getOpcode() == ISD::ZERO_EXTEND) &&
19791 isSetCC(Op->getOperand(0), Info));
19792 }
19793
19794 // The folding we want to perform is:
19795 // (add x, [zext] (setcc cc ...) )
19796 // -->
19797 // (csel x, (add x, 1), !cc ...)
19798 //
19799 // The latter will get matched to a CSINC instruction.
performSetccAddFolding(SDNode * Op,SelectionDAG & DAG)19800 static SDValue performSetccAddFolding(SDNode *Op, SelectionDAG &DAG) {
19801 assert(Op && Op->getOpcode() == ISD::ADD && "Unexpected operation!");
19802 SDValue LHS = Op->getOperand(0);
19803 SDValue RHS = Op->getOperand(1);
19804 SetCCInfoAndKind InfoAndKind;
19805
19806 // If both operands are a SET_CC, then we don't want to perform this
19807 // folding and create another csel as this results in more instructions
19808 // (and higher register usage).
19809 if (isSetCCOrZExtSetCC(LHS, InfoAndKind) &&
19810 isSetCCOrZExtSetCC(RHS, InfoAndKind))
19811 return SDValue();
19812
19813 // If neither operand is a SET_CC, give up.
19814 if (!isSetCCOrZExtSetCC(LHS, InfoAndKind)) {
19815 std::swap(LHS, RHS);
19816 if (!isSetCCOrZExtSetCC(LHS, InfoAndKind))
19817 return SDValue();
19818 }
19819
19820 // FIXME: This could be generatized to work for FP comparisons.
19821 EVT CmpVT = InfoAndKind.IsAArch64
19822 ? InfoAndKind.Info.AArch64.Cmp->getOperand(0).getValueType()
19823 : InfoAndKind.Info.Generic.Opnd0->getValueType();
19824 if (CmpVT != MVT::i32 && CmpVT != MVT::i64)
19825 return SDValue();
19826
19827 SDValue CCVal;
19828 SDValue Cmp;
19829 SDLoc dl(Op);
19830 if (InfoAndKind.IsAArch64) {
19831 CCVal = DAG.getConstant(
19832 AArch64CC::getInvertedCondCode(InfoAndKind.Info.AArch64.CC), dl,
19833 MVT::i32);
19834 Cmp = *InfoAndKind.Info.AArch64.Cmp;
19835 } else
19836 Cmp = getAArch64Cmp(
19837 *InfoAndKind.Info.Generic.Opnd0, *InfoAndKind.Info.Generic.Opnd1,
19838 ISD::getSetCCInverse(InfoAndKind.Info.Generic.CC, CmpVT), CCVal, DAG,
19839 dl);
19840
19841 EVT VT = Op->getValueType(0);
19842 LHS = DAG.getNode(ISD::ADD, dl, VT, RHS, DAG.getConstant(1, dl, VT));
19843 return DAG.getNode(AArch64ISD::CSEL, dl, VT, RHS, LHS, CCVal, Cmp);
19844 }
19845
19846 // ADD(UADDV a, UADDV b) --> UADDV(ADD a, b)
performAddUADDVCombine(SDNode * N,SelectionDAG & DAG)19847 static SDValue performAddUADDVCombine(SDNode *N, SelectionDAG &DAG) {
19848 EVT VT = N->getValueType(0);
19849 // Only scalar integer and vector types.
19850 if (N->getOpcode() != ISD::ADD || !VT.isScalarInteger())
19851 return SDValue();
19852
19853 SDValue LHS = N->getOperand(0);
19854 SDValue RHS = N->getOperand(1);
19855 if (LHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
19856 RHS.getOpcode() != ISD::EXTRACT_VECTOR_ELT || LHS.getValueType() != VT)
19857 return SDValue();
19858
19859 auto *LHSN1 = dyn_cast<ConstantSDNode>(LHS->getOperand(1));
19860 auto *RHSN1 = dyn_cast<ConstantSDNode>(RHS->getOperand(1));
19861 if (!LHSN1 || LHSN1 != RHSN1 || !RHSN1->isZero())
19862 return SDValue();
19863
19864 SDValue Op1 = LHS->getOperand(0);
19865 SDValue Op2 = RHS->getOperand(0);
19866 EVT OpVT1 = Op1.getValueType();
19867 EVT OpVT2 = Op2.getValueType();
19868 if (Op1.getOpcode() != AArch64ISD::UADDV || OpVT1 != OpVT2 ||
19869 Op2.getOpcode() != AArch64ISD::UADDV ||
19870 OpVT1.getVectorElementType() != VT)
19871 return SDValue();
19872
19873 SDValue Val1 = Op1.getOperand(0);
19874 SDValue Val2 = Op2.getOperand(0);
19875 EVT ValVT = Val1->getValueType(0);
19876 SDLoc DL(N);
19877 SDValue AddVal = DAG.getNode(ISD::ADD, DL, ValVT, Val1, Val2);
19878 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT,
19879 DAG.getNode(AArch64ISD::UADDV, DL, ValVT, AddVal),
19880 DAG.getConstant(0, DL, MVT::i64));
19881 }
19882
19883 /// Perform the scalar expression combine in the form of:
19884 /// CSEL(c, 1, cc) + b => CSINC(b+c, b, cc)
19885 /// CSNEG(c, -1, cc) + b => CSINC(b+c, b, cc)
performAddCSelIntoCSinc(SDNode * N,SelectionDAG & DAG)19886 static SDValue performAddCSelIntoCSinc(SDNode *N, SelectionDAG &DAG) {
19887 EVT VT = N->getValueType(0);
19888 if (!VT.isScalarInteger() || N->getOpcode() != ISD::ADD)
19889 return SDValue();
19890
19891 SDValue LHS = N->getOperand(0);
19892 SDValue RHS = N->getOperand(1);
19893
19894 // Handle commutivity.
19895 if (LHS.getOpcode() != AArch64ISD::CSEL &&
19896 LHS.getOpcode() != AArch64ISD::CSNEG) {
19897 std::swap(LHS, RHS);
19898 if (LHS.getOpcode() != AArch64ISD::CSEL &&
19899 LHS.getOpcode() != AArch64ISD::CSNEG) {
19900 return SDValue();
19901 }
19902 }
19903
19904 if (!LHS.hasOneUse())
19905 return SDValue();
19906
19907 AArch64CC::CondCode AArch64CC =
19908 static_cast<AArch64CC::CondCode>(LHS.getConstantOperandVal(2));
19909
19910 // The CSEL should include a const one operand, and the CSNEG should include
19911 // One or NegOne operand.
19912 ConstantSDNode *CTVal = dyn_cast<ConstantSDNode>(LHS.getOperand(0));
19913 ConstantSDNode *CFVal = dyn_cast<ConstantSDNode>(LHS.getOperand(1));
19914 if (!CTVal || !CFVal)
19915 return SDValue();
19916
19917 if (!(LHS.getOpcode() == AArch64ISD::CSEL &&
19918 (CTVal->isOne() || CFVal->isOne())) &&
19919 !(LHS.getOpcode() == AArch64ISD::CSNEG &&
19920 (CTVal->isOne() || CFVal->isAllOnes())))
19921 return SDValue();
19922
19923 // Switch CSEL(1, c, cc) to CSEL(c, 1, !cc)
19924 if (LHS.getOpcode() == AArch64ISD::CSEL && CTVal->isOne() &&
19925 !CFVal->isOne()) {
19926 std::swap(CTVal, CFVal);
19927 AArch64CC = AArch64CC::getInvertedCondCode(AArch64CC);
19928 }
19929
19930 SDLoc DL(N);
19931 // Switch CSNEG(1, c, cc) to CSNEG(-c, -1, !cc)
19932 if (LHS.getOpcode() == AArch64ISD::CSNEG && CTVal->isOne() &&
19933 !CFVal->isAllOnes()) {
19934 APInt C = -1 * CFVal->getAPIntValue();
19935 CTVal = cast<ConstantSDNode>(DAG.getConstant(C, DL, VT));
19936 CFVal = cast<ConstantSDNode>(DAG.getAllOnesConstant(DL, VT));
19937 AArch64CC = AArch64CC::getInvertedCondCode(AArch64CC);
19938 }
19939
19940 // It might be neutral for larger constants, as the immediate need to be
19941 // materialized in a register.
19942 APInt ADDC = CTVal->getAPIntValue();
19943 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
19944 if (!TLI.isLegalAddImmediate(ADDC.getSExtValue()))
19945 return SDValue();
19946
19947 assert(((LHS.getOpcode() == AArch64ISD::CSEL && CFVal->isOne()) ||
19948 (LHS.getOpcode() == AArch64ISD::CSNEG && CFVal->isAllOnes())) &&
19949 "Unexpected constant value");
19950
19951 SDValue NewNode = DAG.getNode(ISD::ADD, DL, VT, RHS, SDValue(CTVal, 0));
19952 SDValue CCVal = DAG.getConstant(AArch64CC, DL, MVT::i32);
19953 SDValue Cmp = LHS.getOperand(3);
19954
19955 return DAG.getNode(AArch64ISD::CSINC, DL, VT, NewNode, RHS, CCVal, Cmp);
19956 }
19957
19958 // ADD(UDOT(zero, x, y), A) --> UDOT(A, x, y)
performAddDotCombine(SDNode * N,SelectionDAG & DAG)19959 static SDValue performAddDotCombine(SDNode *N, SelectionDAG &DAG) {
19960 EVT VT = N->getValueType(0);
19961 if (N->getOpcode() != ISD::ADD)
19962 return SDValue();
19963
19964 SDValue Dot = N->getOperand(0);
19965 SDValue A = N->getOperand(1);
19966 // Handle commutivity
19967 auto isZeroDot = [](SDValue Dot) {
19968 return (Dot.getOpcode() == AArch64ISD::UDOT ||
19969 Dot.getOpcode() == AArch64ISD::SDOT) &&
19970 isZerosVector(Dot.getOperand(0).getNode());
19971 };
19972 if (!isZeroDot(Dot))
19973 std::swap(Dot, A);
19974 if (!isZeroDot(Dot))
19975 return SDValue();
19976
19977 return DAG.getNode(Dot.getOpcode(), SDLoc(N), VT, A, Dot.getOperand(1),
19978 Dot.getOperand(2));
19979 }
19980
isNegatedInteger(SDValue Op)19981 static bool isNegatedInteger(SDValue Op) {
19982 return Op.getOpcode() == ISD::SUB && isNullConstant(Op.getOperand(0));
19983 }
19984
getNegatedInteger(SDValue Op,SelectionDAG & DAG)19985 static SDValue getNegatedInteger(SDValue Op, SelectionDAG &DAG) {
19986 SDLoc DL(Op);
19987 EVT VT = Op.getValueType();
19988 SDValue Zero = DAG.getConstant(0, DL, VT);
19989 return DAG.getNode(ISD::SUB, DL, VT, Zero, Op);
19990 }
19991
19992 // Try to fold
19993 //
19994 // (neg (csel X, Y)) -> (csel (neg X), (neg Y))
19995 //
19996 // The folding helps csel to be matched with csneg without generating
19997 // redundant neg instruction, which includes negation of the csel expansion
19998 // of abs node lowered by lowerABS.
performNegCSelCombine(SDNode * N,SelectionDAG & DAG)19999 static SDValue performNegCSelCombine(SDNode *N, SelectionDAG &DAG) {
20000 if (!isNegatedInteger(SDValue(N, 0)))
20001 return SDValue();
20002
20003 SDValue CSel = N->getOperand(1);
20004 if (CSel.getOpcode() != AArch64ISD::CSEL || !CSel->hasOneUse())
20005 return SDValue();
20006
20007 SDValue N0 = CSel.getOperand(0);
20008 SDValue N1 = CSel.getOperand(1);
20009
20010 // If both of them is not negations, it's not worth the folding as it
20011 // introduces two additional negations while reducing one negation.
20012 if (!isNegatedInteger(N0) && !isNegatedInteger(N1))
20013 return SDValue();
20014
20015 SDValue N0N = getNegatedInteger(N0, DAG);
20016 SDValue N1N = getNegatedInteger(N1, DAG);
20017
20018 SDLoc DL(N);
20019 EVT VT = CSel.getValueType();
20020 return DAG.getNode(AArch64ISD::CSEL, DL, VT, N0N, N1N, CSel.getOperand(2),
20021 CSel.getOperand(3));
20022 }
20023
20024 // The basic add/sub long vector instructions have variants with "2" on the end
20025 // which act on the high-half of their inputs. They are normally matched by
20026 // patterns like:
20027 //
20028 // (add (zeroext (extract_high LHS)),
20029 // (zeroext (extract_high RHS)))
20030 // -> uaddl2 vD, vN, vM
20031 //
20032 // However, if one of the extracts is something like a duplicate, this
20033 // instruction can still be used profitably. This function puts the DAG into a
20034 // more appropriate form for those patterns to trigger.
performAddSubLongCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)20035 static SDValue performAddSubLongCombine(SDNode *N,
20036 TargetLowering::DAGCombinerInfo &DCI) {
20037 SelectionDAG &DAG = DCI.DAG;
20038 if (DCI.isBeforeLegalizeOps())
20039 return SDValue();
20040
20041 MVT VT = N->getSimpleValueType(0);
20042 if (!VT.is128BitVector()) {
20043 if (N->getOpcode() == ISD::ADD)
20044 return performSetccAddFolding(N, DAG);
20045 return SDValue();
20046 }
20047
20048 // Make sure both branches are extended in the same way.
20049 SDValue LHS = N->getOperand(0);
20050 SDValue RHS = N->getOperand(1);
20051 if ((LHS.getOpcode() != ISD::ZERO_EXTEND &&
20052 LHS.getOpcode() != ISD::SIGN_EXTEND) ||
20053 LHS.getOpcode() != RHS.getOpcode())
20054 return SDValue();
20055
20056 unsigned ExtType = LHS.getOpcode();
20057
20058 // It's not worth doing if at least one of the inputs isn't already an
20059 // extract, but we don't know which it'll be so we have to try both.
20060 if (isEssentiallyExtractHighSubvector(LHS.getOperand(0))) {
20061 RHS = tryExtendDUPToExtractHigh(RHS.getOperand(0), DAG);
20062 if (!RHS.getNode())
20063 return SDValue();
20064
20065 RHS = DAG.getNode(ExtType, SDLoc(N), VT, RHS);
20066 } else if (isEssentiallyExtractHighSubvector(RHS.getOperand(0))) {
20067 LHS = tryExtendDUPToExtractHigh(LHS.getOperand(0), DAG);
20068 if (!LHS.getNode())
20069 return SDValue();
20070
20071 LHS = DAG.getNode(ExtType, SDLoc(N), VT, LHS);
20072 }
20073
20074 return DAG.getNode(N->getOpcode(), SDLoc(N), VT, LHS, RHS);
20075 }
20076
isCMP(SDValue Op)20077 static bool isCMP(SDValue Op) {
20078 return Op.getOpcode() == AArch64ISD::SUBS &&
20079 !Op.getNode()->hasAnyUseOfValue(0);
20080 }
20081
20082 // (CSEL 1 0 CC Cond) => CC
20083 // (CSEL 0 1 CC Cond) => !CC
getCSETCondCode(SDValue Op)20084 static std::optional<AArch64CC::CondCode> getCSETCondCode(SDValue Op) {
20085 if (Op.getOpcode() != AArch64ISD::CSEL)
20086 return std::nullopt;
20087 auto CC = static_cast<AArch64CC::CondCode>(Op.getConstantOperandVal(2));
20088 if (CC == AArch64CC::AL || CC == AArch64CC::NV)
20089 return std::nullopt;
20090 SDValue OpLHS = Op.getOperand(0);
20091 SDValue OpRHS = Op.getOperand(1);
20092 if (isOneConstant(OpLHS) && isNullConstant(OpRHS))
20093 return CC;
20094 if (isNullConstant(OpLHS) && isOneConstant(OpRHS))
20095 return getInvertedCondCode(CC);
20096
20097 return std::nullopt;
20098 }
20099
20100 // (ADC{S} l r (CMP (CSET HS carry) 1)) => (ADC{S} l r carry)
20101 // (SBC{S} l r (CMP 0 (CSET LO carry))) => (SBC{S} l r carry)
foldOverflowCheck(SDNode * Op,SelectionDAG & DAG,bool IsAdd)20102 static SDValue foldOverflowCheck(SDNode *Op, SelectionDAG &DAG, bool IsAdd) {
20103 SDValue CmpOp = Op->getOperand(2);
20104 if (!isCMP(CmpOp))
20105 return SDValue();
20106
20107 if (IsAdd) {
20108 if (!isOneConstant(CmpOp.getOperand(1)))
20109 return SDValue();
20110 } else {
20111 if (!isNullConstant(CmpOp.getOperand(0)))
20112 return SDValue();
20113 }
20114
20115 SDValue CsetOp = CmpOp->getOperand(IsAdd ? 0 : 1);
20116 auto CC = getCSETCondCode(CsetOp);
20117 if (CC != (IsAdd ? AArch64CC::HS : AArch64CC::LO))
20118 return SDValue();
20119
20120 return DAG.getNode(Op->getOpcode(), SDLoc(Op), Op->getVTList(),
20121 Op->getOperand(0), Op->getOperand(1),
20122 CsetOp.getOperand(3));
20123 }
20124
20125 // (ADC x 0 cond) => (CINC x HS cond)
foldADCToCINC(SDNode * N,SelectionDAG & DAG)20126 static SDValue foldADCToCINC(SDNode *N, SelectionDAG &DAG) {
20127 SDValue LHS = N->getOperand(0);
20128 SDValue RHS = N->getOperand(1);
20129 SDValue Cond = N->getOperand(2);
20130
20131 if (!isNullConstant(RHS))
20132 return SDValue();
20133
20134 EVT VT = N->getValueType(0);
20135 SDLoc DL(N);
20136
20137 // (CINC x cc cond) <=> (CSINC x x !cc cond)
20138 SDValue CC = DAG.getConstant(AArch64CC::LO, DL, MVT::i32);
20139 return DAG.getNode(AArch64ISD::CSINC, DL, VT, LHS, LHS, CC, Cond);
20140 }
20141
performBuildVectorCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)20142 static SDValue performBuildVectorCombine(SDNode *N,
20143 TargetLowering::DAGCombinerInfo &DCI,
20144 SelectionDAG &DAG) {
20145 SDLoc DL(N);
20146 EVT VT = N->getValueType(0);
20147
20148 if (DAG.getSubtarget<AArch64Subtarget>().isNeonAvailable() &&
20149 (VT == MVT::v4f16 || VT == MVT::v4bf16)) {
20150 SDValue Elt0 = N->getOperand(0), Elt1 = N->getOperand(1),
20151 Elt2 = N->getOperand(2), Elt3 = N->getOperand(3);
20152 if (Elt0->getOpcode() == ISD::FP_ROUND &&
20153 Elt1->getOpcode() == ISD::FP_ROUND &&
20154 isa<ConstantSDNode>(Elt0->getOperand(1)) &&
20155 isa<ConstantSDNode>(Elt1->getOperand(1)) &&
20156 Elt0->getConstantOperandVal(1) == Elt1->getConstantOperandVal(1) &&
20157 Elt0->getOperand(0)->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20158 Elt1->getOperand(0)->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20159 // Constant index.
20160 isa<ConstantSDNode>(Elt0->getOperand(0)->getOperand(1)) &&
20161 isa<ConstantSDNode>(Elt1->getOperand(0)->getOperand(1)) &&
20162 Elt0->getOperand(0)->getOperand(0) ==
20163 Elt1->getOperand(0)->getOperand(0) &&
20164 Elt0->getOperand(0)->getConstantOperandVal(1) == 0 &&
20165 Elt1->getOperand(0)->getConstantOperandVal(1) == 1) {
20166 SDValue LowLanesSrcVec = Elt0->getOperand(0)->getOperand(0);
20167 if (LowLanesSrcVec.getValueType() == MVT::v2f64) {
20168 SDValue HighLanes;
20169 if (Elt2->getOpcode() == ISD::UNDEF &&
20170 Elt3->getOpcode() == ISD::UNDEF) {
20171 HighLanes = DAG.getUNDEF(MVT::v2f32);
20172 } else if (Elt2->getOpcode() == ISD::FP_ROUND &&
20173 Elt3->getOpcode() == ISD::FP_ROUND &&
20174 isa<ConstantSDNode>(Elt2->getOperand(1)) &&
20175 isa<ConstantSDNode>(Elt3->getOperand(1)) &&
20176 Elt2->getConstantOperandVal(1) ==
20177 Elt3->getConstantOperandVal(1) &&
20178 Elt2->getOperand(0)->getOpcode() ==
20179 ISD::EXTRACT_VECTOR_ELT &&
20180 Elt3->getOperand(0)->getOpcode() ==
20181 ISD::EXTRACT_VECTOR_ELT &&
20182 // Constant index.
20183 isa<ConstantSDNode>(Elt2->getOperand(0)->getOperand(1)) &&
20184 isa<ConstantSDNode>(Elt3->getOperand(0)->getOperand(1)) &&
20185 Elt2->getOperand(0)->getOperand(0) ==
20186 Elt3->getOperand(0)->getOperand(0) &&
20187 Elt2->getOperand(0)->getConstantOperandVal(1) == 0 &&
20188 Elt3->getOperand(0)->getConstantOperandVal(1) == 1) {
20189 SDValue HighLanesSrcVec = Elt2->getOperand(0)->getOperand(0);
20190 HighLanes =
20191 DAG.getNode(AArch64ISD::FCVTXN, DL, MVT::v2f32, HighLanesSrcVec);
20192 }
20193 if (HighLanes) {
20194 SDValue DoubleToSingleSticky =
20195 DAG.getNode(AArch64ISD::FCVTXN, DL, MVT::v2f32, LowLanesSrcVec);
20196 SDValue Concat = DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4f32,
20197 DoubleToSingleSticky, HighLanes);
20198 return DAG.getNode(ISD::FP_ROUND, DL, VT, Concat,
20199 Elt0->getOperand(1));
20200 }
20201 }
20202 }
20203 }
20204
20205 if (VT == MVT::v2f64) {
20206 SDValue Elt0 = N->getOperand(0), Elt1 = N->getOperand(1);
20207 if (Elt0->getOpcode() == ISD::FP_EXTEND &&
20208 Elt1->getOpcode() == ISD::FP_EXTEND &&
20209 Elt0->getOperand(0)->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20210 Elt1->getOperand(0)->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20211 Elt0->getOperand(0)->getOperand(0) ==
20212 Elt1->getOperand(0)->getOperand(0) &&
20213 // Constant index.
20214 isa<ConstantSDNode>(Elt0->getOperand(0)->getOperand(1)) &&
20215 isa<ConstantSDNode>(Elt1->getOperand(0)->getOperand(1)) &&
20216 Elt0->getOperand(0)->getConstantOperandVal(1) + 1 ==
20217 Elt1->getOperand(0)->getConstantOperandVal(1) &&
20218 // EXTRACT_SUBVECTOR requires that Idx be a constant multiple of
20219 // ResultType's known minimum vector length.
20220 Elt0->getOperand(0)->getConstantOperandVal(1) %
20221 VT.getVectorMinNumElements() ==
20222 0) {
20223 SDValue SrcVec = Elt0->getOperand(0)->getOperand(0);
20224 if (SrcVec.getValueType() == MVT::v4f16 ||
20225 SrcVec.getValueType() == MVT::v4bf16) {
20226 SDValue HalfToSingle =
20227 DAG.getNode(ISD::FP_EXTEND, DL, MVT::v4f32, SrcVec);
20228 SDValue SubvectorIdx = Elt0->getOperand(0)->getOperand(1);
20229 SDValue Extract = DAG.getNode(
20230 ISD::EXTRACT_SUBVECTOR, DL, VT.changeVectorElementType(MVT::f32),
20231 HalfToSingle, SubvectorIdx);
20232 return DAG.getNode(ISD::FP_EXTEND, DL, VT, Extract);
20233 }
20234 }
20235 }
20236
20237 // A build vector of two extracted elements is equivalent to an
20238 // extract subvector where the inner vector is any-extended to the
20239 // extract_vector_elt VT.
20240 // (build_vector (extract_elt_iXX_to_i32 vec Idx+0)
20241 // (extract_elt_iXX_to_i32 vec Idx+1))
20242 // => (extract_subvector (anyext_iXX_to_i32 vec) Idx)
20243
20244 // For now, only consider the v2i32 case, which arises as a result of
20245 // legalization.
20246 if (VT != MVT::v2i32)
20247 return SDValue();
20248
20249 SDValue Elt0 = N->getOperand(0), Elt1 = N->getOperand(1);
20250 // Reminder, EXTRACT_VECTOR_ELT has the effect of any-extending to its VT.
20251 if (Elt0->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20252 Elt1->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20253 // Constant index.
20254 isa<ConstantSDNode>(Elt0->getOperand(1)) &&
20255 isa<ConstantSDNode>(Elt1->getOperand(1)) &&
20256 // Both EXTRACT_VECTOR_ELT from same vector...
20257 Elt0->getOperand(0) == Elt1->getOperand(0) &&
20258 // ... and contiguous. First element's index +1 == second element's index.
20259 Elt0->getConstantOperandVal(1) + 1 == Elt1->getConstantOperandVal(1) &&
20260 // EXTRACT_SUBVECTOR requires that Idx be a constant multiple of
20261 // ResultType's known minimum vector length.
20262 Elt0->getConstantOperandVal(1) % VT.getVectorMinNumElements() == 0) {
20263 SDValue VecToExtend = Elt0->getOperand(0);
20264 EVT ExtVT = VecToExtend.getValueType().changeVectorElementType(MVT::i32);
20265 if (!DAG.getTargetLoweringInfo().isTypeLegal(ExtVT))
20266 return SDValue();
20267
20268 SDValue SubvectorIdx = DAG.getVectorIdxConstant(Elt0->getConstantOperandVal(1), DL);
20269
20270 SDValue Ext = DAG.getNode(ISD::ANY_EXTEND, DL, ExtVT, VecToExtend);
20271 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v2i32, Ext,
20272 SubvectorIdx);
20273 }
20274
20275 return SDValue();
20276 }
20277
performTruncateCombine(SDNode * N,SelectionDAG & DAG)20278 static SDValue performTruncateCombine(SDNode *N,
20279 SelectionDAG &DAG) {
20280 EVT VT = N->getValueType(0);
20281 SDValue N0 = N->getOperand(0);
20282 if (VT.isFixedLengthVector() && VT.is64BitVector() && N0.hasOneUse() &&
20283 N0.getOpcode() == AArch64ISD::DUP) {
20284 SDValue Op = N0.getOperand(0);
20285 if (VT.getScalarType() == MVT::i32 &&
20286 N0.getOperand(0).getValueType().getScalarType() == MVT::i64)
20287 Op = DAG.getNode(ISD::TRUNCATE, SDLoc(N), MVT::i32, Op);
20288 return DAG.getNode(N0.getOpcode(), SDLoc(N), VT, Op);
20289 }
20290
20291 return SDValue();
20292 }
20293
20294 // Check an node is an extend or shift operand
isExtendOrShiftOperand(SDValue N)20295 static bool isExtendOrShiftOperand(SDValue N) {
20296 unsigned Opcode = N.getOpcode();
20297 if (ISD::isExtOpcode(Opcode) || Opcode == ISD::SIGN_EXTEND_INREG) {
20298 EVT SrcVT;
20299 if (Opcode == ISD::SIGN_EXTEND_INREG)
20300 SrcVT = cast<VTSDNode>(N.getOperand(1))->getVT();
20301 else
20302 SrcVT = N.getOperand(0).getValueType();
20303
20304 return SrcVT == MVT::i32 || SrcVT == MVT::i16 || SrcVT == MVT::i8;
20305 } else if (Opcode == ISD::AND) {
20306 ConstantSDNode *CSD = dyn_cast<ConstantSDNode>(N.getOperand(1));
20307 if (!CSD)
20308 return false;
20309 uint64_t AndMask = CSD->getZExtValue();
20310 return AndMask == 0xff || AndMask == 0xffff || AndMask == 0xffffffff;
20311 } else if (Opcode == ISD::SHL || Opcode == ISD::SRL || Opcode == ISD::SRA) {
20312 return isa<ConstantSDNode>(N.getOperand(1));
20313 }
20314
20315 return false;
20316 }
20317
20318 // (N - Y) + Z --> (Z - Y) + N
20319 // when N is an extend or shift operand
performAddCombineSubShift(SDNode * N,SDValue SUB,SDValue Z,SelectionDAG & DAG)20320 static SDValue performAddCombineSubShift(SDNode *N, SDValue SUB, SDValue Z,
20321 SelectionDAG &DAG) {
20322 auto IsOneUseExtend = [](SDValue N) {
20323 return N.hasOneUse() && isExtendOrShiftOperand(N);
20324 };
20325
20326 // DAGCombiner will revert the combination when Z is constant cause
20327 // dead loop. So don't enable the combination when Z is constant.
20328 // If Z is one use shift C, we also can't do the optimization.
20329 // It will falling to self infinite loop.
20330 if (isa<ConstantSDNode>(Z) || IsOneUseExtend(Z))
20331 return SDValue();
20332
20333 if (SUB.getOpcode() != ISD::SUB || !SUB.hasOneUse())
20334 return SDValue();
20335
20336 SDValue Shift = SUB.getOperand(0);
20337 if (!IsOneUseExtend(Shift))
20338 return SDValue();
20339
20340 SDLoc DL(N);
20341 EVT VT = N->getValueType(0);
20342
20343 SDValue Y = SUB.getOperand(1);
20344 SDValue NewSub = DAG.getNode(ISD::SUB, DL, VT, Z, Y);
20345 return DAG.getNode(ISD::ADD, DL, VT, NewSub, Shift);
20346 }
20347
performAddCombineForShiftedOperands(SDNode * N,SelectionDAG & DAG)20348 static SDValue performAddCombineForShiftedOperands(SDNode *N,
20349 SelectionDAG &DAG) {
20350 // NOTE: Swapping LHS and RHS is not done for SUB, since SUB is not
20351 // commutative.
20352 if (N->getOpcode() != ISD::ADD)
20353 return SDValue();
20354
20355 // Bail out when value type is not one of {i32, i64}, since AArch64 ADD with
20356 // shifted register is only available for i32 and i64.
20357 EVT VT = N->getValueType(0);
20358 if (VT != MVT::i32 && VT != MVT::i64)
20359 return SDValue();
20360
20361 SDLoc DL(N);
20362 SDValue LHS = N->getOperand(0);
20363 SDValue RHS = N->getOperand(1);
20364
20365 if (SDValue Val = performAddCombineSubShift(N, LHS, RHS, DAG))
20366 return Val;
20367 if (SDValue Val = performAddCombineSubShift(N, RHS, LHS, DAG))
20368 return Val;
20369
20370 uint64_t LHSImm = 0, RHSImm = 0;
20371 // If both operand are shifted by imm and shift amount is not greater than 4
20372 // for one operand, swap LHS and RHS to put operand with smaller shift amount
20373 // on RHS.
20374 //
20375 // On many AArch64 processors (Cortex A78, Neoverse N1/N2/V1, etc), ADD with
20376 // LSL shift (shift <= 4) has smaller latency and larger throughput than ADD
20377 // with LSL (shift > 4). For the rest of processors, this is no-op for
20378 // performance or correctness.
20379 if (isOpcWithIntImmediate(LHS.getNode(), ISD::SHL, LHSImm) &&
20380 isOpcWithIntImmediate(RHS.getNode(), ISD::SHL, RHSImm) && LHSImm <= 4 &&
20381 RHSImm > 4 && LHS.hasOneUse())
20382 return DAG.getNode(ISD::ADD, DL, VT, RHS, LHS);
20383
20384 return SDValue();
20385 }
20386
20387 // The mid end will reassociate sub(sub(x, m1), m2) to sub(x, add(m1, m2))
20388 // This reassociates it back to allow the creation of more mls instructions.
performSubAddMULCombine(SDNode * N,SelectionDAG & DAG)20389 static SDValue performSubAddMULCombine(SDNode *N, SelectionDAG &DAG) {
20390 if (N->getOpcode() != ISD::SUB)
20391 return SDValue();
20392
20393 SDValue Add = N->getOperand(1);
20394 SDValue X = N->getOperand(0);
20395 if (Add.getOpcode() != ISD::ADD)
20396 return SDValue();
20397
20398 if (!Add.hasOneUse())
20399 return SDValue();
20400 if (DAG.isConstantIntBuildVectorOrConstantInt(peekThroughBitcasts(X)))
20401 return SDValue();
20402
20403 SDValue M1 = Add.getOperand(0);
20404 SDValue M2 = Add.getOperand(1);
20405 if (M1.getOpcode() != ISD::MUL && M1.getOpcode() != AArch64ISD::SMULL &&
20406 M1.getOpcode() != AArch64ISD::UMULL)
20407 return SDValue();
20408 if (M2.getOpcode() != ISD::MUL && M2.getOpcode() != AArch64ISD::SMULL &&
20409 M2.getOpcode() != AArch64ISD::UMULL)
20410 return SDValue();
20411
20412 EVT VT = N->getValueType(0);
20413 SDValue Sub = DAG.getNode(ISD::SUB, SDLoc(N), VT, X, M1);
20414 return DAG.getNode(ISD::SUB, SDLoc(N), VT, Sub, M2);
20415 }
20416
20417 // Combine into mla/mls.
20418 // This works on the patterns of:
20419 // add v1, (mul v2, v3)
20420 // sub v1, (mul v2, v3)
20421 // for vectors of type <1 x i64> and <2 x i64> when SVE is available.
20422 // It will transform the add/sub to a scalable version, so that we can
20423 // make use of SVE's MLA/MLS that will be generated for that pattern
20424 static SDValue
performSVEMulAddSubCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)20425 performSVEMulAddSubCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
20426 SelectionDAG &DAG = DCI.DAG;
20427 // Make sure that the types are legal
20428 if (!DCI.isAfterLegalizeDAG())
20429 return SDValue();
20430 // Before using SVE's features, check first if it's available.
20431 if (!DAG.getSubtarget<AArch64Subtarget>().hasSVE())
20432 return SDValue();
20433
20434 if (N->getOpcode() != ISD::ADD && N->getOpcode() != ISD::SUB)
20435 return SDValue();
20436
20437 if (!N->getValueType(0).isFixedLengthVector())
20438 return SDValue();
20439
20440 auto performOpt = [&DAG, &N](SDValue Op0, SDValue Op1) -> SDValue {
20441 if (Op1.getOpcode() != ISD::EXTRACT_SUBVECTOR)
20442 return SDValue();
20443
20444 if (!cast<ConstantSDNode>(Op1->getOperand(1))->isZero())
20445 return SDValue();
20446
20447 SDValue MulValue = Op1->getOperand(0);
20448 if (MulValue.getOpcode() != AArch64ISD::MUL_PRED)
20449 return SDValue();
20450
20451 if (!Op1.hasOneUse() || !MulValue.hasOneUse())
20452 return SDValue();
20453
20454 EVT ScalableVT = MulValue.getValueType();
20455 if (!ScalableVT.isScalableVector())
20456 return SDValue();
20457
20458 SDValue ScaledOp = convertToScalableVector(DAG, ScalableVT, Op0);
20459 SDValue NewValue =
20460 DAG.getNode(N->getOpcode(), SDLoc(N), ScalableVT, {ScaledOp, MulValue});
20461 return convertFromScalableVector(DAG, N->getValueType(0), NewValue);
20462 };
20463
20464 if (SDValue res = performOpt(N->getOperand(0), N->getOperand(1)))
20465 return res;
20466 else if (N->getOpcode() == ISD::ADD)
20467 return performOpt(N->getOperand(1), N->getOperand(0));
20468
20469 return SDValue();
20470 }
20471
20472 // Given a i64 add from a v1i64 extract, convert to a neon v1i64 add. This can
20473 // help, for example, to produce ssra from sshr+add.
performAddSubIntoVectorOp(SDNode * N,SelectionDAG & DAG)20474 static SDValue performAddSubIntoVectorOp(SDNode *N, SelectionDAG &DAG) {
20475 EVT VT = N->getValueType(0);
20476 if (VT != MVT::i64 ||
20477 DAG.getTargetLoweringInfo().isOperationExpand(N->getOpcode(), MVT::v1i64))
20478 return SDValue();
20479 SDValue Op0 = N->getOperand(0);
20480 SDValue Op1 = N->getOperand(1);
20481
20482 // At least one of the operands should be an extract, and the other should be
20483 // something that is easy to convert to v1i64 type (in this case a load).
20484 if (Op0.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
20485 Op0.getOpcode() != ISD::LOAD)
20486 return SDValue();
20487 if (Op1.getOpcode() != ISD::EXTRACT_VECTOR_ELT &&
20488 Op1.getOpcode() != ISD::LOAD)
20489 return SDValue();
20490
20491 SDLoc DL(N);
20492 if (Op0.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20493 Op0.getOperand(0).getValueType() == MVT::v1i64) {
20494 Op0 = Op0.getOperand(0);
20495 Op1 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v1i64, Op1);
20496 } else if (Op1.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
20497 Op1.getOperand(0).getValueType() == MVT::v1i64) {
20498 Op0 = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, MVT::v1i64, Op0);
20499 Op1 = Op1.getOperand(0);
20500 } else
20501 return SDValue();
20502
20503 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i64,
20504 DAG.getNode(N->getOpcode(), DL, MVT::v1i64, Op0, Op1),
20505 DAG.getConstant(0, DL, MVT::i64));
20506 }
20507
isLoadOrMultipleLoads(SDValue B,SmallVector<LoadSDNode * > & Loads)20508 static bool isLoadOrMultipleLoads(SDValue B, SmallVector<LoadSDNode *> &Loads) {
20509 SDValue BV = peekThroughOneUseBitcasts(B);
20510 if (!BV->hasOneUse())
20511 return false;
20512 if (auto *Ld = dyn_cast<LoadSDNode>(BV)) {
20513 if (!Ld || !Ld->isSimple())
20514 return false;
20515 Loads.push_back(Ld);
20516 return true;
20517 } else if (BV.getOpcode() == ISD::BUILD_VECTOR ||
20518 BV.getOpcode() == ISD::CONCAT_VECTORS) {
20519 for (unsigned Op = 0; Op < BV.getNumOperands(); Op++) {
20520 auto *Ld = dyn_cast<LoadSDNode>(BV.getOperand(Op));
20521 if (!Ld || !Ld->isSimple() || !BV.getOperand(Op).hasOneUse())
20522 return false;
20523 Loads.push_back(Ld);
20524 }
20525 return true;
20526 } else if (B.getOpcode() == ISD::VECTOR_SHUFFLE) {
20527 // Try to find a tree of shuffles and concats from how IR shuffles of loads
20528 // are lowered. Note that this only comes up because we do not always visit
20529 // operands before uses. After that is fixed this can be removed and in the
20530 // meantime this is fairly specific to the lowering we expect from IR.
20531 // t46: v16i8 = vector_shuffle<0,1,2,3,4,5,6,7,8,9,10,11,16,17,18,19> t44, t45
20532 // t44: v16i8 = vector_shuffle<0,1,2,3,4,5,6,7,16,17,18,19,u,u,u,u> t42, t43
20533 // t42: v16i8 = concat_vectors t40, t36, undef:v4i8, undef:v4i8
20534 // t40: v4i8,ch = load<(load (s32) from %ir.17)> t0, t22, undef:i64
20535 // t36: v4i8,ch = load<(load (s32) from %ir.13)> t0, t18, undef:i64
20536 // t43: v16i8 = concat_vectors t32, undef:v4i8, undef:v4i8, undef:v4i8
20537 // t32: v4i8,ch = load<(load (s32) from %ir.9)> t0, t14, undef:i64
20538 // t45: v16i8 = concat_vectors t28, undef:v4i8, undef:v4i8, undef:v4i8
20539 // t28: v4i8,ch = load<(load (s32) from %ir.0)> t0, t2, undef:i64
20540 if (B.getOperand(0).getOpcode() != ISD::VECTOR_SHUFFLE ||
20541 B.getOperand(0).getOperand(0).getOpcode() != ISD::CONCAT_VECTORS ||
20542 B.getOperand(0).getOperand(1).getOpcode() != ISD::CONCAT_VECTORS ||
20543 B.getOperand(1).getOpcode() != ISD::CONCAT_VECTORS ||
20544 B.getOperand(1).getNumOperands() != 4)
20545 return false;
20546 auto SV1 = cast<ShuffleVectorSDNode>(B);
20547 auto SV2 = cast<ShuffleVectorSDNode>(B.getOperand(0));
20548 int NumElts = B.getValueType().getVectorNumElements();
20549 int NumSubElts = NumElts / 4;
20550 for (int I = 0; I < NumSubElts; I++) {
20551 // <0,1,2,3,4,5,6,7,8,9,10,11,16,17,18,19>
20552 if (SV1->getMaskElt(I) != I ||
20553 SV1->getMaskElt(I + NumSubElts) != I + NumSubElts ||
20554 SV1->getMaskElt(I + NumSubElts * 2) != I + NumSubElts * 2 ||
20555 SV1->getMaskElt(I + NumSubElts * 3) != I + NumElts)
20556 return false;
20557 // <0,1,2,3,4,5,6,7,16,17,18,19,u,u,u,u>
20558 if (SV2->getMaskElt(I) != I ||
20559 SV2->getMaskElt(I + NumSubElts) != I + NumSubElts ||
20560 SV2->getMaskElt(I + NumSubElts * 2) != I + NumElts)
20561 return false;
20562 }
20563 auto *Ld0 = dyn_cast<LoadSDNode>(SV2->getOperand(0).getOperand(0));
20564 auto *Ld1 = dyn_cast<LoadSDNode>(SV2->getOperand(0).getOperand(1));
20565 auto *Ld2 = dyn_cast<LoadSDNode>(SV2->getOperand(1).getOperand(0));
20566 auto *Ld3 = dyn_cast<LoadSDNode>(B.getOperand(1).getOperand(0));
20567 if (!Ld0 || !Ld1 || !Ld2 || !Ld3 || !Ld0->isSimple() || !Ld1->isSimple() ||
20568 !Ld2->isSimple() || !Ld3->isSimple())
20569 return false;
20570 Loads.push_back(Ld0);
20571 Loads.push_back(Ld1);
20572 Loads.push_back(Ld2);
20573 Loads.push_back(Ld3);
20574 return true;
20575 }
20576 return false;
20577 }
20578
areLoadedOffsetButOtherwiseSame(SDValue Op0,SDValue Op1,SelectionDAG & DAG,unsigned & NumSubLoads)20579 static bool areLoadedOffsetButOtherwiseSame(SDValue Op0, SDValue Op1,
20580 SelectionDAG &DAG,
20581 unsigned &NumSubLoads) {
20582 if (!Op0.hasOneUse() || !Op1.hasOneUse())
20583 return false;
20584
20585 SmallVector<LoadSDNode *> Loads0, Loads1;
20586 if (isLoadOrMultipleLoads(Op0, Loads0) &&
20587 isLoadOrMultipleLoads(Op1, Loads1)) {
20588 if (NumSubLoads && Loads0.size() != NumSubLoads)
20589 return false;
20590 NumSubLoads = Loads0.size();
20591 return Loads0.size() == Loads1.size() &&
20592 all_of(zip(Loads0, Loads1), [&DAG](auto L) {
20593 unsigned Size = get<0>(L)->getValueType(0).getSizeInBits();
20594 return Size == get<1>(L)->getValueType(0).getSizeInBits() &&
20595 DAG.areNonVolatileConsecutiveLoads(get<1>(L), get<0>(L),
20596 Size / 8, 1);
20597 });
20598 }
20599
20600 if (Op0.getOpcode() != Op1.getOpcode())
20601 return false;
20602
20603 switch (Op0.getOpcode()) {
20604 case ISD::ADD:
20605 case ISD::SUB:
20606 return areLoadedOffsetButOtherwiseSame(Op0.getOperand(0), Op1.getOperand(0),
20607 DAG, NumSubLoads) &&
20608 areLoadedOffsetButOtherwiseSame(Op0.getOperand(1), Op1.getOperand(1),
20609 DAG, NumSubLoads);
20610 case ISD::SIGN_EXTEND:
20611 case ISD::ANY_EXTEND:
20612 case ISD::ZERO_EXTEND:
20613 EVT XVT = Op0.getOperand(0).getValueType();
20614 if (XVT.getScalarSizeInBits() != 8 && XVT.getScalarSizeInBits() != 16 &&
20615 XVT.getScalarSizeInBits() != 32)
20616 return false;
20617 return areLoadedOffsetButOtherwiseSame(Op0.getOperand(0), Op1.getOperand(0),
20618 DAG, NumSubLoads);
20619 }
20620 return false;
20621 }
20622
20623 // This method attempts to fold trees of add(ext(load p), shl(ext(load p+4))
20624 // into a single load of twice the size, that we extract the bottom part and top
20625 // part so that the shl can use a shll2 instruction. The two loads in that
20626 // example can also be larger trees of instructions, which are identical except
20627 // for the leaves which are all loads offset from the LHS, including
20628 // buildvectors of multiple loads. For example the RHS tree could be
20629 // sub(zext(buildvec(load p+4, load q+4)), zext(buildvec(load r+4, load s+4)))
20630 // Whilst it can be common for the larger loads to replace LDP instructions
20631 // (which doesn't gain anything on it's own), the larger loads can help create
20632 // more efficient code, and in buildvectors prevent the need for ld1 lane
20633 // inserts which can be slower than normal loads.
performExtBinopLoadFold(SDNode * N,SelectionDAG & DAG)20634 static SDValue performExtBinopLoadFold(SDNode *N, SelectionDAG &DAG) {
20635 EVT VT = N->getValueType(0);
20636 if (!VT.isFixedLengthVector() ||
20637 (VT.getScalarSizeInBits() != 16 && VT.getScalarSizeInBits() != 32 &&
20638 VT.getScalarSizeInBits() != 64))
20639 return SDValue();
20640
20641 SDValue Other = N->getOperand(0);
20642 SDValue Shift = N->getOperand(1);
20643 if (Shift.getOpcode() != ISD::SHL && N->getOpcode() != ISD::SUB)
20644 std::swap(Shift, Other);
20645 APInt ShiftAmt;
20646 if (Shift.getOpcode() != ISD::SHL || !Shift.hasOneUse() ||
20647 !ISD::isConstantSplatVector(Shift.getOperand(1).getNode(), ShiftAmt))
20648 return SDValue();
20649
20650 if (!ISD::isExtOpcode(Shift.getOperand(0).getOpcode()) ||
20651 !ISD::isExtOpcode(Other.getOpcode()) ||
20652 Shift.getOperand(0).getOperand(0).getValueType() !=
20653 Other.getOperand(0).getValueType() ||
20654 !Other.hasOneUse() || !Shift.getOperand(0).hasOneUse())
20655 return SDValue();
20656
20657 SDValue Op0 = Other.getOperand(0);
20658 SDValue Op1 = Shift.getOperand(0).getOperand(0);
20659
20660 unsigned NumSubLoads = 0;
20661 if (!areLoadedOffsetButOtherwiseSame(Op0, Op1, DAG, NumSubLoads))
20662 return SDValue();
20663
20664 // Attempt to rule out some unprofitable cases using heuristics (some working
20665 // around suboptimal code generation), notably if the extend not be able to
20666 // use ushll2 instructions as the types are not large enough. Otherwise zip's
20667 // will need to be created which can increase the instruction count.
20668 unsigned NumElts = Op0.getValueType().getVectorNumElements();
20669 unsigned NumSubElts = NumElts / NumSubLoads;
20670 if (NumSubElts * VT.getScalarSizeInBits() < 128 ||
20671 (Other.getOpcode() != Shift.getOperand(0).getOpcode() &&
20672 Op0.getValueType().getSizeInBits() < 128 &&
20673 !DAG.getTargetLoweringInfo().isTypeLegal(Op0.getValueType())))
20674 return SDValue();
20675
20676 // Recreate the tree with the new combined loads.
20677 std::function<SDValue(SDValue, SDValue, SelectionDAG &)> GenCombinedTree =
20678 [&GenCombinedTree](SDValue Op0, SDValue Op1, SelectionDAG &DAG) {
20679 EVT DVT =
20680 Op0.getValueType().getDoubleNumVectorElementsVT(*DAG.getContext());
20681
20682 SmallVector<LoadSDNode *> Loads0, Loads1;
20683 if (isLoadOrMultipleLoads(Op0, Loads0) &&
20684 isLoadOrMultipleLoads(Op1, Loads1)) {
20685 EVT LoadVT = EVT::getVectorVT(
20686 *DAG.getContext(), Op0.getValueType().getScalarType(),
20687 Op0.getValueType().getVectorNumElements() / Loads0.size());
20688 EVT DLoadVT = LoadVT.getDoubleNumVectorElementsVT(*DAG.getContext());
20689
20690 SmallVector<SDValue> NewLoads;
20691 for (const auto &[L0, L1] : zip(Loads0, Loads1)) {
20692 SDValue Load = DAG.getLoad(DLoadVT, SDLoc(L0), L0->getChain(),
20693 L0->getBasePtr(), L0->getPointerInfo(),
20694 L0->getOriginalAlign());
20695 DAG.makeEquivalentMemoryOrdering(L0, Load.getValue(1));
20696 DAG.makeEquivalentMemoryOrdering(L1, Load.getValue(1));
20697 NewLoads.push_back(Load);
20698 }
20699 return DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(Op0), DVT, NewLoads);
20700 }
20701
20702 SmallVector<SDValue> Ops;
20703 for (const auto &[O0, O1] : zip(Op0->op_values(), Op1->op_values()))
20704 Ops.push_back(GenCombinedTree(O0, O1, DAG));
20705 return DAG.getNode(Op0.getOpcode(), SDLoc(Op0), DVT, Ops);
20706 };
20707 SDValue NewOp = GenCombinedTree(Op0, Op1, DAG);
20708
20709 SmallVector<int> LowMask(NumElts, 0), HighMask(NumElts, 0);
20710 int Hi = NumSubElts, Lo = 0;
20711 for (unsigned i = 0; i < NumSubLoads; i++) {
20712 for (unsigned j = 0; j < NumSubElts; j++) {
20713 LowMask[i * NumSubElts + j] = Lo++;
20714 HighMask[i * NumSubElts + j] = Hi++;
20715 }
20716 Lo += NumSubElts;
20717 Hi += NumSubElts;
20718 }
20719 SDLoc DL(N);
20720 SDValue Ext0, Ext1;
20721 // Extract the top and bottom lanes, then extend the result. Possibly extend
20722 // the result then extract the lanes if the two operands match as it produces
20723 // slightly smaller code.
20724 if (Other.getOpcode() != Shift.getOperand(0).getOpcode()) {
20725 SDValue SubL = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Op0.getValueType(),
20726 NewOp, DAG.getConstant(0, DL, MVT::i64));
20727 SDValue SubH =
20728 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, Op0.getValueType(), NewOp,
20729 DAG.getConstant(NumSubElts * NumSubLoads, DL, MVT::i64));
20730 SDValue Extr0 =
20731 DAG.getVectorShuffle(Op0.getValueType(), DL, SubL, SubH, LowMask);
20732 SDValue Extr1 =
20733 DAG.getVectorShuffle(Op0.getValueType(), DL, SubL, SubH, HighMask);
20734 Ext0 = DAG.getNode(Other.getOpcode(), DL, VT, Extr0);
20735 Ext1 = DAG.getNode(Shift.getOperand(0).getOpcode(), DL, VT, Extr1);
20736 } else {
20737 EVT DVT = VT.getDoubleNumVectorElementsVT(*DAG.getContext());
20738 SDValue Ext = DAG.getNode(Other.getOpcode(), DL, DVT, NewOp);
20739 SDValue SubL = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Ext,
20740 DAG.getConstant(0, DL, MVT::i64));
20741 SDValue SubH =
20742 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, Ext,
20743 DAG.getConstant(NumSubElts * NumSubLoads, DL, MVT::i64));
20744 Ext0 = DAG.getVectorShuffle(VT, DL, SubL, SubH, LowMask);
20745 Ext1 = DAG.getVectorShuffle(VT, DL, SubL, SubH, HighMask);
20746 }
20747 SDValue NShift =
20748 DAG.getNode(Shift.getOpcode(), DL, VT, Ext1, Shift.getOperand(1));
20749 return DAG.getNode(N->getOpcode(), DL, VT, Ext0, NShift);
20750 }
20751
performAddSubCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)20752 static SDValue performAddSubCombine(SDNode *N,
20753 TargetLowering::DAGCombinerInfo &DCI) {
20754 // Try to change sum of two reductions.
20755 if (SDValue Val = performAddUADDVCombine(N, DCI.DAG))
20756 return Val;
20757 if (SDValue Val = performAddDotCombine(N, DCI.DAG))
20758 return Val;
20759 if (SDValue Val = performAddCSelIntoCSinc(N, DCI.DAG))
20760 return Val;
20761 if (SDValue Val = performNegCSelCombine(N, DCI.DAG))
20762 return Val;
20763 if (SDValue Val = performVectorExtCombine(N, DCI.DAG))
20764 return Val;
20765 if (SDValue Val = performAddCombineForShiftedOperands(N, DCI.DAG))
20766 return Val;
20767 if (SDValue Val = performSubAddMULCombine(N, DCI.DAG))
20768 return Val;
20769 if (SDValue Val = performSVEMulAddSubCombine(N, DCI))
20770 return Val;
20771 if (SDValue Val = performAddSubIntoVectorOp(N, DCI.DAG))
20772 return Val;
20773
20774 if (SDValue Val = performExtBinopLoadFold(N, DCI.DAG))
20775 return Val;
20776
20777 return performAddSubLongCombine(N, DCI);
20778 }
20779
20780 // Massage DAGs which we can use the high-half "long" operations on into
20781 // something isel will recognize better. E.g.
20782 //
20783 // (aarch64_neon_umull (extract_high vec) (dupv64 scalar)) -->
20784 // (aarch64_neon_umull (extract_high (v2i64 vec)))
20785 // (extract_high (v2i64 (dup128 scalar)))))
20786 //
tryCombineLongOpWithDup(unsigned IID,SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)20787 static SDValue tryCombineLongOpWithDup(unsigned IID, SDNode *N,
20788 TargetLowering::DAGCombinerInfo &DCI,
20789 SelectionDAG &DAG) {
20790 if (DCI.isBeforeLegalizeOps())
20791 return SDValue();
20792
20793 SDValue LHS = N->getOperand((IID == Intrinsic::not_intrinsic) ? 0 : 1);
20794 SDValue RHS = N->getOperand((IID == Intrinsic::not_intrinsic) ? 1 : 2);
20795 assert(LHS.getValueType().is64BitVector() &&
20796 RHS.getValueType().is64BitVector() &&
20797 "unexpected shape for long operation");
20798
20799 // Either node could be a DUP, but it's not worth doing both of them (you'd
20800 // just as well use the non-high version) so look for a corresponding extract
20801 // operation on the other "wing".
20802 if (isEssentiallyExtractHighSubvector(LHS)) {
20803 RHS = tryExtendDUPToExtractHigh(RHS, DAG);
20804 if (!RHS.getNode())
20805 return SDValue();
20806 } else if (isEssentiallyExtractHighSubvector(RHS)) {
20807 LHS = tryExtendDUPToExtractHigh(LHS, DAG);
20808 if (!LHS.getNode())
20809 return SDValue();
20810 } else
20811 return SDValue();
20812
20813 if (IID == Intrinsic::not_intrinsic)
20814 return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), LHS, RHS);
20815
20816 return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SDLoc(N), N->getValueType(0),
20817 N->getOperand(0), LHS, RHS);
20818 }
20819
tryCombineShiftImm(unsigned IID,SDNode * N,SelectionDAG & DAG)20820 static SDValue tryCombineShiftImm(unsigned IID, SDNode *N, SelectionDAG &DAG) {
20821 MVT ElemTy = N->getSimpleValueType(0).getScalarType();
20822 unsigned ElemBits = ElemTy.getSizeInBits();
20823
20824 int64_t ShiftAmount;
20825 if (BuildVectorSDNode *BVN = dyn_cast<BuildVectorSDNode>(N->getOperand(2))) {
20826 APInt SplatValue, SplatUndef;
20827 unsigned SplatBitSize;
20828 bool HasAnyUndefs;
20829 if (!BVN->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
20830 HasAnyUndefs, ElemBits) ||
20831 SplatBitSize != ElemBits)
20832 return SDValue();
20833
20834 ShiftAmount = SplatValue.getSExtValue();
20835 } else if (ConstantSDNode *CVN = dyn_cast<ConstantSDNode>(N->getOperand(2))) {
20836 ShiftAmount = CVN->getSExtValue();
20837 } else
20838 return SDValue();
20839
20840 // If the shift amount is zero, remove the shift intrinsic.
20841 if (ShiftAmount == 0 && IID != Intrinsic::aarch64_neon_sqshlu)
20842 return N->getOperand(1);
20843
20844 unsigned Opcode;
20845 bool IsRightShift;
20846 switch (IID) {
20847 default:
20848 llvm_unreachable("Unknown shift intrinsic");
20849 case Intrinsic::aarch64_neon_sqshl:
20850 Opcode = AArch64ISD::SQSHL_I;
20851 IsRightShift = false;
20852 break;
20853 case Intrinsic::aarch64_neon_uqshl:
20854 Opcode = AArch64ISD::UQSHL_I;
20855 IsRightShift = false;
20856 break;
20857 case Intrinsic::aarch64_neon_srshl:
20858 Opcode = AArch64ISD::SRSHR_I;
20859 IsRightShift = true;
20860 break;
20861 case Intrinsic::aarch64_neon_urshl:
20862 Opcode = AArch64ISD::URSHR_I;
20863 IsRightShift = true;
20864 break;
20865 case Intrinsic::aarch64_neon_sqshlu:
20866 Opcode = AArch64ISD::SQSHLU_I;
20867 IsRightShift = false;
20868 break;
20869 case Intrinsic::aarch64_neon_sshl:
20870 case Intrinsic::aarch64_neon_ushl:
20871 // For positive shift amounts we can use SHL, as ushl/sshl perform a regular
20872 // left shift for positive shift amounts. For negative shifts we can use a
20873 // VASHR/VLSHR as appropiate.
20874 if (ShiftAmount < 0) {
20875 Opcode = IID == Intrinsic::aarch64_neon_sshl ? AArch64ISD::VASHR
20876 : AArch64ISD::VLSHR;
20877 ShiftAmount = -ShiftAmount;
20878 } else
20879 Opcode = AArch64ISD::VSHL;
20880 IsRightShift = false;
20881 break;
20882 }
20883
20884 EVT VT = N->getValueType(0);
20885 SDValue Op = N->getOperand(1);
20886 SDLoc dl(N);
20887 if (VT == MVT::i64) {
20888 Op = DAG.getNode(ISD::SCALAR_TO_VECTOR, dl, MVT::v1i64, Op);
20889 VT = MVT::v1i64;
20890 }
20891
20892 if (IsRightShift && ShiftAmount <= -1 && ShiftAmount >= -(int)ElemBits) {
20893 Op = DAG.getNode(Opcode, dl, VT, Op,
20894 DAG.getConstant(-ShiftAmount, dl, MVT::i32));
20895 if (N->getValueType(0) == MVT::i64)
20896 Op = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i64, Op,
20897 DAG.getConstant(0, dl, MVT::i64));
20898 return Op;
20899 } else if (!IsRightShift && ShiftAmount >= 0 && ShiftAmount < ElemBits) {
20900 Op = DAG.getNode(Opcode, dl, VT, Op,
20901 DAG.getConstant(ShiftAmount, dl, MVT::i32));
20902 if (N->getValueType(0) == MVT::i64)
20903 Op = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, MVT::i64, Op,
20904 DAG.getConstant(0, dl, MVT::i64));
20905 return Op;
20906 }
20907
20908 return SDValue();
20909 }
20910
20911 // The CRC32[BH] instructions ignore the high bits of their data operand. Since
20912 // the intrinsics must be legal and take an i32, this means there's almost
20913 // certainly going to be a zext in the DAG which we can eliminate.
tryCombineCRC32(unsigned Mask,SDNode * N,SelectionDAG & DAG)20914 static SDValue tryCombineCRC32(unsigned Mask, SDNode *N, SelectionDAG &DAG) {
20915 SDValue AndN = N->getOperand(2);
20916 if (AndN.getOpcode() != ISD::AND)
20917 return SDValue();
20918
20919 ConstantSDNode *CMask = dyn_cast<ConstantSDNode>(AndN.getOperand(1));
20920 if (!CMask || CMask->getZExtValue() != Mask)
20921 return SDValue();
20922
20923 return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, SDLoc(N), MVT::i32,
20924 N->getOperand(0), N->getOperand(1), AndN.getOperand(0));
20925 }
20926
combineAcrossLanesIntrinsic(unsigned Opc,SDNode * N,SelectionDAG & DAG)20927 static SDValue combineAcrossLanesIntrinsic(unsigned Opc, SDNode *N,
20928 SelectionDAG &DAG) {
20929 SDLoc dl(N);
20930 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, N->getValueType(0),
20931 DAG.getNode(Opc, dl,
20932 N->getOperand(1).getSimpleValueType(),
20933 N->getOperand(1)),
20934 DAG.getConstant(0, dl, MVT::i64));
20935 }
20936
LowerSVEIntrinsicIndex(SDNode * N,SelectionDAG & DAG)20937 static SDValue LowerSVEIntrinsicIndex(SDNode *N, SelectionDAG &DAG) {
20938 SDLoc DL(N);
20939 SDValue Op1 = N->getOperand(1);
20940 SDValue Op2 = N->getOperand(2);
20941 EVT ScalarTy = Op2.getValueType();
20942 if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16))
20943 ScalarTy = MVT::i32;
20944
20945 // Lower index_vector(base, step) to mul(step step_vector(1)) + splat(base).
20946 SDValue StepVector = DAG.getStepVector(DL, N->getValueType(0));
20947 SDValue Step = DAG.getNode(ISD::SPLAT_VECTOR, DL, N->getValueType(0), Op2);
20948 SDValue Mul = DAG.getNode(ISD::MUL, DL, N->getValueType(0), StepVector, Step);
20949 SDValue Base = DAG.getNode(ISD::SPLAT_VECTOR, DL, N->getValueType(0), Op1);
20950 return DAG.getNode(ISD::ADD, DL, N->getValueType(0), Mul, Base);
20951 }
20952
LowerSVEIntrinsicDUP(SDNode * N,SelectionDAG & DAG)20953 static SDValue LowerSVEIntrinsicDUP(SDNode *N, SelectionDAG &DAG) {
20954 SDLoc dl(N);
20955 SDValue Scalar = N->getOperand(3);
20956 EVT ScalarTy = Scalar.getValueType();
20957
20958 if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16))
20959 Scalar = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i32, Scalar);
20960
20961 SDValue Passthru = N->getOperand(1);
20962 SDValue Pred = N->getOperand(2);
20963 return DAG.getNode(AArch64ISD::DUP_MERGE_PASSTHRU, dl, N->getValueType(0),
20964 Pred, Scalar, Passthru);
20965 }
20966
LowerSVEIntrinsicEXT(SDNode * N,SelectionDAG & DAG)20967 static SDValue LowerSVEIntrinsicEXT(SDNode *N, SelectionDAG &DAG) {
20968 SDLoc dl(N);
20969 LLVMContext &Ctx = *DAG.getContext();
20970 EVT VT = N->getValueType(0);
20971
20972 assert(VT.isScalableVector() && "Expected a scalable vector.");
20973
20974 // Current lowering only supports the SVE-ACLE types.
20975 if (VT.getSizeInBits().getKnownMinValue() != AArch64::SVEBitsPerBlock)
20976 return SDValue();
20977
20978 unsigned ElemSize = VT.getVectorElementType().getSizeInBits() / 8;
20979 unsigned ByteSize = VT.getSizeInBits().getKnownMinValue() / 8;
20980 EVT ByteVT =
20981 EVT::getVectorVT(Ctx, MVT::i8, ElementCount::getScalable(ByteSize));
20982
20983 // Convert everything to the domain of EXT (i.e bytes).
20984 SDValue Op0 = DAG.getNode(ISD::BITCAST, dl, ByteVT, N->getOperand(1));
20985 SDValue Op1 = DAG.getNode(ISD::BITCAST, dl, ByteVT, N->getOperand(2));
20986 SDValue Op2 = DAG.getNode(ISD::MUL, dl, MVT::i32, N->getOperand(3),
20987 DAG.getConstant(ElemSize, dl, MVT::i32));
20988
20989 SDValue EXT = DAG.getNode(AArch64ISD::EXT, dl, ByteVT, Op0, Op1, Op2);
20990 return DAG.getNode(ISD::BITCAST, dl, VT, EXT);
20991 }
20992
tryConvertSVEWideCompare(SDNode * N,ISD::CondCode CC,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)20993 static SDValue tryConvertSVEWideCompare(SDNode *N, ISD::CondCode CC,
20994 TargetLowering::DAGCombinerInfo &DCI,
20995 SelectionDAG &DAG) {
20996 if (DCI.isBeforeLegalize())
20997 return SDValue();
20998
20999 SDValue Comparator = N->getOperand(3);
21000 if (Comparator.getOpcode() == AArch64ISD::DUP ||
21001 Comparator.getOpcode() == ISD::SPLAT_VECTOR) {
21002 unsigned IID = getIntrinsicID(N);
21003 EVT VT = N->getValueType(0);
21004 EVT CmpVT = N->getOperand(2).getValueType();
21005 SDValue Pred = N->getOperand(1);
21006 SDValue Imm;
21007 SDLoc DL(N);
21008
21009 switch (IID) {
21010 default:
21011 llvm_unreachable("Called with wrong intrinsic!");
21012 break;
21013
21014 // Signed comparisons
21015 case Intrinsic::aarch64_sve_cmpeq_wide:
21016 case Intrinsic::aarch64_sve_cmpne_wide:
21017 case Intrinsic::aarch64_sve_cmpge_wide:
21018 case Intrinsic::aarch64_sve_cmpgt_wide:
21019 case Intrinsic::aarch64_sve_cmplt_wide:
21020 case Intrinsic::aarch64_sve_cmple_wide: {
21021 if (auto *CN = dyn_cast<ConstantSDNode>(Comparator.getOperand(0))) {
21022 int64_t ImmVal = CN->getSExtValue();
21023 if (ImmVal >= -16 && ImmVal <= 15)
21024 Imm = DAG.getConstant(ImmVal, DL, MVT::i32);
21025 else
21026 return SDValue();
21027 }
21028 break;
21029 }
21030 // Unsigned comparisons
21031 case Intrinsic::aarch64_sve_cmphs_wide:
21032 case Intrinsic::aarch64_sve_cmphi_wide:
21033 case Intrinsic::aarch64_sve_cmplo_wide:
21034 case Intrinsic::aarch64_sve_cmpls_wide: {
21035 if (auto *CN = dyn_cast<ConstantSDNode>(Comparator.getOperand(0))) {
21036 uint64_t ImmVal = CN->getZExtValue();
21037 if (ImmVal <= 127)
21038 Imm = DAG.getConstant(ImmVal, DL, MVT::i32);
21039 else
21040 return SDValue();
21041 }
21042 break;
21043 }
21044 }
21045
21046 if (!Imm)
21047 return SDValue();
21048
21049 SDValue Splat = DAG.getNode(ISD::SPLAT_VECTOR, DL, CmpVT, Imm);
21050 return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, VT, Pred,
21051 N->getOperand(2), Splat, DAG.getCondCode(CC));
21052 }
21053
21054 return SDValue();
21055 }
21056
getPTest(SelectionDAG & DAG,EVT VT,SDValue Pg,SDValue Op,AArch64CC::CondCode Cond)21057 static SDValue getPTest(SelectionDAG &DAG, EVT VT, SDValue Pg, SDValue Op,
21058 AArch64CC::CondCode Cond) {
21059 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
21060
21061 SDLoc DL(Op);
21062 assert(Op.getValueType().isScalableVector() &&
21063 TLI.isTypeLegal(Op.getValueType()) &&
21064 "Expected legal scalable vector type!");
21065 assert(Op.getValueType() == Pg.getValueType() &&
21066 "Expected same type for PTEST operands");
21067
21068 // Ensure target specific opcodes are using legal type.
21069 EVT OutVT = TLI.getTypeToTransformTo(*DAG.getContext(), VT);
21070 SDValue TVal = DAG.getConstant(1, DL, OutVT);
21071 SDValue FVal = DAG.getConstant(0, DL, OutVT);
21072
21073 // Ensure operands have type nxv16i1.
21074 if (Op.getValueType() != MVT::nxv16i1) {
21075 if ((Cond == AArch64CC::ANY_ACTIVE || Cond == AArch64CC::NONE_ACTIVE) &&
21076 isZeroingInactiveLanes(Op))
21077 Pg = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, Pg);
21078 else
21079 Pg = getSVEPredicateBitCast(MVT::nxv16i1, Pg, DAG);
21080 Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv16i1, Op);
21081 }
21082
21083 // Set condition code (CC) flags.
21084 SDValue Test = DAG.getNode(
21085 Cond == AArch64CC::ANY_ACTIVE ? AArch64ISD::PTEST_ANY : AArch64ISD::PTEST,
21086 DL, MVT::Other, Pg, Op);
21087
21088 // Convert CC to integer based on requested condition.
21089 // NOTE: Cond is inverted to promote CSEL's removal when it feeds a compare.
21090 SDValue CC = DAG.getConstant(getInvertedCondCode(Cond), DL, MVT::i32);
21091 SDValue Res = DAG.getNode(AArch64ISD::CSEL, DL, OutVT, FVal, TVal, CC, Test);
21092 return DAG.getZExtOrTrunc(Res, DL, VT);
21093 }
21094
combineSVEReductionInt(SDNode * N,unsigned Opc,SelectionDAG & DAG)21095 static SDValue combineSVEReductionInt(SDNode *N, unsigned Opc,
21096 SelectionDAG &DAG) {
21097 SDLoc DL(N);
21098
21099 SDValue Pred = N->getOperand(1);
21100 SDValue VecToReduce = N->getOperand(2);
21101
21102 // NOTE: The integer reduction's result type is not always linked to the
21103 // operand's element type so we construct it from the intrinsic's result type.
21104 EVT ReduceVT = getPackedSVEVectorVT(N->getValueType(0));
21105 SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, VecToReduce);
21106
21107 // SVE reductions set the whole vector register with the first element
21108 // containing the reduction result, which we'll now extract.
21109 SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
21110 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce,
21111 Zero);
21112 }
21113
combineSVEReductionFP(SDNode * N,unsigned Opc,SelectionDAG & DAG)21114 static SDValue combineSVEReductionFP(SDNode *N, unsigned Opc,
21115 SelectionDAG &DAG) {
21116 SDLoc DL(N);
21117
21118 SDValue Pred = N->getOperand(1);
21119 SDValue VecToReduce = N->getOperand(2);
21120
21121 EVT ReduceVT = VecToReduce.getValueType();
21122 SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, VecToReduce);
21123
21124 // SVE reductions set the whole vector register with the first element
21125 // containing the reduction result, which we'll now extract.
21126 SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
21127 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce,
21128 Zero);
21129 }
21130
combineSVEReductionOrderedFP(SDNode * N,unsigned Opc,SelectionDAG & DAG)21131 static SDValue combineSVEReductionOrderedFP(SDNode *N, unsigned Opc,
21132 SelectionDAG &DAG) {
21133 SDLoc DL(N);
21134
21135 SDValue Pred = N->getOperand(1);
21136 SDValue InitVal = N->getOperand(2);
21137 SDValue VecToReduce = N->getOperand(3);
21138 EVT ReduceVT = VecToReduce.getValueType();
21139
21140 // Ordered reductions use the first lane of the result vector as the
21141 // reduction's initial value.
21142 SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
21143 InitVal = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ReduceVT,
21144 DAG.getUNDEF(ReduceVT), InitVal, Zero);
21145
21146 SDValue Reduce = DAG.getNode(Opc, DL, ReduceVT, Pred, InitVal, VecToReduce);
21147
21148 // SVE reductions set the whole vector register with the first element
21149 // containing the reduction result, which we'll now extract.
21150 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, N->getValueType(0), Reduce,
21151 Zero);
21152 }
21153
21154 // If a merged operation has no inactive lanes we can relax it to a predicated
21155 // or unpredicated operation, which potentially allows better isel (perhaps
21156 // using immediate forms) or relaxing register reuse requirements.
convertMergedOpToPredOp(SDNode * N,unsigned Opc,SelectionDAG & DAG,bool UnpredOp=false,bool SwapOperands=false)21157 static SDValue convertMergedOpToPredOp(SDNode *N, unsigned Opc,
21158 SelectionDAG &DAG, bool UnpredOp = false,
21159 bool SwapOperands = false) {
21160 assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN && "Expected intrinsic!");
21161 assert(N->getNumOperands() == 4 && "Expected 3 operand intrinsic!");
21162 SDValue Pg = N->getOperand(1);
21163 SDValue Op1 = N->getOperand(SwapOperands ? 3 : 2);
21164 SDValue Op2 = N->getOperand(SwapOperands ? 2 : 3);
21165
21166 // ISD way to specify an all active predicate.
21167 if (isAllActivePredicate(DAG, Pg)) {
21168 if (UnpredOp)
21169 return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Op1, Op2);
21170
21171 return DAG.getNode(Opc, SDLoc(N), N->getValueType(0), Pg, Op1, Op2);
21172 }
21173
21174 // FUTURE: SplatVector(true)
21175 return SDValue();
21176 }
21177
tryCombineWhileLo(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)21178 static SDValue tryCombineWhileLo(SDNode *N,
21179 TargetLowering::DAGCombinerInfo &DCI,
21180 const AArch64Subtarget *Subtarget) {
21181 if (DCI.isBeforeLegalize())
21182 return SDValue();
21183
21184 if (!Subtarget->hasSVE2p1())
21185 return SDValue();
21186
21187 if (!N->hasNUsesOfValue(2, 0))
21188 return SDValue();
21189
21190 const uint64_t HalfSize = N->getValueType(0).getVectorMinNumElements() / 2;
21191 if (HalfSize < 2)
21192 return SDValue();
21193
21194 auto It = N->use_begin();
21195 SDNode *Lo = *It++;
21196 SDNode *Hi = *It;
21197
21198 if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
21199 Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR)
21200 return SDValue();
21201
21202 uint64_t OffLo = Lo->getConstantOperandVal(1);
21203 uint64_t OffHi = Hi->getConstantOperandVal(1);
21204
21205 if (OffLo > OffHi) {
21206 std::swap(Lo, Hi);
21207 std::swap(OffLo, OffHi);
21208 }
21209
21210 if (OffLo != 0 || OffHi != HalfSize)
21211 return SDValue();
21212
21213 EVT HalfVec = Lo->getValueType(0);
21214 if (HalfVec != Hi->getValueType(0) ||
21215 HalfVec.getVectorElementCount() != ElementCount::getScalable(HalfSize))
21216 return SDValue();
21217
21218 SelectionDAG &DAG = DCI.DAG;
21219 SDLoc DL(N);
21220 SDValue ID =
21221 DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
21222 SDValue Idx = N->getOperand(1);
21223 SDValue TC = N->getOperand(2);
21224 if (Idx.getValueType() != MVT::i64) {
21225 Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
21226 TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
21227 }
21228 auto R =
21229 DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
21230 {Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
21231
21232 DCI.CombineTo(Lo, R.getValue(0));
21233 DCI.CombineTo(Hi, R.getValue(1));
21234
21235 return SDValue(N, 0);
21236 }
21237
performIntrinsicCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)21238 static SDValue performIntrinsicCombine(SDNode *N,
21239 TargetLowering::DAGCombinerInfo &DCI,
21240 const AArch64Subtarget *Subtarget) {
21241 SelectionDAG &DAG = DCI.DAG;
21242 unsigned IID = getIntrinsicID(N);
21243 switch (IID) {
21244 default:
21245 break;
21246 case Intrinsic::aarch64_neon_vcvtfxs2fp:
21247 case Intrinsic::aarch64_neon_vcvtfxu2fp:
21248 return tryCombineFixedPointConvert(N, DCI, DAG);
21249 case Intrinsic::aarch64_neon_saddv:
21250 return combineAcrossLanesIntrinsic(AArch64ISD::SADDV, N, DAG);
21251 case Intrinsic::aarch64_neon_uaddv:
21252 return combineAcrossLanesIntrinsic(AArch64ISD::UADDV, N, DAG);
21253 case Intrinsic::aarch64_neon_sminv:
21254 return combineAcrossLanesIntrinsic(AArch64ISD::SMINV, N, DAG);
21255 case Intrinsic::aarch64_neon_uminv:
21256 return combineAcrossLanesIntrinsic(AArch64ISD::UMINV, N, DAG);
21257 case Intrinsic::aarch64_neon_smaxv:
21258 return combineAcrossLanesIntrinsic(AArch64ISD::SMAXV, N, DAG);
21259 case Intrinsic::aarch64_neon_umaxv:
21260 return combineAcrossLanesIntrinsic(AArch64ISD::UMAXV, N, DAG);
21261 case Intrinsic::aarch64_neon_fmax:
21262 return DAG.getNode(ISD::FMAXIMUM, SDLoc(N), N->getValueType(0),
21263 N->getOperand(1), N->getOperand(2));
21264 case Intrinsic::aarch64_neon_fmin:
21265 return DAG.getNode(ISD::FMINIMUM, SDLoc(N), N->getValueType(0),
21266 N->getOperand(1), N->getOperand(2));
21267 case Intrinsic::aarch64_neon_fmaxnm:
21268 return DAG.getNode(ISD::FMAXNUM, SDLoc(N), N->getValueType(0),
21269 N->getOperand(1), N->getOperand(2));
21270 case Intrinsic::aarch64_neon_fminnm:
21271 return DAG.getNode(ISD::FMINNUM, SDLoc(N), N->getValueType(0),
21272 N->getOperand(1), N->getOperand(2));
21273 case Intrinsic::aarch64_neon_smull:
21274 return DAG.getNode(AArch64ISD::SMULL, SDLoc(N), N->getValueType(0),
21275 N->getOperand(1), N->getOperand(2));
21276 case Intrinsic::aarch64_neon_umull:
21277 return DAG.getNode(AArch64ISD::UMULL, SDLoc(N), N->getValueType(0),
21278 N->getOperand(1), N->getOperand(2));
21279 case Intrinsic::aarch64_neon_pmull:
21280 return DAG.getNode(AArch64ISD::PMULL, SDLoc(N), N->getValueType(0),
21281 N->getOperand(1), N->getOperand(2));
21282 case Intrinsic::aarch64_neon_sqdmull:
21283 return tryCombineLongOpWithDup(IID, N, DCI, DAG);
21284 case Intrinsic::aarch64_neon_sqshl:
21285 case Intrinsic::aarch64_neon_uqshl:
21286 case Intrinsic::aarch64_neon_sqshlu:
21287 case Intrinsic::aarch64_neon_srshl:
21288 case Intrinsic::aarch64_neon_urshl:
21289 case Intrinsic::aarch64_neon_sshl:
21290 case Intrinsic::aarch64_neon_ushl:
21291 return tryCombineShiftImm(IID, N, DAG);
21292 case Intrinsic::aarch64_neon_sabd:
21293 return DAG.getNode(ISD::ABDS, SDLoc(N), N->getValueType(0),
21294 N->getOperand(1), N->getOperand(2));
21295 case Intrinsic::aarch64_neon_uabd:
21296 return DAG.getNode(ISD::ABDU, SDLoc(N), N->getValueType(0),
21297 N->getOperand(1), N->getOperand(2));
21298 case Intrinsic::aarch64_crc32b:
21299 case Intrinsic::aarch64_crc32cb:
21300 return tryCombineCRC32(0xff, N, DAG);
21301 case Intrinsic::aarch64_crc32h:
21302 case Intrinsic::aarch64_crc32ch:
21303 return tryCombineCRC32(0xffff, N, DAG);
21304 case Intrinsic::aarch64_sve_saddv:
21305 // There is no i64 version of SADDV because the sign is irrelevant.
21306 if (N->getOperand(2)->getValueType(0).getVectorElementType() == MVT::i64)
21307 return combineSVEReductionInt(N, AArch64ISD::UADDV_PRED, DAG);
21308 else
21309 return combineSVEReductionInt(N, AArch64ISD::SADDV_PRED, DAG);
21310 case Intrinsic::aarch64_sve_uaddv:
21311 return combineSVEReductionInt(N, AArch64ISD::UADDV_PRED, DAG);
21312 case Intrinsic::aarch64_sve_smaxv:
21313 return combineSVEReductionInt(N, AArch64ISD::SMAXV_PRED, DAG);
21314 case Intrinsic::aarch64_sve_umaxv:
21315 return combineSVEReductionInt(N, AArch64ISD::UMAXV_PRED, DAG);
21316 case Intrinsic::aarch64_sve_sminv:
21317 return combineSVEReductionInt(N, AArch64ISD::SMINV_PRED, DAG);
21318 case Intrinsic::aarch64_sve_uminv:
21319 return combineSVEReductionInt(N, AArch64ISD::UMINV_PRED, DAG);
21320 case Intrinsic::aarch64_sve_orv:
21321 return combineSVEReductionInt(N, AArch64ISD::ORV_PRED, DAG);
21322 case Intrinsic::aarch64_sve_eorv:
21323 return combineSVEReductionInt(N, AArch64ISD::EORV_PRED, DAG);
21324 case Intrinsic::aarch64_sve_andv:
21325 return combineSVEReductionInt(N, AArch64ISD::ANDV_PRED, DAG);
21326 case Intrinsic::aarch64_sve_index:
21327 return LowerSVEIntrinsicIndex(N, DAG);
21328 case Intrinsic::aarch64_sve_dup:
21329 return LowerSVEIntrinsicDUP(N, DAG);
21330 case Intrinsic::aarch64_sve_dup_x:
21331 return DAG.getNode(ISD::SPLAT_VECTOR, SDLoc(N), N->getValueType(0),
21332 N->getOperand(1));
21333 case Intrinsic::aarch64_sve_ext:
21334 return LowerSVEIntrinsicEXT(N, DAG);
21335 case Intrinsic::aarch64_sve_mul_u:
21336 return DAG.getNode(AArch64ISD::MUL_PRED, SDLoc(N), N->getValueType(0),
21337 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21338 case Intrinsic::aarch64_sve_smulh_u:
21339 return DAG.getNode(AArch64ISD::MULHS_PRED, SDLoc(N), N->getValueType(0),
21340 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21341 case Intrinsic::aarch64_sve_umulh_u:
21342 return DAG.getNode(AArch64ISD::MULHU_PRED, SDLoc(N), N->getValueType(0),
21343 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21344 case Intrinsic::aarch64_sve_smin_u:
21345 return DAG.getNode(AArch64ISD::SMIN_PRED, SDLoc(N), N->getValueType(0),
21346 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21347 case Intrinsic::aarch64_sve_umin_u:
21348 return DAG.getNode(AArch64ISD::UMIN_PRED, SDLoc(N), N->getValueType(0),
21349 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21350 case Intrinsic::aarch64_sve_smax_u:
21351 return DAG.getNode(AArch64ISD::SMAX_PRED, SDLoc(N), N->getValueType(0),
21352 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21353 case Intrinsic::aarch64_sve_umax_u:
21354 return DAG.getNode(AArch64ISD::UMAX_PRED, SDLoc(N), N->getValueType(0),
21355 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21356 case Intrinsic::aarch64_sve_lsl_u:
21357 return DAG.getNode(AArch64ISD::SHL_PRED, SDLoc(N), N->getValueType(0),
21358 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21359 case Intrinsic::aarch64_sve_lsr_u:
21360 return DAG.getNode(AArch64ISD::SRL_PRED, SDLoc(N), N->getValueType(0),
21361 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21362 case Intrinsic::aarch64_sve_asr_u:
21363 return DAG.getNode(AArch64ISD::SRA_PRED, SDLoc(N), N->getValueType(0),
21364 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21365 case Intrinsic::aarch64_sve_fadd_u:
21366 return DAG.getNode(AArch64ISD::FADD_PRED, SDLoc(N), N->getValueType(0),
21367 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21368 case Intrinsic::aarch64_sve_fdiv_u:
21369 return DAG.getNode(AArch64ISD::FDIV_PRED, SDLoc(N), N->getValueType(0),
21370 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21371 case Intrinsic::aarch64_sve_fmax_u:
21372 return DAG.getNode(AArch64ISD::FMAX_PRED, SDLoc(N), N->getValueType(0),
21373 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21374 case Intrinsic::aarch64_sve_fmaxnm_u:
21375 return DAG.getNode(AArch64ISD::FMAXNM_PRED, SDLoc(N), N->getValueType(0),
21376 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21377 case Intrinsic::aarch64_sve_fmla_u:
21378 return DAG.getNode(AArch64ISD::FMA_PRED, SDLoc(N), N->getValueType(0),
21379 N->getOperand(1), N->getOperand(3), N->getOperand(4),
21380 N->getOperand(2));
21381 case Intrinsic::aarch64_sve_fmin_u:
21382 return DAG.getNode(AArch64ISD::FMIN_PRED, SDLoc(N), N->getValueType(0),
21383 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21384 case Intrinsic::aarch64_sve_fminnm_u:
21385 return DAG.getNode(AArch64ISD::FMINNM_PRED, SDLoc(N), N->getValueType(0),
21386 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21387 case Intrinsic::aarch64_sve_fmul_u:
21388 return DAG.getNode(AArch64ISD::FMUL_PRED, SDLoc(N), N->getValueType(0),
21389 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21390 case Intrinsic::aarch64_sve_fsub_u:
21391 return DAG.getNode(AArch64ISD::FSUB_PRED, SDLoc(N), N->getValueType(0),
21392 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21393 case Intrinsic::aarch64_sve_add_u:
21394 return DAG.getNode(ISD::ADD, SDLoc(N), N->getValueType(0), N->getOperand(2),
21395 N->getOperand(3));
21396 case Intrinsic::aarch64_sve_sub_u:
21397 return DAG.getNode(ISD::SUB, SDLoc(N), N->getValueType(0), N->getOperand(2),
21398 N->getOperand(3));
21399 case Intrinsic::aarch64_sve_subr:
21400 return convertMergedOpToPredOp(N, ISD::SUB, DAG, true, true);
21401 case Intrinsic::aarch64_sve_and_u:
21402 return DAG.getNode(ISD::AND, SDLoc(N), N->getValueType(0), N->getOperand(2),
21403 N->getOperand(3));
21404 case Intrinsic::aarch64_sve_bic_u:
21405 return DAG.getNode(AArch64ISD::BIC, SDLoc(N), N->getValueType(0),
21406 N->getOperand(2), N->getOperand(3));
21407 case Intrinsic::aarch64_sve_eor_u:
21408 return DAG.getNode(ISD::XOR, SDLoc(N), N->getValueType(0), N->getOperand(2),
21409 N->getOperand(3));
21410 case Intrinsic::aarch64_sve_orr_u:
21411 return DAG.getNode(ISD::OR, SDLoc(N), N->getValueType(0), N->getOperand(2),
21412 N->getOperand(3));
21413 case Intrinsic::aarch64_sve_sabd_u:
21414 return DAG.getNode(ISD::ABDS, SDLoc(N), N->getValueType(0),
21415 N->getOperand(2), N->getOperand(3));
21416 case Intrinsic::aarch64_sve_uabd_u:
21417 return DAG.getNode(ISD::ABDU, SDLoc(N), N->getValueType(0),
21418 N->getOperand(2), N->getOperand(3));
21419 case Intrinsic::aarch64_sve_sdiv_u:
21420 return DAG.getNode(AArch64ISD::SDIV_PRED, SDLoc(N), N->getValueType(0),
21421 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21422 case Intrinsic::aarch64_sve_udiv_u:
21423 return DAG.getNode(AArch64ISD::UDIV_PRED, SDLoc(N), N->getValueType(0),
21424 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21425 case Intrinsic::aarch64_sve_sqadd:
21426 return convertMergedOpToPredOp(N, ISD::SADDSAT, DAG, true);
21427 case Intrinsic::aarch64_sve_sqsub_u:
21428 return DAG.getNode(ISD::SSUBSAT, SDLoc(N), N->getValueType(0),
21429 N->getOperand(2), N->getOperand(3));
21430 case Intrinsic::aarch64_sve_uqadd:
21431 return convertMergedOpToPredOp(N, ISD::UADDSAT, DAG, true);
21432 case Intrinsic::aarch64_sve_uqsub_u:
21433 return DAG.getNode(ISD::USUBSAT, SDLoc(N), N->getValueType(0),
21434 N->getOperand(2), N->getOperand(3));
21435 case Intrinsic::aarch64_sve_sqadd_x:
21436 return DAG.getNode(ISD::SADDSAT, SDLoc(N), N->getValueType(0),
21437 N->getOperand(1), N->getOperand(2));
21438 case Intrinsic::aarch64_sve_sqsub_x:
21439 return DAG.getNode(ISD::SSUBSAT, SDLoc(N), N->getValueType(0),
21440 N->getOperand(1), N->getOperand(2));
21441 case Intrinsic::aarch64_sve_uqadd_x:
21442 return DAG.getNode(ISD::UADDSAT, SDLoc(N), N->getValueType(0),
21443 N->getOperand(1), N->getOperand(2));
21444 case Intrinsic::aarch64_sve_uqsub_x:
21445 return DAG.getNode(ISD::USUBSAT, SDLoc(N), N->getValueType(0),
21446 N->getOperand(1), N->getOperand(2));
21447 case Intrinsic::aarch64_sve_asrd:
21448 return DAG.getNode(AArch64ISD::SRAD_MERGE_OP1, SDLoc(N), N->getValueType(0),
21449 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21450 case Intrinsic::aarch64_sve_cmphs:
21451 if (!N->getOperand(2).getValueType().isFloatingPoint())
21452 return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
21453 N->getValueType(0), N->getOperand(1), N->getOperand(2),
21454 N->getOperand(3), DAG.getCondCode(ISD::SETUGE));
21455 break;
21456 case Intrinsic::aarch64_sve_cmphi:
21457 if (!N->getOperand(2).getValueType().isFloatingPoint())
21458 return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
21459 N->getValueType(0), N->getOperand(1), N->getOperand(2),
21460 N->getOperand(3), DAG.getCondCode(ISD::SETUGT));
21461 break;
21462 case Intrinsic::aarch64_sve_fcmpge:
21463 case Intrinsic::aarch64_sve_cmpge:
21464 return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
21465 N->getValueType(0), N->getOperand(1), N->getOperand(2),
21466 N->getOperand(3), DAG.getCondCode(ISD::SETGE));
21467 break;
21468 case Intrinsic::aarch64_sve_fcmpgt:
21469 case Intrinsic::aarch64_sve_cmpgt:
21470 return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
21471 N->getValueType(0), N->getOperand(1), N->getOperand(2),
21472 N->getOperand(3), DAG.getCondCode(ISD::SETGT));
21473 break;
21474 case Intrinsic::aarch64_sve_fcmpeq:
21475 case Intrinsic::aarch64_sve_cmpeq:
21476 return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
21477 N->getValueType(0), N->getOperand(1), N->getOperand(2),
21478 N->getOperand(3), DAG.getCondCode(ISD::SETEQ));
21479 break;
21480 case Intrinsic::aarch64_sve_fcmpne:
21481 case Intrinsic::aarch64_sve_cmpne:
21482 return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
21483 N->getValueType(0), N->getOperand(1), N->getOperand(2),
21484 N->getOperand(3), DAG.getCondCode(ISD::SETNE));
21485 break;
21486 case Intrinsic::aarch64_sve_fcmpuo:
21487 return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, SDLoc(N),
21488 N->getValueType(0), N->getOperand(1), N->getOperand(2),
21489 N->getOperand(3), DAG.getCondCode(ISD::SETUO));
21490 break;
21491 case Intrinsic::aarch64_sve_fadda:
21492 return combineSVEReductionOrderedFP(N, AArch64ISD::FADDA_PRED, DAG);
21493 case Intrinsic::aarch64_sve_faddv:
21494 return combineSVEReductionFP(N, AArch64ISD::FADDV_PRED, DAG);
21495 case Intrinsic::aarch64_sve_fmaxnmv:
21496 return combineSVEReductionFP(N, AArch64ISD::FMAXNMV_PRED, DAG);
21497 case Intrinsic::aarch64_sve_fmaxv:
21498 return combineSVEReductionFP(N, AArch64ISD::FMAXV_PRED, DAG);
21499 case Intrinsic::aarch64_sve_fminnmv:
21500 return combineSVEReductionFP(N, AArch64ISD::FMINNMV_PRED, DAG);
21501 case Intrinsic::aarch64_sve_fminv:
21502 return combineSVEReductionFP(N, AArch64ISD::FMINV_PRED, DAG);
21503 case Intrinsic::aarch64_sve_sel:
21504 return DAG.getNode(ISD::VSELECT, SDLoc(N), N->getValueType(0),
21505 N->getOperand(1), N->getOperand(2), N->getOperand(3));
21506 case Intrinsic::aarch64_sve_cmpeq_wide:
21507 return tryConvertSVEWideCompare(N, ISD::SETEQ, DCI, DAG);
21508 case Intrinsic::aarch64_sve_cmpne_wide:
21509 return tryConvertSVEWideCompare(N, ISD::SETNE, DCI, DAG);
21510 case Intrinsic::aarch64_sve_cmpge_wide:
21511 return tryConvertSVEWideCompare(N, ISD::SETGE, DCI, DAG);
21512 case Intrinsic::aarch64_sve_cmpgt_wide:
21513 return tryConvertSVEWideCompare(N, ISD::SETGT, DCI, DAG);
21514 case Intrinsic::aarch64_sve_cmplt_wide:
21515 return tryConvertSVEWideCompare(N, ISD::SETLT, DCI, DAG);
21516 case Intrinsic::aarch64_sve_cmple_wide:
21517 return tryConvertSVEWideCompare(N, ISD::SETLE, DCI, DAG);
21518 case Intrinsic::aarch64_sve_cmphs_wide:
21519 return tryConvertSVEWideCompare(N, ISD::SETUGE, DCI, DAG);
21520 case Intrinsic::aarch64_sve_cmphi_wide:
21521 return tryConvertSVEWideCompare(N, ISD::SETUGT, DCI, DAG);
21522 case Intrinsic::aarch64_sve_cmplo_wide:
21523 return tryConvertSVEWideCompare(N, ISD::SETULT, DCI, DAG);
21524 case Intrinsic::aarch64_sve_cmpls_wide:
21525 return tryConvertSVEWideCompare(N, ISD::SETULE, DCI, DAG);
21526 case Intrinsic::aarch64_sve_ptest_any:
21527 return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2),
21528 AArch64CC::ANY_ACTIVE);
21529 case Intrinsic::aarch64_sve_ptest_first:
21530 return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2),
21531 AArch64CC::FIRST_ACTIVE);
21532 case Intrinsic::aarch64_sve_ptest_last:
21533 return getPTest(DAG, N->getValueType(0), N->getOperand(1), N->getOperand(2),
21534 AArch64CC::LAST_ACTIVE);
21535 case Intrinsic::aarch64_sve_whilelo:
21536 return tryCombineWhileLo(N, DCI, Subtarget);
21537 }
21538 return SDValue();
21539 }
21540
isCheapToExtend(const SDValue & N)21541 static bool isCheapToExtend(const SDValue &N) {
21542 unsigned OC = N->getOpcode();
21543 return OC == ISD::LOAD || OC == ISD::MLOAD ||
21544 ISD::isConstantSplatVectorAllZeros(N.getNode());
21545 }
21546
21547 static SDValue
performSignExtendSetCCCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)21548 performSignExtendSetCCCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
21549 SelectionDAG &DAG) {
21550 // If we have (sext (setcc A B)) and A and B are cheap to extend,
21551 // we can move the sext into the arguments and have the same result. For
21552 // example, if A and B are both loads, we can make those extending loads and
21553 // avoid an extra instruction. This pattern appears often in VLS code
21554 // generation where the inputs to the setcc have a different size to the
21555 // instruction that wants to use the result of the setcc.
21556 assert(N->getOpcode() == ISD::SIGN_EXTEND &&
21557 N->getOperand(0)->getOpcode() == ISD::SETCC);
21558 const SDValue SetCC = N->getOperand(0);
21559
21560 const SDValue CCOp0 = SetCC.getOperand(0);
21561 const SDValue CCOp1 = SetCC.getOperand(1);
21562 if (!CCOp0->getValueType(0).isInteger() ||
21563 !CCOp1->getValueType(0).isInteger())
21564 return SDValue();
21565
21566 ISD::CondCode Code =
21567 cast<CondCodeSDNode>(SetCC->getOperand(2).getNode())->get();
21568
21569 ISD::NodeType ExtType =
21570 isSignedIntSetCC(Code) ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
21571
21572 if (isCheapToExtend(SetCC.getOperand(0)) &&
21573 isCheapToExtend(SetCC.getOperand(1))) {
21574 const SDValue Ext1 =
21575 DAG.getNode(ExtType, SDLoc(N), N->getValueType(0), CCOp0);
21576 const SDValue Ext2 =
21577 DAG.getNode(ExtType, SDLoc(N), N->getValueType(0), CCOp1);
21578
21579 return DAG.getSetCC(
21580 SDLoc(SetCC), N->getValueType(0), Ext1, Ext2,
21581 cast<CondCodeSDNode>(SetCC->getOperand(2).getNode())->get());
21582 }
21583
21584 return SDValue();
21585 }
21586
performExtendCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)21587 static SDValue performExtendCombine(SDNode *N,
21588 TargetLowering::DAGCombinerInfo &DCI,
21589 SelectionDAG &DAG) {
21590 // If we see something like (zext (sabd (extract_high ...), (DUP ...))) then
21591 // we can convert that DUP into another extract_high (of a bigger DUP), which
21592 // helps the backend to decide that an sabdl2 would be useful, saving a real
21593 // extract_high operation.
21594 if (!DCI.isBeforeLegalizeOps() && N->getOpcode() == ISD::ZERO_EXTEND &&
21595 (N->getOperand(0).getOpcode() == ISD::ABDU ||
21596 N->getOperand(0).getOpcode() == ISD::ABDS)) {
21597 SDNode *ABDNode = N->getOperand(0).getNode();
21598 SDValue NewABD =
21599 tryCombineLongOpWithDup(Intrinsic::not_intrinsic, ABDNode, DCI, DAG);
21600 if (!NewABD.getNode())
21601 return SDValue();
21602
21603 return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N), N->getValueType(0), NewABD);
21604 }
21605
21606 if (N->getValueType(0).isFixedLengthVector() &&
21607 N->getOpcode() == ISD::SIGN_EXTEND &&
21608 N->getOperand(0)->getOpcode() == ISD::SETCC)
21609 return performSignExtendSetCCCombine(N, DCI, DAG);
21610
21611 return SDValue();
21612 }
21613
splitStoreSplat(SelectionDAG & DAG,StoreSDNode & St,SDValue SplatVal,unsigned NumVecElts)21614 static SDValue splitStoreSplat(SelectionDAG &DAG, StoreSDNode &St,
21615 SDValue SplatVal, unsigned NumVecElts) {
21616 assert(!St.isTruncatingStore() && "cannot split truncating vector store");
21617 Align OrigAlignment = St.getAlign();
21618 unsigned EltOffset = SplatVal.getValueType().getSizeInBits() / 8;
21619
21620 // Create scalar stores. This is at least as good as the code sequence for a
21621 // split unaligned store which is a dup.s, ext.b, and two stores.
21622 // Most of the time the three stores should be replaced by store pair
21623 // instructions (stp).
21624 SDLoc DL(&St);
21625 SDValue BasePtr = St.getBasePtr();
21626 uint64_t BaseOffset = 0;
21627
21628 const MachinePointerInfo &PtrInfo = St.getPointerInfo();
21629 SDValue NewST1 =
21630 DAG.getStore(St.getChain(), DL, SplatVal, BasePtr, PtrInfo,
21631 OrigAlignment, St.getMemOperand()->getFlags());
21632
21633 // As this in ISel, we will not merge this add which may degrade results.
21634 if (BasePtr->getOpcode() == ISD::ADD &&
21635 isa<ConstantSDNode>(BasePtr->getOperand(1))) {
21636 BaseOffset = cast<ConstantSDNode>(BasePtr->getOperand(1))->getSExtValue();
21637 BasePtr = BasePtr->getOperand(0);
21638 }
21639
21640 unsigned Offset = EltOffset;
21641 while (--NumVecElts) {
21642 Align Alignment = commonAlignment(OrigAlignment, Offset);
21643 SDValue OffsetPtr =
21644 DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr,
21645 DAG.getConstant(BaseOffset + Offset, DL, MVT::i64));
21646 NewST1 = DAG.getStore(NewST1.getValue(0), DL, SplatVal, OffsetPtr,
21647 PtrInfo.getWithOffset(Offset), Alignment,
21648 St.getMemOperand()->getFlags());
21649 Offset += EltOffset;
21650 }
21651 return NewST1;
21652 }
21653
21654 // Returns an SVE type that ContentTy can be trivially sign or zero extended
21655 // into.
getSVEContainerType(EVT ContentTy)21656 static MVT getSVEContainerType(EVT ContentTy) {
21657 assert(ContentTy.isSimple() && "No SVE containers for extended types");
21658
21659 switch (ContentTy.getSimpleVT().SimpleTy) {
21660 default:
21661 llvm_unreachable("No known SVE container for this MVT type");
21662 case MVT::nxv2i8:
21663 case MVT::nxv2i16:
21664 case MVT::nxv2i32:
21665 case MVT::nxv2i64:
21666 case MVT::nxv2f32:
21667 case MVT::nxv2f64:
21668 return MVT::nxv2i64;
21669 case MVT::nxv4i8:
21670 case MVT::nxv4i16:
21671 case MVT::nxv4i32:
21672 case MVT::nxv4f32:
21673 return MVT::nxv4i32;
21674 case MVT::nxv8i8:
21675 case MVT::nxv8i16:
21676 case MVT::nxv8f16:
21677 case MVT::nxv8bf16:
21678 return MVT::nxv8i16;
21679 case MVT::nxv16i8:
21680 return MVT::nxv16i8;
21681 }
21682 }
21683
performLD1Combine(SDNode * N,SelectionDAG & DAG,unsigned Opc)21684 static SDValue performLD1Combine(SDNode *N, SelectionDAG &DAG, unsigned Opc) {
21685 SDLoc DL(N);
21686 EVT VT = N->getValueType(0);
21687
21688 if (VT.getSizeInBits().getKnownMinValue() > AArch64::SVEBitsPerBlock)
21689 return SDValue();
21690
21691 EVT ContainerVT = VT;
21692 if (ContainerVT.isInteger())
21693 ContainerVT = getSVEContainerType(ContainerVT);
21694
21695 SDVTList VTs = DAG.getVTList(ContainerVT, MVT::Other);
21696 SDValue Ops[] = { N->getOperand(0), // Chain
21697 N->getOperand(2), // Pg
21698 N->getOperand(3), // Base
21699 DAG.getValueType(VT) };
21700
21701 SDValue Load = DAG.getNode(Opc, DL, VTs, Ops);
21702 SDValue LoadChain = SDValue(Load.getNode(), 1);
21703
21704 if (ContainerVT.isInteger() && (VT != ContainerVT))
21705 Load = DAG.getNode(ISD::TRUNCATE, DL, VT, Load.getValue(0));
21706
21707 return DAG.getMergeValues({ Load, LoadChain }, DL);
21708 }
21709
performLDNT1Combine(SDNode * N,SelectionDAG & DAG)21710 static SDValue performLDNT1Combine(SDNode *N, SelectionDAG &DAG) {
21711 SDLoc DL(N);
21712 EVT VT = N->getValueType(0);
21713 EVT PtrTy = N->getOperand(3).getValueType();
21714
21715 EVT LoadVT = VT;
21716 if (VT.isFloatingPoint())
21717 LoadVT = VT.changeTypeToInteger();
21718
21719 auto *MINode = cast<MemIntrinsicSDNode>(N);
21720 SDValue PassThru = DAG.getConstant(0, DL, LoadVT);
21721 SDValue L = DAG.getMaskedLoad(LoadVT, DL, MINode->getChain(),
21722 MINode->getOperand(3), DAG.getUNDEF(PtrTy),
21723 MINode->getOperand(2), PassThru,
21724 MINode->getMemoryVT(), MINode->getMemOperand(),
21725 ISD::UNINDEXED, ISD::NON_EXTLOAD, false);
21726
21727 if (VT.isFloatingPoint()) {
21728 SDValue Ops[] = { DAG.getNode(ISD::BITCAST, DL, VT, L), L.getValue(1) };
21729 return DAG.getMergeValues(Ops, DL);
21730 }
21731
21732 return L;
21733 }
21734
21735 template <unsigned Opcode>
performLD1ReplicateCombine(SDNode * N,SelectionDAG & DAG)21736 static SDValue performLD1ReplicateCombine(SDNode *N, SelectionDAG &DAG) {
21737 static_assert(Opcode == AArch64ISD::LD1RQ_MERGE_ZERO ||
21738 Opcode == AArch64ISD::LD1RO_MERGE_ZERO,
21739 "Unsupported opcode.");
21740 SDLoc DL(N);
21741 EVT VT = N->getValueType(0);
21742
21743 EVT LoadVT = VT;
21744 if (VT.isFloatingPoint())
21745 LoadVT = VT.changeTypeToInteger();
21746
21747 SDValue Ops[] = {N->getOperand(0), N->getOperand(2), N->getOperand(3)};
21748 SDValue Load = DAG.getNode(Opcode, DL, {LoadVT, MVT::Other}, Ops);
21749 SDValue LoadChain = SDValue(Load.getNode(), 1);
21750
21751 if (VT.isFloatingPoint())
21752 Load = DAG.getNode(ISD::BITCAST, DL, VT, Load.getValue(0));
21753
21754 return DAG.getMergeValues({Load, LoadChain}, DL);
21755 }
21756
performST1Combine(SDNode * N,SelectionDAG & DAG)21757 static SDValue performST1Combine(SDNode *N, SelectionDAG &DAG) {
21758 SDLoc DL(N);
21759 SDValue Data = N->getOperand(2);
21760 EVT DataVT = Data.getValueType();
21761 EVT HwSrcVt = getSVEContainerType(DataVT);
21762 SDValue InputVT = DAG.getValueType(DataVT);
21763
21764 if (DataVT.isFloatingPoint())
21765 InputVT = DAG.getValueType(HwSrcVt);
21766
21767 SDValue SrcNew;
21768 if (Data.getValueType().isFloatingPoint())
21769 SrcNew = DAG.getNode(ISD::BITCAST, DL, HwSrcVt, Data);
21770 else
21771 SrcNew = DAG.getNode(ISD::ANY_EXTEND, DL, HwSrcVt, Data);
21772
21773 SDValue Ops[] = { N->getOperand(0), // Chain
21774 SrcNew,
21775 N->getOperand(4), // Base
21776 N->getOperand(3), // Pg
21777 InputVT
21778 };
21779
21780 return DAG.getNode(AArch64ISD::ST1_PRED, DL, N->getValueType(0), Ops);
21781 }
21782
performSTNT1Combine(SDNode * N,SelectionDAG & DAG)21783 static SDValue performSTNT1Combine(SDNode *N, SelectionDAG &DAG) {
21784 SDLoc DL(N);
21785
21786 SDValue Data = N->getOperand(2);
21787 EVT DataVT = Data.getValueType();
21788 EVT PtrTy = N->getOperand(4).getValueType();
21789
21790 if (DataVT.isFloatingPoint())
21791 Data = DAG.getNode(ISD::BITCAST, DL, DataVT.changeTypeToInteger(), Data);
21792
21793 auto *MINode = cast<MemIntrinsicSDNode>(N);
21794 return DAG.getMaskedStore(MINode->getChain(), DL, Data, MINode->getOperand(4),
21795 DAG.getUNDEF(PtrTy), MINode->getOperand(3),
21796 MINode->getMemoryVT(), MINode->getMemOperand(),
21797 ISD::UNINDEXED, false, false);
21798 }
21799
21800 /// Replace a splat of zeros to a vector store by scalar stores of WZR/XZR. The
21801 /// load store optimizer pass will merge them to store pair stores. This should
21802 /// be better than a movi to create the vector zero followed by a vector store
21803 /// if the zero constant is not re-used, since one instructions and one register
21804 /// live range will be removed.
21805 ///
21806 /// For example, the final generated code should be:
21807 ///
21808 /// stp xzr, xzr, [x0]
21809 ///
21810 /// instead of:
21811 ///
21812 /// movi v0.2d, #0
21813 /// str q0, [x0]
21814 ///
replaceZeroVectorStore(SelectionDAG & DAG,StoreSDNode & St)21815 static SDValue replaceZeroVectorStore(SelectionDAG &DAG, StoreSDNode &St) {
21816 SDValue StVal = St.getValue();
21817 EVT VT = StVal.getValueType();
21818
21819 // Avoid scalarizing zero splat stores for scalable vectors.
21820 if (VT.isScalableVector())
21821 return SDValue();
21822
21823 // It is beneficial to scalarize a zero splat store for 2 or 3 i64 elements or
21824 // 2, 3 or 4 i32 elements.
21825 int NumVecElts = VT.getVectorNumElements();
21826 if (!(((NumVecElts == 2 || NumVecElts == 3) &&
21827 VT.getVectorElementType().getSizeInBits() == 64) ||
21828 ((NumVecElts == 2 || NumVecElts == 3 || NumVecElts == 4) &&
21829 VT.getVectorElementType().getSizeInBits() == 32)))
21830 return SDValue();
21831
21832 if (StVal.getOpcode() != ISD::BUILD_VECTOR)
21833 return SDValue();
21834
21835 // If the zero constant has more than one use then the vector store could be
21836 // better since the constant mov will be amortized and stp q instructions
21837 // should be able to be formed.
21838 if (!StVal.hasOneUse())
21839 return SDValue();
21840
21841 // If the store is truncating then it's going down to i16 or smaller, which
21842 // means it can be implemented in a single store anyway.
21843 if (St.isTruncatingStore())
21844 return SDValue();
21845
21846 // If the immediate offset of the address operand is too large for the stp
21847 // instruction, then bail out.
21848 if (DAG.isBaseWithConstantOffset(St.getBasePtr())) {
21849 int64_t Offset = St.getBasePtr()->getConstantOperandVal(1);
21850 if (Offset < -512 || Offset > 504)
21851 return SDValue();
21852 }
21853
21854 for (int I = 0; I < NumVecElts; ++I) {
21855 SDValue EltVal = StVal.getOperand(I);
21856 if (!isNullConstant(EltVal) && !isNullFPConstant(EltVal))
21857 return SDValue();
21858 }
21859
21860 // Use a CopyFromReg WZR/XZR here to prevent
21861 // DAGCombiner::MergeConsecutiveStores from undoing this transformation.
21862 SDLoc DL(&St);
21863 unsigned ZeroReg;
21864 EVT ZeroVT;
21865 if (VT.getVectorElementType().getSizeInBits() == 32) {
21866 ZeroReg = AArch64::WZR;
21867 ZeroVT = MVT::i32;
21868 } else {
21869 ZeroReg = AArch64::XZR;
21870 ZeroVT = MVT::i64;
21871 }
21872 SDValue SplatVal =
21873 DAG.getCopyFromReg(DAG.getEntryNode(), DL, ZeroReg, ZeroVT);
21874 return splitStoreSplat(DAG, St, SplatVal, NumVecElts);
21875 }
21876
21877 /// Replace a splat of a scalar to a vector store by scalar stores of the scalar
21878 /// value. The load store optimizer pass will merge them to store pair stores.
21879 /// This has better performance than a splat of the scalar followed by a split
21880 /// vector store. Even if the stores are not merged it is four stores vs a dup,
21881 /// followed by an ext.b and two stores.
replaceSplatVectorStore(SelectionDAG & DAG,StoreSDNode & St)21882 static SDValue replaceSplatVectorStore(SelectionDAG &DAG, StoreSDNode &St) {
21883 SDValue StVal = St.getValue();
21884 EVT VT = StVal.getValueType();
21885
21886 // Don't replace floating point stores, they possibly won't be transformed to
21887 // stp because of the store pair suppress pass.
21888 if (VT.isFloatingPoint())
21889 return SDValue();
21890
21891 // We can express a splat as store pair(s) for 2 or 4 elements.
21892 unsigned NumVecElts = VT.getVectorNumElements();
21893 if (NumVecElts != 4 && NumVecElts != 2)
21894 return SDValue();
21895
21896 // If the store is truncating then it's going down to i16 or smaller, which
21897 // means it can be implemented in a single store anyway.
21898 if (St.isTruncatingStore())
21899 return SDValue();
21900
21901 // Check that this is a splat.
21902 // Make sure that each of the relevant vector element locations are inserted
21903 // to, i.e. 0 and 1 for v2i64 and 0, 1, 2, 3 for v4i32.
21904 std::bitset<4> IndexNotInserted((1 << NumVecElts) - 1);
21905 SDValue SplatVal;
21906 for (unsigned I = 0; I < NumVecElts; ++I) {
21907 // Check for insert vector elements.
21908 if (StVal.getOpcode() != ISD::INSERT_VECTOR_ELT)
21909 return SDValue();
21910
21911 // Check that same value is inserted at each vector element.
21912 if (I == 0)
21913 SplatVal = StVal.getOperand(1);
21914 else if (StVal.getOperand(1) != SplatVal)
21915 return SDValue();
21916
21917 // Check insert element index.
21918 ConstantSDNode *CIndex = dyn_cast<ConstantSDNode>(StVal.getOperand(2));
21919 if (!CIndex)
21920 return SDValue();
21921 uint64_t IndexVal = CIndex->getZExtValue();
21922 if (IndexVal >= NumVecElts)
21923 return SDValue();
21924 IndexNotInserted.reset(IndexVal);
21925
21926 StVal = StVal.getOperand(0);
21927 }
21928 // Check that all vector element locations were inserted to.
21929 if (IndexNotInserted.any())
21930 return SDValue();
21931
21932 return splitStoreSplat(DAG, St, SplatVal, NumVecElts);
21933 }
21934
splitStores(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)21935 static SDValue splitStores(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
21936 SelectionDAG &DAG,
21937 const AArch64Subtarget *Subtarget) {
21938
21939 StoreSDNode *S = cast<StoreSDNode>(N);
21940 if (S->isVolatile() || S->isIndexed())
21941 return SDValue();
21942
21943 SDValue StVal = S->getValue();
21944 EVT VT = StVal.getValueType();
21945
21946 if (!VT.isFixedLengthVector())
21947 return SDValue();
21948
21949 // If we get a splat of zeros, convert this vector store to a store of
21950 // scalars. They will be merged into store pairs of xzr thereby removing one
21951 // instruction and one register.
21952 if (SDValue ReplacedZeroSplat = replaceZeroVectorStore(DAG, *S))
21953 return ReplacedZeroSplat;
21954
21955 // FIXME: The logic for deciding if an unaligned store should be split should
21956 // be included in TLI.allowsMisalignedMemoryAccesses(), and there should be
21957 // a call to that function here.
21958
21959 if (!Subtarget->isMisaligned128StoreSlow())
21960 return SDValue();
21961
21962 // Don't split at -Oz.
21963 if (DAG.getMachineFunction().getFunction().hasMinSize())
21964 return SDValue();
21965
21966 // Don't split v2i64 vectors. Memcpy lowering produces those and splitting
21967 // those up regresses performance on micro-benchmarks and olden/bh.
21968 if (VT.getVectorNumElements() < 2 || VT == MVT::v2i64)
21969 return SDValue();
21970
21971 // Split unaligned 16B stores. They are terrible for performance.
21972 // Don't split stores with alignment of 1 or 2. Code that uses clang vector
21973 // extensions can use this to mark that it does not want splitting to happen
21974 // (by underspecifying alignment to be 1 or 2). Furthermore, the chance of
21975 // eliminating alignment hazards is only 1 in 8 for alignment of 2.
21976 if (VT.getSizeInBits() != 128 || S->getAlign() >= Align(16) ||
21977 S->getAlign() <= Align(2))
21978 return SDValue();
21979
21980 // If we get a splat of a scalar convert this vector store to a store of
21981 // scalars. They will be merged into store pairs thereby removing two
21982 // instructions.
21983 if (SDValue ReplacedSplat = replaceSplatVectorStore(DAG, *S))
21984 return ReplacedSplat;
21985
21986 SDLoc DL(S);
21987
21988 // Split VT into two.
21989 EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
21990 unsigned NumElts = HalfVT.getVectorNumElements();
21991 SDValue SubVector0 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, StVal,
21992 DAG.getConstant(0, DL, MVT::i64));
21993 SDValue SubVector1 = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, HalfVT, StVal,
21994 DAG.getConstant(NumElts, DL, MVT::i64));
21995 SDValue BasePtr = S->getBasePtr();
21996 SDValue NewST1 =
21997 DAG.getStore(S->getChain(), DL, SubVector0, BasePtr, S->getPointerInfo(),
21998 S->getAlign(), S->getMemOperand()->getFlags());
21999 SDValue OffsetPtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr,
22000 DAG.getConstant(8, DL, MVT::i64));
22001 return DAG.getStore(NewST1.getValue(0), DL, SubVector1, OffsetPtr,
22002 S->getPointerInfo(), S->getAlign(),
22003 S->getMemOperand()->getFlags());
22004 }
22005
performSpliceCombine(SDNode * N,SelectionDAG & DAG)22006 static SDValue performSpliceCombine(SDNode *N, SelectionDAG &DAG) {
22007 assert(N->getOpcode() == AArch64ISD::SPLICE && "Unexepected Opcode!");
22008
22009 // splice(pg, op1, undef) -> op1
22010 if (N->getOperand(2).isUndef())
22011 return N->getOperand(1);
22012
22013 return SDValue();
22014 }
22015
performUnpackCombine(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)22016 static SDValue performUnpackCombine(SDNode *N, SelectionDAG &DAG,
22017 const AArch64Subtarget *Subtarget) {
22018 assert((N->getOpcode() == AArch64ISD::UUNPKHI ||
22019 N->getOpcode() == AArch64ISD::UUNPKLO) &&
22020 "Unexpected Opcode!");
22021
22022 // uunpklo/hi undef -> undef
22023 if (N->getOperand(0).isUndef())
22024 return DAG.getUNDEF(N->getValueType(0));
22025
22026 // If this is a masked load followed by an UUNPKLO, fold this into a masked
22027 // extending load. We can do this even if this is already a masked
22028 // {z,}extload.
22029 if (N->getOperand(0).getOpcode() == ISD::MLOAD &&
22030 N->getOpcode() == AArch64ISD::UUNPKLO) {
22031 MaskedLoadSDNode *MLD = cast<MaskedLoadSDNode>(N->getOperand(0));
22032 SDValue Mask = MLD->getMask();
22033 SDLoc DL(N);
22034
22035 if (MLD->isUnindexed() && MLD->getExtensionType() != ISD::SEXTLOAD &&
22036 SDValue(MLD, 0).hasOneUse() && Mask->getOpcode() == AArch64ISD::PTRUE &&
22037 (MLD->getPassThru()->isUndef() ||
22038 isZerosVector(MLD->getPassThru().getNode()))) {
22039 unsigned MinSVESize = Subtarget->getMinSVEVectorSizeInBits();
22040 unsigned PgPattern = Mask->getConstantOperandVal(0);
22041 EVT VT = N->getValueType(0);
22042
22043 // Ensure we can double the size of the predicate pattern
22044 unsigned NumElts = getNumElementsFromSVEPredPattern(PgPattern);
22045 if (NumElts &&
22046 NumElts * VT.getVectorElementType().getSizeInBits() <= MinSVESize) {
22047 Mask =
22048 getPTrue(DAG, DL, VT.changeVectorElementType(MVT::i1), PgPattern);
22049 SDValue PassThru = DAG.getConstant(0, DL, VT);
22050 SDValue NewLoad = DAG.getMaskedLoad(
22051 VT, DL, MLD->getChain(), MLD->getBasePtr(), MLD->getOffset(), Mask,
22052 PassThru, MLD->getMemoryVT(), MLD->getMemOperand(),
22053 MLD->getAddressingMode(), ISD::ZEXTLOAD);
22054
22055 DAG.ReplaceAllUsesOfValueWith(SDValue(MLD, 1), NewLoad.getValue(1));
22056
22057 return NewLoad;
22058 }
22059 }
22060 }
22061
22062 return SDValue();
22063 }
22064
isHalvingTruncateAndConcatOfLegalIntScalableType(SDNode * N)22065 static bool isHalvingTruncateAndConcatOfLegalIntScalableType(SDNode *N) {
22066 if (N->getOpcode() != AArch64ISD::UZP1)
22067 return false;
22068 SDValue Op0 = N->getOperand(0);
22069 EVT SrcVT = Op0->getValueType(0);
22070 EVT DstVT = N->getValueType(0);
22071 return (SrcVT == MVT::nxv8i16 && DstVT == MVT::nxv16i8) ||
22072 (SrcVT == MVT::nxv4i32 && DstVT == MVT::nxv8i16) ||
22073 (SrcVT == MVT::nxv2i64 && DstVT == MVT::nxv4i32);
22074 }
22075
22076 // Try to combine rounding shifts where the operands come from an extend, and
22077 // the result is truncated and combined into one vector.
22078 // uzp1(rshrnb(uunpklo(X),C), rshrnb(uunpkhi(X), C)) -> urshr(X, C)
tryCombineExtendRShTrunc(SDNode * N,SelectionDAG & DAG)22079 static SDValue tryCombineExtendRShTrunc(SDNode *N, SelectionDAG &DAG) {
22080 assert(N->getOpcode() == AArch64ISD::UZP1 && "Only UZP1 expected.");
22081 SDValue Op0 = N->getOperand(0);
22082 SDValue Op1 = N->getOperand(1);
22083 EVT ResVT = N->getValueType(0);
22084
22085 unsigned RshOpc = Op0.getOpcode();
22086 if (RshOpc != AArch64ISD::RSHRNB_I)
22087 return SDValue();
22088
22089 // Same op code and imm value?
22090 SDValue ShiftValue = Op0.getOperand(1);
22091 if (RshOpc != Op1.getOpcode() || ShiftValue != Op1.getOperand(1))
22092 return SDValue();
22093
22094 // Same unextended operand value?
22095 SDValue Lo = Op0.getOperand(0);
22096 SDValue Hi = Op1.getOperand(0);
22097 if (Lo.getOpcode() != AArch64ISD::UUNPKLO &&
22098 Hi.getOpcode() != AArch64ISD::UUNPKHI)
22099 return SDValue();
22100 SDValue OrigArg = Lo.getOperand(0);
22101 if (OrigArg != Hi.getOperand(0))
22102 return SDValue();
22103
22104 SDLoc DL(N);
22105 return DAG.getNode(AArch64ISD::URSHR_I_PRED, DL, ResVT,
22106 getPredicateForVector(DAG, DL, ResVT), OrigArg,
22107 ShiftValue);
22108 }
22109
22110 // Try to simplify:
22111 // t1 = nxv8i16 add(X, 1 << (ShiftValue - 1))
22112 // t2 = nxv8i16 srl(t1, ShiftValue)
22113 // to
22114 // t1 = nxv8i16 rshrnb(X, shiftvalue).
22115 // rshrnb will zero the top half bits of each element. Therefore, this combine
22116 // should only be performed when a following instruction with the rshrnb
22117 // as an operand does not care about the top half of each element. For example,
22118 // a uzp1 or a truncating store.
trySimplifySrlAddToRshrnb(SDValue Srl,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)22119 static SDValue trySimplifySrlAddToRshrnb(SDValue Srl, SelectionDAG &DAG,
22120 const AArch64Subtarget *Subtarget) {
22121 EVT VT = Srl->getValueType(0);
22122 if (!VT.isScalableVector() || !Subtarget->hasSVE2())
22123 return SDValue();
22124
22125 EVT ResVT;
22126 if (VT == MVT::nxv8i16)
22127 ResVT = MVT::nxv16i8;
22128 else if (VT == MVT::nxv4i32)
22129 ResVT = MVT::nxv8i16;
22130 else if (VT == MVT::nxv2i64)
22131 ResVT = MVT::nxv4i32;
22132 else
22133 return SDValue();
22134
22135 SDLoc DL(Srl);
22136 unsigned ShiftValue;
22137 SDValue RShOperand;
22138 if (!canLowerSRLToRoundingShiftForVT(Srl, ResVT, DAG, ShiftValue, RShOperand))
22139 return SDValue();
22140 SDValue Rshrnb = DAG.getNode(
22141 AArch64ISD::RSHRNB_I, DL, ResVT,
22142 {RShOperand, DAG.getTargetConstant(ShiftValue, DL, MVT::i32)});
22143 return DAG.getNode(ISD::BITCAST, DL, VT, Rshrnb);
22144 }
22145
performUzpCombine(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)22146 static SDValue performUzpCombine(SDNode *N, SelectionDAG &DAG,
22147 const AArch64Subtarget *Subtarget) {
22148 SDLoc DL(N);
22149 SDValue Op0 = N->getOperand(0);
22150 SDValue Op1 = N->getOperand(1);
22151 EVT ResVT = N->getValueType(0);
22152
22153 // uzp(extract_lo(x), extract_hi(x)) -> extract_lo(uzp x, x)
22154 if (Op0.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
22155 Op1.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
22156 Op0.getOperand(0) == Op1.getOperand(0)) {
22157
22158 SDValue SourceVec = Op0.getOperand(0);
22159 uint64_t ExtIdx0 = Op0.getConstantOperandVal(1);
22160 uint64_t ExtIdx1 = Op1.getConstantOperandVal(1);
22161 uint64_t NumElements = SourceVec.getValueType().getVectorMinNumElements();
22162 if (ExtIdx0 == 0 && ExtIdx1 == NumElements / 2) {
22163 EVT OpVT = Op0.getOperand(1).getValueType();
22164 EVT WidenedResVT = ResVT.getDoubleNumVectorElementsVT(*DAG.getContext());
22165 SDValue Uzp = DAG.getNode(N->getOpcode(), DL, WidenedResVT, SourceVec,
22166 DAG.getUNDEF(WidenedResVT));
22167 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResVT, Uzp,
22168 DAG.getConstant(0, DL, OpVT));
22169 }
22170 }
22171
22172 // Following optimizations only work with uzp1.
22173 if (N->getOpcode() == AArch64ISD::UZP2)
22174 return SDValue();
22175
22176 // uzp1(x, undef) -> concat(truncate(x), undef)
22177 if (Op1.getOpcode() == ISD::UNDEF) {
22178 EVT BCVT = MVT::Other, HalfVT = MVT::Other;
22179 switch (ResVT.getSimpleVT().SimpleTy) {
22180 default:
22181 break;
22182 case MVT::v16i8:
22183 BCVT = MVT::v8i16;
22184 HalfVT = MVT::v8i8;
22185 break;
22186 case MVT::v8i16:
22187 BCVT = MVT::v4i32;
22188 HalfVT = MVT::v4i16;
22189 break;
22190 case MVT::v4i32:
22191 BCVT = MVT::v2i64;
22192 HalfVT = MVT::v2i32;
22193 break;
22194 }
22195 if (BCVT != MVT::Other) {
22196 SDValue BC = DAG.getBitcast(BCVT, Op0);
22197 SDValue Trunc = DAG.getNode(ISD::TRUNCATE, DL, HalfVT, BC);
22198 return DAG.getNode(ISD::CONCAT_VECTORS, DL, ResVT, Trunc,
22199 DAG.getUNDEF(HalfVT));
22200 }
22201 }
22202
22203 if (SDValue Urshr = tryCombineExtendRShTrunc(N, DAG))
22204 return Urshr;
22205
22206 if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op0, DAG, Subtarget))
22207 return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Rshrnb, Op1);
22208
22209 if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Op1, DAG, Subtarget))
22210 return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Rshrnb);
22211
22212 // uzp1(unpklo(uzp1(x, y)), z) => uzp1(x, z)
22213 if (Op0.getOpcode() == AArch64ISD::UUNPKLO) {
22214 if (Op0.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
22215 SDValue X = Op0.getOperand(0).getOperand(0);
22216 return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, X, Op1);
22217 }
22218 }
22219
22220 // uzp1(x, unpkhi(uzp1(y, z))) => uzp1(x, z)
22221 if (Op1.getOpcode() == AArch64ISD::UUNPKHI) {
22222 if (Op1.getOperand(0).getOpcode() == AArch64ISD::UZP1) {
22223 SDValue Z = Op1.getOperand(0).getOperand(1);
22224 return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0, Z);
22225 }
22226 }
22227
22228 // These optimizations only work on little endian.
22229 if (!DAG.getDataLayout().isLittleEndian())
22230 return SDValue();
22231
22232 // uzp1(bitcast(x), bitcast(y)) -> uzp1(x, y)
22233 // Example:
22234 // nxv4i32 = uzp1 bitcast(nxv4i32 x to nxv2i64), bitcast(nxv4i32 y to nxv2i64)
22235 // to
22236 // nxv4i32 = uzp1 nxv4i32 x, nxv4i32 y
22237 if (isHalvingTruncateAndConcatOfLegalIntScalableType(N) &&
22238 Op0.getOpcode() == ISD::BITCAST && Op1.getOpcode() == ISD::BITCAST) {
22239 if (Op0.getOperand(0).getValueType() == Op1.getOperand(0).getValueType()) {
22240 return DAG.getNode(AArch64ISD::UZP1, DL, ResVT, Op0.getOperand(0),
22241 Op1.getOperand(0));
22242 }
22243 }
22244
22245 if (ResVT != MVT::v2i32 && ResVT != MVT::v4i16 && ResVT != MVT::v8i8)
22246 return SDValue();
22247
22248 SDValue SourceOp0 = peekThroughBitcasts(Op0);
22249 SDValue SourceOp1 = peekThroughBitcasts(Op1);
22250
22251 // truncating uzp1(x, y) -> xtn(concat (x, y))
22252 if (SourceOp0.getValueType() == SourceOp1.getValueType()) {
22253 EVT Op0Ty = SourceOp0.getValueType();
22254 if ((ResVT == MVT::v4i16 && Op0Ty == MVT::v2i32) ||
22255 (ResVT == MVT::v8i8 && Op0Ty == MVT::v4i16)) {
22256 SDValue Concat =
22257 DAG.getNode(ISD::CONCAT_VECTORS, DL,
22258 Op0Ty.getDoubleNumVectorElementsVT(*DAG.getContext()),
22259 SourceOp0, SourceOp1);
22260 return DAG.getNode(ISD::TRUNCATE, DL, ResVT, Concat);
22261 }
22262 }
22263
22264 // uzp1(xtn x, xtn y) -> xtn(uzp1 (x, y))
22265 if (SourceOp0.getOpcode() != ISD::TRUNCATE ||
22266 SourceOp1.getOpcode() != ISD::TRUNCATE)
22267 return SDValue();
22268 SourceOp0 = SourceOp0.getOperand(0);
22269 SourceOp1 = SourceOp1.getOperand(0);
22270
22271 if (SourceOp0.getValueType() != SourceOp1.getValueType() ||
22272 !SourceOp0.getValueType().isSimple())
22273 return SDValue();
22274
22275 EVT ResultTy;
22276
22277 switch (SourceOp0.getSimpleValueType().SimpleTy) {
22278 case MVT::v2i64:
22279 ResultTy = MVT::v4i32;
22280 break;
22281 case MVT::v4i32:
22282 ResultTy = MVT::v8i16;
22283 break;
22284 case MVT::v8i16:
22285 ResultTy = MVT::v16i8;
22286 break;
22287 default:
22288 return SDValue();
22289 }
22290
22291 SDValue UzpOp0 = DAG.getNode(ISD::BITCAST, DL, ResultTy, SourceOp0);
22292 SDValue UzpOp1 = DAG.getNode(ISD::BITCAST, DL, ResultTy, SourceOp1);
22293 SDValue UzpResult =
22294 DAG.getNode(AArch64ISD::UZP1, DL, UzpOp0.getValueType(), UzpOp0, UzpOp1);
22295
22296 EVT BitcastResultTy;
22297
22298 switch (ResVT.getSimpleVT().SimpleTy) {
22299 case MVT::v2i32:
22300 BitcastResultTy = MVT::v2i64;
22301 break;
22302 case MVT::v4i16:
22303 BitcastResultTy = MVT::v4i32;
22304 break;
22305 case MVT::v8i8:
22306 BitcastResultTy = MVT::v8i16;
22307 break;
22308 default:
22309 llvm_unreachable("Should be one of {v2i32, v4i16, v8i8}");
22310 }
22311
22312 return DAG.getNode(ISD::TRUNCATE, DL, ResVT,
22313 DAG.getNode(ISD::BITCAST, DL, BitcastResultTy, UzpResult));
22314 }
22315
performGLD1Combine(SDNode * N,SelectionDAG & DAG)22316 static SDValue performGLD1Combine(SDNode *N, SelectionDAG &DAG) {
22317 unsigned Opc = N->getOpcode();
22318
22319 assert(((Opc >= AArch64ISD::GLD1_MERGE_ZERO && // unsigned gather loads
22320 Opc <= AArch64ISD::GLD1_IMM_MERGE_ZERO) ||
22321 (Opc >= AArch64ISD::GLD1S_MERGE_ZERO && // signed gather loads
22322 Opc <= AArch64ISD::GLD1S_IMM_MERGE_ZERO)) &&
22323 "Invalid opcode.");
22324
22325 const bool Scaled = Opc == AArch64ISD::GLD1_SCALED_MERGE_ZERO ||
22326 Opc == AArch64ISD::GLD1S_SCALED_MERGE_ZERO;
22327 const bool Signed = Opc == AArch64ISD::GLD1S_MERGE_ZERO ||
22328 Opc == AArch64ISD::GLD1S_SCALED_MERGE_ZERO;
22329 const bool Extended = Opc == AArch64ISD::GLD1_SXTW_MERGE_ZERO ||
22330 Opc == AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO ||
22331 Opc == AArch64ISD::GLD1_UXTW_MERGE_ZERO ||
22332 Opc == AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO;
22333
22334 SDLoc DL(N);
22335 SDValue Chain = N->getOperand(0);
22336 SDValue Pg = N->getOperand(1);
22337 SDValue Base = N->getOperand(2);
22338 SDValue Offset = N->getOperand(3);
22339 SDValue Ty = N->getOperand(4);
22340
22341 EVT ResVT = N->getValueType(0);
22342
22343 const auto OffsetOpc = Offset.getOpcode();
22344 const bool OffsetIsZExt =
22345 OffsetOpc == AArch64ISD::ZERO_EXTEND_INREG_MERGE_PASSTHRU;
22346 const bool OffsetIsSExt =
22347 OffsetOpc == AArch64ISD::SIGN_EXTEND_INREG_MERGE_PASSTHRU;
22348
22349 // Fold sign/zero extensions of vector offsets into GLD1 nodes where possible.
22350 if (!Extended && (OffsetIsSExt || OffsetIsZExt)) {
22351 SDValue ExtPg = Offset.getOperand(0);
22352 VTSDNode *ExtFrom = cast<VTSDNode>(Offset.getOperand(2).getNode());
22353 EVT ExtFromEVT = ExtFrom->getVT().getVectorElementType();
22354
22355 // If the predicate for the sign- or zero-extended offset is the
22356 // same as the predicate used for this load and the sign-/zero-extension
22357 // was from a 32-bits...
22358 if (ExtPg == Pg && ExtFromEVT == MVT::i32) {
22359 SDValue UnextendedOffset = Offset.getOperand(1);
22360
22361 unsigned NewOpc = getGatherVecOpcode(Scaled, OffsetIsSExt, true);
22362 if (Signed)
22363 NewOpc = getSignExtendedGatherOpcode(NewOpc);
22364
22365 return DAG.getNode(NewOpc, DL, {ResVT, MVT::Other},
22366 {Chain, Pg, Base, UnextendedOffset, Ty});
22367 }
22368 }
22369
22370 return SDValue();
22371 }
22372
22373 /// Optimize a vector shift instruction and its operand if shifted out
22374 /// bits are not used.
performVectorShiftCombine(SDNode * N,const AArch64TargetLowering & TLI,TargetLowering::DAGCombinerInfo & DCI)22375 static SDValue performVectorShiftCombine(SDNode *N,
22376 const AArch64TargetLowering &TLI,
22377 TargetLowering::DAGCombinerInfo &DCI) {
22378 assert(N->getOpcode() == AArch64ISD::VASHR ||
22379 N->getOpcode() == AArch64ISD::VLSHR);
22380
22381 SDValue Op = N->getOperand(0);
22382 unsigned OpScalarSize = Op.getScalarValueSizeInBits();
22383
22384 unsigned ShiftImm = N->getConstantOperandVal(1);
22385 assert(OpScalarSize > ShiftImm && "Invalid shift imm");
22386
22387 // Remove sign_extend_inreg (ashr(shl(x)) based on the number of sign bits.
22388 if (N->getOpcode() == AArch64ISD::VASHR &&
22389 Op.getOpcode() == AArch64ISD::VSHL &&
22390 N->getOperand(1) == Op.getOperand(1))
22391 if (DCI.DAG.ComputeNumSignBits(Op.getOperand(0)) > ShiftImm)
22392 return Op.getOperand(0);
22393
22394 // If the shift is exact, the shifted out bits matter.
22395 if (N->getFlags().hasExact())
22396 return SDValue();
22397
22398 APInt ShiftedOutBits = APInt::getLowBitsSet(OpScalarSize, ShiftImm);
22399 APInt DemandedMask = ~ShiftedOutBits;
22400
22401 if (TLI.SimplifyDemandedBits(Op, DemandedMask, DCI))
22402 return SDValue(N, 0);
22403
22404 return SDValue();
22405 }
22406
performSunpkloCombine(SDNode * N,SelectionDAG & DAG)22407 static SDValue performSunpkloCombine(SDNode *N, SelectionDAG &DAG) {
22408 // sunpklo(sext(pred)) -> sext(extract_low_half(pred))
22409 // This transform works in partnership with performSetCCPunpkCombine to
22410 // remove unnecessary transfer of predicates into standard registers and back
22411 if (N->getOperand(0).getOpcode() == ISD::SIGN_EXTEND &&
22412 N->getOperand(0)->getOperand(0)->getValueType(0).getScalarType() ==
22413 MVT::i1) {
22414 SDValue CC = N->getOperand(0)->getOperand(0);
22415 auto VT = CC->getValueType(0).getHalfNumVectorElementsVT(*DAG.getContext());
22416 SDValue Unpk = DAG.getNode(ISD::EXTRACT_SUBVECTOR, SDLoc(N), VT, CC,
22417 DAG.getVectorIdxConstant(0, SDLoc(N)));
22418 return DAG.getNode(ISD::SIGN_EXTEND, SDLoc(N), N->getValueType(0), Unpk);
22419 }
22420
22421 return SDValue();
22422 }
22423
22424 /// Target-specific DAG combine function for post-increment LD1 (lane) and
22425 /// post-increment LD1R.
performPostLD1Combine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,bool IsLaneOp)22426 static SDValue performPostLD1Combine(SDNode *N,
22427 TargetLowering::DAGCombinerInfo &DCI,
22428 bool IsLaneOp) {
22429 if (DCI.isBeforeLegalizeOps())
22430 return SDValue();
22431
22432 SelectionDAG &DAG = DCI.DAG;
22433 EVT VT = N->getValueType(0);
22434
22435 if (!VT.is128BitVector() && !VT.is64BitVector())
22436 return SDValue();
22437
22438 unsigned LoadIdx = IsLaneOp ? 1 : 0;
22439 SDNode *LD = N->getOperand(LoadIdx).getNode();
22440 // If it is not LOAD, can not do such combine.
22441 if (LD->getOpcode() != ISD::LOAD)
22442 return SDValue();
22443
22444 // The vector lane must be a constant in the LD1LANE opcode.
22445 SDValue Lane;
22446 if (IsLaneOp) {
22447 Lane = N->getOperand(2);
22448 auto *LaneC = dyn_cast<ConstantSDNode>(Lane);
22449 if (!LaneC || LaneC->getZExtValue() >= VT.getVectorNumElements())
22450 return SDValue();
22451 }
22452
22453 LoadSDNode *LoadSDN = cast<LoadSDNode>(LD);
22454 EVT MemVT = LoadSDN->getMemoryVT();
22455 // Check if memory operand is the same type as the vector element.
22456 if (MemVT != VT.getVectorElementType())
22457 return SDValue();
22458
22459 // Check if there are other uses. If so, do not combine as it will introduce
22460 // an extra load.
22461 for (SDNode::use_iterator UI = LD->use_begin(), UE = LD->use_end(); UI != UE;
22462 ++UI) {
22463 if (UI.getUse().getResNo() == 1) // Ignore uses of the chain result.
22464 continue;
22465 if (*UI != N)
22466 return SDValue();
22467 }
22468
22469 // If there is one use and it can splat the value, prefer that operation.
22470 // TODO: This could be expanded to more operations if they reliably use the
22471 // index variants.
22472 if (N->hasOneUse()) {
22473 unsigned UseOpc = N->use_begin()->getOpcode();
22474 if (UseOpc == ISD::FMUL || UseOpc == ISD::FMA)
22475 return SDValue();
22476 }
22477
22478 SDValue Addr = LD->getOperand(1);
22479 SDValue Vector = N->getOperand(0);
22480 // Search for a use of the address operand that is an increment.
22481 for (SDNode::use_iterator UI = Addr.getNode()->use_begin(), UE =
22482 Addr.getNode()->use_end(); UI != UE; ++UI) {
22483 SDNode *User = *UI;
22484 if (User->getOpcode() != ISD::ADD
22485 || UI.getUse().getResNo() != Addr.getResNo())
22486 continue;
22487
22488 // If the increment is a constant, it must match the memory ref size.
22489 SDValue Inc = User->getOperand(User->getOperand(0) == Addr ? 1 : 0);
22490 if (ConstantSDNode *CInc = dyn_cast<ConstantSDNode>(Inc.getNode())) {
22491 uint32_t IncVal = CInc->getZExtValue();
22492 unsigned NumBytes = VT.getScalarSizeInBits() / 8;
22493 if (IncVal != NumBytes)
22494 continue;
22495 Inc = DAG.getRegister(AArch64::XZR, MVT::i64);
22496 }
22497
22498 // To avoid cycle construction make sure that neither the load nor the add
22499 // are predecessors to each other or the Vector.
22500 SmallPtrSet<const SDNode *, 32> Visited;
22501 SmallVector<const SDNode *, 16> Worklist;
22502 Visited.insert(Addr.getNode());
22503 Worklist.push_back(User);
22504 Worklist.push_back(LD);
22505 Worklist.push_back(Vector.getNode());
22506 if (SDNode::hasPredecessorHelper(LD, Visited, Worklist) ||
22507 SDNode::hasPredecessorHelper(User, Visited, Worklist))
22508 continue;
22509
22510 SmallVector<SDValue, 8> Ops;
22511 Ops.push_back(LD->getOperand(0)); // Chain
22512 if (IsLaneOp) {
22513 Ops.push_back(Vector); // The vector to be inserted
22514 Ops.push_back(Lane); // The lane to be inserted in the vector
22515 }
22516 Ops.push_back(Addr);
22517 Ops.push_back(Inc);
22518
22519 EVT Tys[3] = { VT, MVT::i64, MVT::Other };
22520 SDVTList SDTys = DAG.getVTList(Tys);
22521 unsigned NewOp = IsLaneOp ? AArch64ISD::LD1LANEpost : AArch64ISD::LD1DUPpost;
22522 SDValue UpdN = DAG.getMemIntrinsicNode(NewOp, SDLoc(N), SDTys, Ops,
22523 MemVT,
22524 LoadSDN->getMemOperand());
22525
22526 // Update the uses.
22527 SDValue NewResults[] = {
22528 SDValue(LD, 0), // The result of load
22529 SDValue(UpdN.getNode(), 2) // Chain
22530 };
22531 DCI.CombineTo(LD, NewResults);
22532 DCI.CombineTo(N, SDValue(UpdN.getNode(), 0)); // Dup/Inserted Result
22533 DCI.CombineTo(User, SDValue(UpdN.getNode(), 1)); // Write back register
22534
22535 break;
22536 }
22537 return SDValue();
22538 }
22539
22540 /// Simplify ``Addr`` given that the top byte of it is ignored by HW during
22541 /// address translation.
performTBISimplification(SDValue Addr,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)22542 static bool performTBISimplification(SDValue Addr,
22543 TargetLowering::DAGCombinerInfo &DCI,
22544 SelectionDAG &DAG) {
22545 APInt DemandedMask = APInt::getLowBitsSet(64, 56);
22546 KnownBits Known;
22547 TargetLowering::TargetLoweringOpt TLO(DAG, !DCI.isBeforeLegalize(),
22548 !DCI.isBeforeLegalizeOps());
22549 const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22550 if (TLI.SimplifyDemandedBits(Addr, DemandedMask, Known, TLO)) {
22551 DCI.CommitTargetLoweringOpt(TLO);
22552 return true;
22553 }
22554 return false;
22555 }
22556
foldTruncStoreOfExt(SelectionDAG & DAG,SDNode * N)22557 static SDValue foldTruncStoreOfExt(SelectionDAG &DAG, SDNode *N) {
22558 assert((N->getOpcode() == ISD::STORE || N->getOpcode() == ISD::MSTORE) &&
22559 "Expected STORE dag node in input!");
22560
22561 if (auto Store = dyn_cast<StoreSDNode>(N)) {
22562 if (!Store->isTruncatingStore() || Store->isIndexed())
22563 return SDValue();
22564 SDValue Ext = Store->getValue();
22565 auto ExtOpCode = Ext.getOpcode();
22566 if (ExtOpCode != ISD::ZERO_EXTEND && ExtOpCode != ISD::SIGN_EXTEND &&
22567 ExtOpCode != ISD::ANY_EXTEND)
22568 return SDValue();
22569 SDValue Orig = Ext->getOperand(0);
22570 if (Store->getMemoryVT() != Orig.getValueType())
22571 return SDValue();
22572 return DAG.getStore(Store->getChain(), SDLoc(Store), Orig,
22573 Store->getBasePtr(), Store->getMemOperand());
22574 }
22575
22576 return SDValue();
22577 }
22578
22579 // A custom combine to lower load <3 x i8> as the more efficient sequence
22580 // below:
22581 // ldrb wX, [x0, #2]
22582 // ldrh wY, [x0]
22583 // orr wX, wY, wX, lsl #16
22584 // fmov s0, wX
22585 //
22586 // Note that an alternative sequence with even fewer (although usually more
22587 // complex/expensive) instructions would be:
22588 // ld1r.4h { v0 }, [x0], #2
22589 // ld1.b { v0 }[2], [x0]
22590 //
22591 // Generating this sequence unfortunately results in noticeably worse codegen
22592 // for code that extends the loaded v3i8, due to legalization breaking vector
22593 // shuffle detection in a way that is very difficult to work around.
22594 // TODO: Revisit once v3i8 legalization has been improved in general.
combineV3I8LoadExt(LoadSDNode * LD,SelectionDAG & DAG)22595 static SDValue combineV3I8LoadExt(LoadSDNode *LD, SelectionDAG &DAG) {
22596 EVT MemVT = LD->getMemoryVT();
22597 if (MemVT != EVT::getVectorVT(*DAG.getContext(), MVT::i8, 3) ||
22598 LD->getOriginalAlign() >= 4)
22599 return SDValue();
22600
22601 SDLoc DL(LD);
22602 MachineFunction &MF = DAG.getMachineFunction();
22603 SDValue Chain = LD->getChain();
22604 SDValue BasePtr = LD->getBasePtr();
22605 MachineMemOperand *MMO = LD->getMemOperand();
22606 assert(LD->getOffset().isUndef() && "undef offset expected");
22607
22608 // Load 2 x i8, then 1 x i8.
22609 SDValue L16 = DAG.getLoad(MVT::i16, DL, Chain, BasePtr, MMO);
22610 TypeSize Offset2 = TypeSize::getFixed(2);
22611 SDValue L8 = DAG.getLoad(MVT::i8, DL, Chain,
22612 DAG.getMemBasePlusOffset(BasePtr, Offset2, DL),
22613 MF.getMachineMemOperand(MMO, 2, 1));
22614
22615 // Extend to i32.
22616 SDValue Ext16 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, L16);
22617 SDValue Ext8 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i32, L8);
22618
22619 // Pack 2 x i8 and 1 x i8 in an i32 and convert to v4i8.
22620 SDValue Shl = DAG.getNode(ISD::SHL, DL, MVT::i32, Ext8,
22621 DAG.getConstant(16, DL, MVT::i32));
22622 SDValue Or = DAG.getNode(ISD::OR, DL, MVT::i32, Ext16, Shl);
22623 SDValue Cast = DAG.getNode(ISD::BITCAST, DL, MVT::v4i8, Or);
22624
22625 // Extract v3i8 again.
22626 SDValue Extract = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MemVT, Cast,
22627 DAG.getConstant(0, DL, MVT::i64));
22628 SDValue TokenFactor = DAG.getNode(
22629 ISD::TokenFactor, DL, MVT::Other,
22630 {SDValue(cast<SDNode>(L16), 1), SDValue(cast<SDNode>(L8), 1)});
22631 return DAG.getMergeValues({Extract, TokenFactor}, DL);
22632 }
22633
22634 // Perform TBI simplification if supported by the target and try to break up
22635 // nontemporal loads larger than 256-bits loads for odd types so LDNPQ 256-bit
22636 // load instructions can be selected.
performLOADCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)22637 static SDValue performLOADCombine(SDNode *N,
22638 TargetLowering::DAGCombinerInfo &DCI,
22639 SelectionDAG &DAG,
22640 const AArch64Subtarget *Subtarget) {
22641 if (Subtarget->supportsAddressTopByteIgnored())
22642 performTBISimplification(N->getOperand(1), DCI, DAG);
22643
22644 LoadSDNode *LD = cast<LoadSDNode>(N);
22645 if (LD->isVolatile() || !Subtarget->isLittleEndian())
22646 return SDValue(N, 0);
22647
22648 if (SDValue Res = combineV3I8LoadExt(LD, DAG))
22649 return Res;
22650
22651 if (!LD->isNonTemporal())
22652 return SDValue(N, 0);
22653
22654 EVT MemVT = LD->getMemoryVT();
22655 if (MemVT.isScalableVector() || MemVT.getSizeInBits() <= 256 ||
22656 MemVT.getSizeInBits() % 256 == 0 ||
22657 256 % MemVT.getScalarSizeInBits() != 0)
22658 return SDValue(N, 0);
22659
22660 SDLoc DL(LD);
22661 SDValue Chain = LD->getChain();
22662 SDValue BasePtr = LD->getBasePtr();
22663 SDNodeFlags Flags = LD->getFlags();
22664 SmallVector<SDValue, 4> LoadOps;
22665 SmallVector<SDValue, 4> LoadOpsChain;
22666 // Replace any non temporal load over 256-bit with a series of 256 bit loads
22667 // and a scalar/vector load less than 256. This way we can utilize 256-bit
22668 // loads and reduce the amount of load instructions generated.
22669 MVT NewVT =
22670 MVT::getVectorVT(MemVT.getVectorElementType().getSimpleVT(),
22671 256 / MemVT.getVectorElementType().getSizeInBits());
22672 unsigned Num256Loads = MemVT.getSizeInBits() / 256;
22673 // Create all 256-bit loads starting from offset 0 and up to Num256Loads-1*32.
22674 for (unsigned I = 0; I < Num256Loads; I++) {
22675 unsigned PtrOffset = I * 32;
22676 SDValue NewPtr = DAG.getMemBasePlusOffset(
22677 BasePtr, TypeSize::getFixed(PtrOffset), DL, Flags);
22678 Align NewAlign = commonAlignment(LD->getAlign(), PtrOffset);
22679 SDValue NewLoad = DAG.getLoad(
22680 NewVT, DL, Chain, NewPtr, LD->getPointerInfo().getWithOffset(PtrOffset),
22681 NewAlign, LD->getMemOperand()->getFlags(), LD->getAAInfo());
22682 LoadOps.push_back(NewLoad);
22683 LoadOpsChain.push_back(SDValue(cast<SDNode>(NewLoad), 1));
22684 }
22685
22686 // Process remaining bits of the load operation.
22687 // This is done by creating an UNDEF vector to match the size of the
22688 // 256-bit loads and inserting the remaining load to it. We extract the
22689 // original load type at the end using EXTRACT_SUBVECTOR instruction.
22690 unsigned BitsRemaining = MemVT.getSizeInBits() % 256;
22691 unsigned PtrOffset = (MemVT.getSizeInBits() - BitsRemaining) / 8;
22692 MVT RemainingVT = MVT::getVectorVT(
22693 MemVT.getVectorElementType().getSimpleVT(),
22694 BitsRemaining / MemVT.getVectorElementType().getSizeInBits());
22695 SDValue NewPtr = DAG.getMemBasePlusOffset(
22696 BasePtr, TypeSize::getFixed(PtrOffset), DL, Flags);
22697 Align NewAlign = commonAlignment(LD->getAlign(), PtrOffset);
22698 SDValue RemainingLoad =
22699 DAG.getLoad(RemainingVT, DL, Chain, NewPtr,
22700 LD->getPointerInfo().getWithOffset(PtrOffset), NewAlign,
22701 LD->getMemOperand()->getFlags(), LD->getAAInfo());
22702 SDValue UndefVector = DAG.getUNDEF(NewVT);
22703 SDValue InsertIdx = DAG.getVectorIdxConstant(0, DL);
22704 SDValue ExtendedReminingLoad =
22705 DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewVT,
22706 {UndefVector, RemainingLoad, InsertIdx});
22707 LoadOps.push_back(ExtendedReminingLoad);
22708 LoadOpsChain.push_back(SDValue(cast<SDNode>(RemainingLoad), 1));
22709 EVT ConcatVT =
22710 EVT::getVectorVT(*DAG.getContext(), MemVT.getScalarType(),
22711 LoadOps.size() * NewVT.getVectorNumElements());
22712 SDValue ConcatVectors =
22713 DAG.getNode(ISD::CONCAT_VECTORS, DL, ConcatVT, LoadOps);
22714 // Extract the original vector type size.
22715 SDValue ExtractSubVector =
22716 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MemVT,
22717 {ConcatVectors, DAG.getVectorIdxConstant(0, DL)});
22718 SDValue TokenFactor =
22719 DAG.getNode(ISD::TokenFactor, DL, MVT::Other, LoadOpsChain);
22720 return DAG.getMergeValues({ExtractSubVector, TokenFactor}, DL);
22721 }
22722
tryGetOriginalBoolVectorType(SDValue Op,int Depth=0)22723 static EVT tryGetOriginalBoolVectorType(SDValue Op, int Depth = 0) {
22724 EVT VecVT = Op.getValueType();
22725 assert(VecVT.isVector() && VecVT.getVectorElementType() == MVT::i1 &&
22726 "Need boolean vector type.");
22727
22728 if (Depth > 3)
22729 return MVT::INVALID_SIMPLE_VALUE_TYPE;
22730
22731 // We can get the base type from a vector compare or truncate.
22732 if (Op.getOpcode() == ISD::SETCC || Op.getOpcode() == ISD::TRUNCATE)
22733 return Op.getOperand(0).getValueType();
22734
22735 // If an operand is a bool vector, continue looking.
22736 EVT BaseVT = MVT::INVALID_SIMPLE_VALUE_TYPE;
22737 for (SDValue Operand : Op->op_values()) {
22738 if (Operand.getValueType() != VecVT)
22739 continue;
22740
22741 EVT OperandVT = tryGetOriginalBoolVectorType(Operand, Depth + 1);
22742 if (!BaseVT.isSimple())
22743 BaseVT = OperandVT;
22744 else if (OperandVT != BaseVT)
22745 return MVT::INVALID_SIMPLE_VALUE_TYPE;
22746 }
22747
22748 return BaseVT;
22749 }
22750
22751 // When converting a <N x iX> vector to <N x i1> to store or use as a scalar
22752 // iN, we can use a trick that extracts the i^th bit from the i^th element and
22753 // then performs a vector add to get a scalar bitmask. This requires that each
22754 // element's bits are either all 1 or all 0.
vectorToScalarBitmask(SDNode * N,SelectionDAG & DAG)22755 static SDValue vectorToScalarBitmask(SDNode *N, SelectionDAG &DAG) {
22756 SDLoc DL(N);
22757 SDValue ComparisonResult(N, 0);
22758 EVT VecVT = ComparisonResult.getValueType();
22759 assert(VecVT.isVector() && "Must be a vector type");
22760
22761 unsigned NumElts = VecVT.getVectorNumElements();
22762 if (NumElts != 2 && NumElts != 4 && NumElts != 8 && NumElts != 16)
22763 return SDValue();
22764
22765 if (VecVT.getVectorElementType() != MVT::i1 &&
22766 !DAG.getTargetLoweringInfo().isTypeLegal(VecVT))
22767 return SDValue();
22768
22769 // If we can find the original types to work on instead of a vector of i1,
22770 // we can avoid extend/extract conversion instructions.
22771 if (VecVT.getVectorElementType() == MVT::i1) {
22772 VecVT = tryGetOriginalBoolVectorType(ComparisonResult);
22773 if (!VecVT.isSimple()) {
22774 unsigned BitsPerElement = std::max(64 / NumElts, 8u); // >= 64-bit vector
22775 VecVT = MVT::getVectorVT(MVT::getIntegerVT(BitsPerElement), NumElts);
22776 }
22777 }
22778 VecVT = VecVT.changeVectorElementTypeToInteger();
22779
22780 // Large vectors don't map directly to this conversion, so to avoid too many
22781 // edge cases, we don't apply it here. The conversion will likely still be
22782 // applied later via multiple smaller vectors, whose results are concatenated.
22783 if (VecVT.getSizeInBits() > 128)
22784 return SDValue();
22785
22786 // Ensure that all elements' bits are either 0s or 1s.
22787 ComparisonResult = DAG.getSExtOrTrunc(ComparisonResult, DL, VecVT);
22788
22789 SmallVector<SDValue, 16> MaskConstants;
22790 if (DAG.getSubtarget<AArch64Subtarget>().isNeonAvailable() &&
22791 VecVT == MVT::v16i8) {
22792 // v16i8 is a special case, as we have 16 entries but only 8 positional bits
22793 // per entry. We split it into two halves, apply the mask, zip the halves to
22794 // create 8x 16-bit values, and the perform the vector reduce.
22795 for (unsigned Half = 0; Half < 2; ++Half) {
22796 for (unsigned MaskBit = 1; MaskBit <= 128; MaskBit *= 2) {
22797 MaskConstants.push_back(DAG.getConstant(MaskBit, DL, MVT::i32));
22798 }
22799 }
22800 SDValue Mask = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, MaskConstants);
22801 SDValue RepresentativeBits =
22802 DAG.getNode(ISD::AND, DL, VecVT, ComparisonResult, Mask);
22803
22804 SDValue UpperRepresentativeBits =
22805 DAG.getNode(AArch64ISD::EXT, DL, VecVT, RepresentativeBits,
22806 RepresentativeBits, DAG.getConstant(8, DL, MVT::i32));
22807 SDValue Zipped = DAG.getNode(AArch64ISD::ZIP1, DL, VecVT,
22808 RepresentativeBits, UpperRepresentativeBits);
22809 Zipped = DAG.getNode(ISD::BITCAST, DL, MVT::v8i16, Zipped);
22810 return DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i16, Zipped);
22811 }
22812
22813 // All other vector sizes.
22814 unsigned MaxBitMask = 1u << (VecVT.getVectorNumElements() - 1);
22815 for (unsigned MaskBit = 1; MaskBit <= MaxBitMask; MaskBit *= 2) {
22816 MaskConstants.push_back(DAG.getConstant(MaskBit, DL, MVT::i64));
22817 }
22818
22819 SDValue Mask = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, MaskConstants);
22820 SDValue RepresentativeBits =
22821 DAG.getNode(ISD::AND, DL, VecVT, ComparisonResult, Mask);
22822 EVT ResultVT = MVT::getIntegerVT(std::max<unsigned>(
22823 NumElts, VecVT.getVectorElementType().getSizeInBits()));
22824 return DAG.getNode(ISD::VECREDUCE_ADD, DL, ResultVT, RepresentativeBits);
22825 }
22826
combineBoolVectorAndTruncateStore(SelectionDAG & DAG,StoreSDNode * Store)22827 static SDValue combineBoolVectorAndTruncateStore(SelectionDAG &DAG,
22828 StoreSDNode *Store) {
22829 if (!Store->isTruncatingStore())
22830 return SDValue();
22831
22832 SDLoc DL(Store);
22833 SDValue VecOp = Store->getValue();
22834 EVT VT = VecOp.getValueType();
22835 EVT MemVT = Store->getMemoryVT();
22836
22837 if (!MemVT.isVector() || !VT.isVector() ||
22838 MemVT.getVectorElementType() != MVT::i1)
22839 return SDValue();
22840
22841 // If we are storing a vector that we are currently building, let
22842 // `scalarizeVectorStore()` handle this more efficiently.
22843 if (VecOp.getOpcode() == ISD::BUILD_VECTOR)
22844 return SDValue();
22845
22846 VecOp = DAG.getNode(ISD::TRUNCATE, DL, MemVT, VecOp);
22847 SDValue VectorBits = vectorToScalarBitmask(VecOp.getNode(), DAG);
22848 if (!VectorBits)
22849 return SDValue();
22850
22851 EVT StoreVT =
22852 EVT::getIntegerVT(*DAG.getContext(), MemVT.getStoreSizeInBits());
22853 SDValue ExtendedBits = DAG.getZExtOrTrunc(VectorBits, DL, StoreVT);
22854 return DAG.getStore(Store->getChain(), DL, ExtendedBits, Store->getBasePtr(),
22855 Store->getMemOperand());
22856 }
22857
isHalvingTruncateOfLegalScalableType(EVT SrcVT,EVT DstVT)22858 bool isHalvingTruncateOfLegalScalableType(EVT SrcVT, EVT DstVT) {
22859 return (SrcVT == MVT::nxv8i16 && DstVT == MVT::nxv8i8) ||
22860 (SrcVT == MVT::nxv4i32 && DstVT == MVT::nxv4i16) ||
22861 (SrcVT == MVT::nxv2i64 && DstVT == MVT::nxv2i32);
22862 }
22863
22864 // Combine store (trunc X to <3 x i8>) to sequence of ST1.b.
combineI8TruncStore(StoreSDNode * ST,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)22865 static SDValue combineI8TruncStore(StoreSDNode *ST, SelectionDAG &DAG,
22866 const AArch64Subtarget *Subtarget) {
22867 SDValue Value = ST->getValue();
22868 EVT ValueVT = Value.getValueType();
22869
22870 if (ST->isVolatile() || !Subtarget->isLittleEndian() ||
22871 Value.getOpcode() != ISD::TRUNCATE ||
22872 ValueVT != EVT::getVectorVT(*DAG.getContext(), MVT::i8, 3))
22873 return SDValue();
22874
22875 assert(ST->getOffset().isUndef() && "undef offset expected");
22876 SDLoc DL(ST);
22877 auto WideVT = EVT::getVectorVT(
22878 *DAG.getContext(),
22879 Value->getOperand(0).getValueType().getVectorElementType(), 4);
22880 SDValue UndefVector = DAG.getUNDEF(WideVT);
22881 SDValue WideTrunc = DAG.getNode(
22882 ISD::INSERT_SUBVECTOR, DL, WideVT,
22883 {UndefVector, Value->getOperand(0), DAG.getVectorIdxConstant(0, DL)});
22884 SDValue Cast = DAG.getNode(
22885 ISD::BITCAST, DL, WideVT.getSizeInBits() == 64 ? MVT::v8i8 : MVT::v16i8,
22886 WideTrunc);
22887
22888 MachineFunction &MF = DAG.getMachineFunction();
22889 SDValue Chain = ST->getChain();
22890 MachineMemOperand *MMO = ST->getMemOperand();
22891 unsigned IdxScale = WideVT.getScalarSizeInBits() / 8;
22892 SDValue E2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, Cast,
22893 DAG.getConstant(2 * IdxScale, DL, MVT::i64));
22894 TypeSize Offset2 = TypeSize::getFixed(2);
22895 SDValue Ptr2 = DAG.getMemBasePlusOffset(ST->getBasePtr(), Offset2, DL);
22896 Chain = DAG.getStore(Chain, DL, E2, Ptr2, MF.getMachineMemOperand(MMO, 2, 1));
22897
22898 SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, Cast,
22899 DAG.getConstant(1 * IdxScale, DL, MVT::i64));
22900 TypeSize Offset1 = TypeSize::getFixed(1);
22901 SDValue Ptr1 = DAG.getMemBasePlusOffset(ST->getBasePtr(), Offset1, DL);
22902 Chain = DAG.getStore(Chain, DL, E1, Ptr1, MF.getMachineMemOperand(MMO, 1, 1));
22903
22904 SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8, Cast,
22905 DAG.getConstant(0, DL, MVT::i64));
22906 Chain = DAG.getStore(Chain, DL, E0, ST->getBasePtr(),
22907 MF.getMachineMemOperand(MMO, 0, 1));
22908 return Chain;
22909 }
22910
performSTORECombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)22911 static SDValue performSTORECombine(SDNode *N,
22912 TargetLowering::DAGCombinerInfo &DCI,
22913 SelectionDAG &DAG,
22914 const AArch64Subtarget *Subtarget) {
22915 StoreSDNode *ST = cast<StoreSDNode>(N);
22916 SDValue Chain = ST->getChain();
22917 SDValue Value = ST->getValue();
22918 SDValue Ptr = ST->getBasePtr();
22919 EVT ValueVT = Value.getValueType();
22920
22921 auto hasValidElementTypeForFPTruncStore = [](EVT VT) {
22922 EVT EltVT = VT.getVectorElementType();
22923 return EltVT == MVT::f32 || EltVT == MVT::f64;
22924 };
22925
22926 if (SDValue Res = combineI8TruncStore(ST, DAG, Subtarget))
22927 return Res;
22928
22929 // If this is an FP_ROUND followed by a store, fold this into a truncating
22930 // store. We can do this even if this is already a truncstore.
22931 // We purposefully don't care about legality of the nodes here as we know
22932 // they can be split down into something legal.
22933 if (DCI.isBeforeLegalizeOps() && Value.getOpcode() == ISD::FP_ROUND &&
22934 Value.getNode()->hasOneUse() && ST->isUnindexed() &&
22935 Subtarget->useSVEForFixedLengthVectors() &&
22936 ValueVT.isFixedLengthVector() &&
22937 ValueVT.getFixedSizeInBits() >= Subtarget->getMinSVEVectorSizeInBits() &&
22938 hasValidElementTypeForFPTruncStore(Value.getOperand(0).getValueType()))
22939 return DAG.getTruncStore(Chain, SDLoc(N), Value.getOperand(0), Ptr,
22940 ST->getMemoryVT(), ST->getMemOperand());
22941
22942 if (SDValue Split = splitStores(N, DCI, DAG, Subtarget))
22943 return Split;
22944
22945 if (Subtarget->supportsAddressTopByteIgnored() &&
22946 performTBISimplification(N->getOperand(2), DCI, DAG))
22947 return SDValue(N, 0);
22948
22949 if (SDValue Store = foldTruncStoreOfExt(DAG, N))
22950 return Store;
22951
22952 if (SDValue Store = combineBoolVectorAndTruncateStore(DAG, ST))
22953 return Store;
22954
22955 if (ST->isTruncatingStore()) {
22956 EVT StoreVT = ST->getMemoryVT();
22957 if (!isHalvingTruncateOfLegalScalableType(ValueVT, StoreVT))
22958 return SDValue();
22959 if (SDValue Rshrnb =
22960 trySimplifySrlAddToRshrnb(ST->getOperand(1), DAG, Subtarget)) {
22961 return DAG.getTruncStore(ST->getChain(), ST, Rshrnb, ST->getBasePtr(),
22962 StoreVT, ST->getMemOperand());
22963 }
22964 }
22965
22966 return SDValue();
22967 }
22968
performMSTORECombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)22969 static SDValue performMSTORECombine(SDNode *N,
22970 TargetLowering::DAGCombinerInfo &DCI,
22971 SelectionDAG &DAG,
22972 const AArch64Subtarget *Subtarget) {
22973 MaskedStoreSDNode *MST = cast<MaskedStoreSDNode>(N);
22974 SDValue Value = MST->getValue();
22975 SDValue Mask = MST->getMask();
22976 SDLoc DL(N);
22977
22978 // If this is a UZP1 followed by a masked store, fold this into a masked
22979 // truncating store. We can do this even if this is already a masked
22980 // truncstore.
22981 if (Value.getOpcode() == AArch64ISD::UZP1 && Value->hasOneUse() &&
22982 MST->isUnindexed() && Mask->getOpcode() == AArch64ISD::PTRUE &&
22983 Value.getValueType().isInteger()) {
22984 Value = Value.getOperand(0);
22985 if (Value.getOpcode() == ISD::BITCAST) {
22986 EVT HalfVT =
22987 Value.getValueType().getHalfNumVectorElementsVT(*DAG.getContext());
22988 EVT InVT = Value.getOperand(0).getValueType();
22989
22990 if (HalfVT.widenIntegerVectorElementType(*DAG.getContext()) == InVT) {
22991 unsigned MinSVESize = Subtarget->getMinSVEVectorSizeInBits();
22992 unsigned PgPattern = Mask->getConstantOperandVal(0);
22993
22994 // Ensure we can double the size of the predicate pattern
22995 unsigned NumElts = getNumElementsFromSVEPredPattern(PgPattern);
22996 if (NumElts && NumElts * InVT.getVectorElementType().getSizeInBits() <=
22997 MinSVESize) {
22998 Mask = getPTrue(DAG, DL, InVT.changeVectorElementType(MVT::i1),
22999 PgPattern);
23000 return DAG.getMaskedStore(MST->getChain(), DL, Value.getOperand(0),
23001 MST->getBasePtr(), MST->getOffset(), Mask,
23002 MST->getMemoryVT(), MST->getMemOperand(),
23003 MST->getAddressingMode(),
23004 /*IsTruncating=*/true);
23005 }
23006 }
23007 }
23008 }
23009
23010 if (MST->isTruncatingStore()) {
23011 EVT ValueVT = Value->getValueType(0);
23012 EVT MemVT = MST->getMemoryVT();
23013 if (!isHalvingTruncateOfLegalScalableType(ValueVT, MemVT))
23014 return SDValue();
23015 if (SDValue Rshrnb = trySimplifySrlAddToRshrnb(Value, DAG, Subtarget)) {
23016 return DAG.getMaskedStore(MST->getChain(), DL, Rshrnb, MST->getBasePtr(),
23017 MST->getOffset(), MST->getMask(),
23018 MST->getMemoryVT(), MST->getMemOperand(),
23019 MST->getAddressingMode(), true);
23020 }
23021 }
23022
23023 return SDValue();
23024 }
23025
23026 /// \return true if part of the index was folded into the Base.
foldIndexIntoBase(SDValue & BasePtr,SDValue & Index,SDValue Scale,SDLoc DL,SelectionDAG & DAG)23027 static bool foldIndexIntoBase(SDValue &BasePtr, SDValue &Index, SDValue Scale,
23028 SDLoc DL, SelectionDAG &DAG) {
23029 // This function assumes a vector of i64 indices.
23030 EVT IndexVT = Index.getValueType();
23031 if (!IndexVT.isVector() || IndexVT.getVectorElementType() != MVT::i64)
23032 return false;
23033
23034 // Simplify:
23035 // BasePtr = Ptr
23036 // Index = X + splat(Offset)
23037 // ->
23038 // BasePtr = Ptr + Offset * scale.
23039 // Index = X
23040 if (Index.getOpcode() == ISD::ADD) {
23041 if (auto Offset = DAG.getSplatValue(Index.getOperand(1))) {
23042 Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, Scale);
23043 BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset);
23044 Index = Index.getOperand(0);
23045 return true;
23046 }
23047 }
23048
23049 // Simplify:
23050 // BasePtr = Ptr
23051 // Index = (X + splat(Offset)) << splat(Shift)
23052 // ->
23053 // BasePtr = Ptr + (Offset << Shift) * scale)
23054 // Index = X << splat(shift)
23055 if (Index.getOpcode() == ISD::SHL &&
23056 Index.getOperand(0).getOpcode() == ISD::ADD) {
23057 SDValue Add = Index.getOperand(0);
23058 SDValue ShiftOp = Index.getOperand(1);
23059 SDValue OffsetOp = Add.getOperand(1);
23060 if (auto Shift = DAG.getSplatValue(ShiftOp))
23061 if (auto Offset = DAG.getSplatValue(OffsetOp)) {
23062 Offset = DAG.getNode(ISD::SHL, DL, MVT::i64, Offset, Shift);
23063 Offset = DAG.getNode(ISD::MUL, DL, MVT::i64, Offset, Scale);
23064 BasePtr = DAG.getNode(ISD::ADD, DL, MVT::i64, BasePtr, Offset);
23065 Index = DAG.getNode(ISD::SHL, DL, Index.getValueType(),
23066 Add.getOperand(0), ShiftOp);
23067 return true;
23068 }
23069 }
23070
23071 return false;
23072 }
23073
23074 // Analyse the specified address returning true if a more optimal addressing
23075 // mode is available. When returning true all parameters are updated to reflect
23076 // their recommended values.
findMoreOptimalIndexType(const MaskedGatherScatterSDNode * N,SDValue & BasePtr,SDValue & Index,SelectionDAG & DAG)23077 static bool findMoreOptimalIndexType(const MaskedGatherScatterSDNode *N,
23078 SDValue &BasePtr, SDValue &Index,
23079 SelectionDAG &DAG) {
23080 // Try to iteratively fold parts of the index into the base pointer to
23081 // simplify the index as much as possible.
23082 bool Changed = false;
23083 while (foldIndexIntoBase(BasePtr, Index, N->getScale(), SDLoc(N), DAG))
23084 Changed = true;
23085
23086 // Only consider element types that are pointer sized as smaller types can
23087 // be easily promoted.
23088 EVT IndexVT = Index.getValueType();
23089 if (IndexVT.getVectorElementType() != MVT::i64 || IndexVT == MVT::nxv2i64)
23090 return Changed;
23091
23092 // Can indices be trivially shrunk?
23093 EVT DataVT = N->getOperand(1).getValueType();
23094 // Don't attempt to shrink the index for fixed vectors of 64 bit data since it
23095 // will later be re-extended to 64 bits in legalization
23096 if (DataVT.isFixedLengthVector() && DataVT.getScalarSizeInBits() == 64)
23097 return Changed;
23098 if (ISD::isVectorShrinkable(Index.getNode(), 32, N->isIndexSigned())) {
23099 EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32);
23100 Index = DAG.getNode(ISD::TRUNCATE, SDLoc(N), NewIndexVT, Index);
23101 return true;
23102 }
23103
23104 // Match:
23105 // Index = step(const)
23106 int64_t Stride = 0;
23107 if (Index.getOpcode() == ISD::STEP_VECTOR) {
23108 Stride = cast<ConstantSDNode>(Index.getOperand(0))->getSExtValue();
23109 }
23110 // Match:
23111 // Index = step(const) << shift(const)
23112 else if (Index.getOpcode() == ISD::SHL &&
23113 Index.getOperand(0).getOpcode() == ISD::STEP_VECTOR) {
23114 SDValue RHS = Index.getOperand(1);
23115 if (auto *Shift =
23116 dyn_cast_or_null<ConstantSDNode>(DAG.getSplatValue(RHS))) {
23117 int64_t Step = (int64_t)Index.getOperand(0).getConstantOperandVal(1);
23118 Stride = Step << Shift->getZExtValue();
23119 }
23120 }
23121
23122 // Return early because no supported pattern is found.
23123 if (Stride == 0)
23124 return Changed;
23125
23126 if (Stride < std::numeric_limits<int32_t>::min() ||
23127 Stride > std::numeric_limits<int32_t>::max())
23128 return Changed;
23129
23130 const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
23131 unsigned MaxVScale =
23132 Subtarget.getMaxSVEVectorSizeInBits() / AArch64::SVEBitsPerBlock;
23133 int64_t LastElementOffset =
23134 IndexVT.getVectorMinNumElements() * Stride * MaxVScale;
23135
23136 if (LastElementOffset < std::numeric_limits<int32_t>::min() ||
23137 LastElementOffset > std::numeric_limits<int32_t>::max())
23138 return Changed;
23139
23140 EVT NewIndexVT = IndexVT.changeVectorElementType(MVT::i32);
23141 // Stride does not scale explicitly by 'Scale', because it happens in
23142 // the gather/scatter addressing mode.
23143 Index = DAG.getStepVector(SDLoc(N), NewIndexVT, APInt(32, Stride));
23144 return true;
23145 }
23146
performMaskedGatherScatterCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)23147 static SDValue performMaskedGatherScatterCombine(
23148 SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
23149 MaskedGatherScatterSDNode *MGS = cast<MaskedGatherScatterSDNode>(N);
23150 assert(MGS && "Can only combine gather load or scatter store nodes");
23151
23152 if (!DCI.isBeforeLegalize())
23153 return SDValue();
23154
23155 SDLoc DL(MGS);
23156 SDValue Chain = MGS->getChain();
23157 SDValue Scale = MGS->getScale();
23158 SDValue Index = MGS->getIndex();
23159 SDValue Mask = MGS->getMask();
23160 SDValue BasePtr = MGS->getBasePtr();
23161 ISD::MemIndexType IndexType = MGS->getIndexType();
23162
23163 if (!findMoreOptimalIndexType(MGS, BasePtr, Index, DAG))
23164 return SDValue();
23165
23166 // Here we catch such cases early and change MGATHER's IndexType to allow
23167 // the use of an Index that's more legalisation friendly.
23168 if (auto *MGT = dyn_cast<MaskedGatherSDNode>(MGS)) {
23169 SDValue PassThru = MGT->getPassThru();
23170 SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
23171 return DAG.getMaskedGather(
23172 DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
23173 Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
23174 }
23175 auto *MSC = cast<MaskedScatterSDNode>(MGS);
23176 SDValue Data = MSC->getValue();
23177 SDValue Ops[] = {Chain, Data, Mask, BasePtr, Index, Scale};
23178 return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL,
23179 Ops, MSC->getMemOperand(), IndexType,
23180 MSC->isTruncatingStore());
23181 }
23182
23183 /// Target-specific DAG combine function for NEON load/store intrinsics
23184 /// to merge base address updates.
performNEONPostLDSTCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)23185 static SDValue performNEONPostLDSTCombine(SDNode *N,
23186 TargetLowering::DAGCombinerInfo &DCI,
23187 SelectionDAG &DAG) {
23188 if (DCI.isBeforeLegalize() || DCI.isCalledByLegalizer())
23189 return SDValue();
23190
23191 unsigned AddrOpIdx = N->getNumOperands() - 1;
23192 SDValue Addr = N->getOperand(AddrOpIdx);
23193
23194 // Search for a use of the address operand that is an increment.
23195 for (SDNode::use_iterator UI = Addr.getNode()->use_begin(),
23196 UE = Addr.getNode()->use_end(); UI != UE; ++UI) {
23197 SDNode *User = *UI;
23198 if (User->getOpcode() != ISD::ADD ||
23199 UI.getUse().getResNo() != Addr.getResNo())
23200 continue;
23201
23202 // Check that the add is independent of the load/store. Otherwise, folding
23203 // it would create a cycle.
23204 SmallPtrSet<const SDNode *, 32> Visited;
23205 SmallVector<const SDNode *, 16> Worklist;
23206 Visited.insert(Addr.getNode());
23207 Worklist.push_back(N);
23208 Worklist.push_back(User);
23209 if (SDNode::hasPredecessorHelper(N, Visited, Worklist) ||
23210 SDNode::hasPredecessorHelper(User, Visited, Worklist))
23211 continue;
23212
23213 // Find the new opcode for the updating load/store.
23214 bool IsStore = false;
23215 bool IsLaneOp = false;
23216 bool IsDupOp = false;
23217 unsigned NewOpc = 0;
23218 unsigned NumVecs = 0;
23219 unsigned IntNo = N->getConstantOperandVal(1);
23220 switch (IntNo) {
23221 default: llvm_unreachable("unexpected intrinsic for Neon base update");
23222 case Intrinsic::aarch64_neon_ld2: NewOpc = AArch64ISD::LD2post;
23223 NumVecs = 2; break;
23224 case Intrinsic::aarch64_neon_ld3: NewOpc = AArch64ISD::LD3post;
23225 NumVecs = 3; break;
23226 case Intrinsic::aarch64_neon_ld4: NewOpc = AArch64ISD::LD4post;
23227 NumVecs = 4; break;
23228 case Intrinsic::aarch64_neon_st2: NewOpc = AArch64ISD::ST2post;
23229 NumVecs = 2; IsStore = true; break;
23230 case Intrinsic::aarch64_neon_st3: NewOpc = AArch64ISD::ST3post;
23231 NumVecs = 3; IsStore = true; break;
23232 case Intrinsic::aarch64_neon_st4: NewOpc = AArch64ISD::ST4post;
23233 NumVecs = 4; IsStore = true; break;
23234 case Intrinsic::aarch64_neon_ld1x2: NewOpc = AArch64ISD::LD1x2post;
23235 NumVecs = 2; break;
23236 case Intrinsic::aarch64_neon_ld1x3: NewOpc = AArch64ISD::LD1x3post;
23237 NumVecs = 3; break;
23238 case Intrinsic::aarch64_neon_ld1x4: NewOpc = AArch64ISD::LD1x4post;
23239 NumVecs = 4; break;
23240 case Intrinsic::aarch64_neon_st1x2: NewOpc = AArch64ISD::ST1x2post;
23241 NumVecs = 2; IsStore = true; break;
23242 case Intrinsic::aarch64_neon_st1x3: NewOpc = AArch64ISD::ST1x3post;
23243 NumVecs = 3; IsStore = true; break;
23244 case Intrinsic::aarch64_neon_st1x4: NewOpc = AArch64ISD::ST1x4post;
23245 NumVecs = 4; IsStore = true; break;
23246 case Intrinsic::aarch64_neon_ld2r: NewOpc = AArch64ISD::LD2DUPpost;
23247 NumVecs = 2; IsDupOp = true; break;
23248 case Intrinsic::aarch64_neon_ld3r: NewOpc = AArch64ISD::LD3DUPpost;
23249 NumVecs = 3; IsDupOp = true; break;
23250 case Intrinsic::aarch64_neon_ld4r: NewOpc = AArch64ISD::LD4DUPpost;
23251 NumVecs = 4; IsDupOp = true; break;
23252 case Intrinsic::aarch64_neon_ld2lane: NewOpc = AArch64ISD::LD2LANEpost;
23253 NumVecs = 2; IsLaneOp = true; break;
23254 case Intrinsic::aarch64_neon_ld3lane: NewOpc = AArch64ISD::LD3LANEpost;
23255 NumVecs = 3; IsLaneOp = true; break;
23256 case Intrinsic::aarch64_neon_ld4lane: NewOpc = AArch64ISD::LD4LANEpost;
23257 NumVecs = 4; IsLaneOp = true; break;
23258 case Intrinsic::aarch64_neon_st2lane: NewOpc = AArch64ISD::ST2LANEpost;
23259 NumVecs = 2; IsStore = true; IsLaneOp = true; break;
23260 case Intrinsic::aarch64_neon_st3lane: NewOpc = AArch64ISD::ST3LANEpost;
23261 NumVecs = 3; IsStore = true; IsLaneOp = true; break;
23262 case Intrinsic::aarch64_neon_st4lane: NewOpc = AArch64ISD::ST4LANEpost;
23263 NumVecs = 4; IsStore = true; IsLaneOp = true; break;
23264 }
23265
23266 EVT VecTy;
23267 if (IsStore)
23268 VecTy = N->getOperand(2).getValueType();
23269 else
23270 VecTy = N->getValueType(0);
23271
23272 // If the increment is a constant, it must match the memory ref size.
23273 SDValue Inc = User->getOperand(User->getOperand(0) == Addr ? 1 : 0);
23274 if (ConstantSDNode *CInc = dyn_cast<ConstantSDNode>(Inc.getNode())) {
23275 uint32_t IncVal = CInc->getZExtValue();
23276 unsigned NumBytes = NumVecs * VecTy.getSizeInBits() / 8;
23277 if (IsLaneOp || IsDupOp)
23278 NumBytes /= VecTy.getVectorNumElements();
23279 if (IncVal != NumBytes)
23280 continue;
23281 Inc = DAG.getRegister(AArch64::XZR, MVT::i64);
23282 }
23283 SmallVector<SDValue, 8> Ops;
23284 Ops.push_back(N->getOperand(0)); // Incoming chain
23285 // Load lane and store have vector list as input.
23286 if (IsLaneOp || IsStore)
23287 for (unsigned i = 2; i < AddrOpIdx; ++i)
23288 Ops.push_back(N->getOperand(i));
23289 Ops.push_back(Addr); // Base register
23290 Ops.push_back(Inc);
23291
23292 // Return Types.
23293 EVT Tys[6];
23294 unsigned NumResultVecs = (IsStore ? 0 : NumVecs);
23295 unsigned n;
23296 for (n = 0; n < NumResultVecs; ++n)
23297 Tys[n] = VecTy;
23298 Tys[n++] = MVT::i64; // Type of write back register
23299 Tys[n] = MVT::Other; // Type of the chain
23300 SDVTList SDTys = DAG.getVTList(ArrayRef(Tys, NumResultVecs + 2));
23301
23302 MemIntrinsicSDNode *MemInt = cast<MemIntrinsicSDNode>(N);
23303 SDValue UpdN = DAG.getMemIntrinsicNode(NewOpc, SDLoc(N), SDTys, Ops,
23304 MemInt->getMemoryVT(),
23305 MemInt->getMemOperand());
23306
23307 // Update the uses.
23308 std::vector<SDValue> NewResults;
23309 for (unsigned i = 0; i < NumResultVecs; ++i) {
23310 NewResults.push_back(SDValue(UpdN.getNode(), i));
23311 }
23312 NewResults.push_back(SDValue(UpdN.getNode(), NumResultVecs + 1));
23313 DCI.CombineTo(N, NewResults);
23314 DCI.CombineTo(User, SDValue(UpdN.getNode(), NumResultVecs));
23315
23316 break;
23317 }
23318 return SDValue();
23319 }
23320
23321 // Checks to see if the value is the prescribed width and returns information
23322 // about its extension mode.
23323 static
checkValueWidth(SDValue V,unsigned width,ISD::LoadExtType & ExtType)23324 bool checkValueWidth(SDValue V, unsigned width, ISD::LoadExtType &ExtType) {
23325 ExtType = ISD::NON_EXTLOAD;
23326 switch(V.getNode()->getOpcode()) {
23327 default:
23328 return false;
23329 case ISD::LOAD: {
23330 LoadSDNode *LoadNode = cast<LoadSDNode>(V.getNode());
23331 if ((LoadNode->getMemoryVT() == MVT::i8 && width == 8)
23332 || (LoadNode->getMemoryVT() == MVT::i16 && width == 16)) {
23333 ExtType = LoadNode->getExtensionType();
23334 return true;
23335 }
23336 return false;
23337 }
23338 case ISD::AssertSext: {
23339 VTSDNode *TypeNode = cast<VTSDNode>(V.getNode()->getOperand(1));
23340 if ((TypeNode->getVT() == MVT::i8 && width == 8)
23341 || (TypeNode->getVT() == MVT::i16 && width == 16)) {
23342 ExtType = ISD::SEXTLOAD;
23343 return true;
23344 }
23345 return false;
23346 }
23347 case ISD::AssertZext: {
23348 VTSDNode *TypeNode = cast<VTSDNode>(V.getNode()->getOperand(1));
23349 if ((TypeNode->getVT() == MVT::i8 && width == 8)
23350 || (TypeNode->getVT() == MVT::i16 && width == 16)) {
23351 ExtType = ISD::ZEXTLOAD;
23352 return true;
23353 }
23354 return false;
23355 }
23356 case ISD::Constant:
23357 case ISD::TargetConstant: {
23358 return std::abs(cast<ConstantSDNode>(V.getNode())->getSExtValue()) <
23359 1LL << (width - 1);
23360 }
23361 }
23362
23363 return true;
23364 }
23365
23366 // This function does a whole lot of voodoo to determine if the tests are
23367 // equivalent without and with a mask. Essentially what happens is that given a
23368 // DAG resembling:
23369 //
23370 // +-------------+ +-------------+ +-------------+ +-------------+
23371 // | Input | | AddConstant | | CompConstant| | CC |
23372 // +-------------+ +-------------+ +-------------+ +-------------+
23373 // | | | |
23374 // V V | +----------+
23375 // +-------------+ +----+ | |
23376 // | ADD | |0xff| | |
23377 // +-------------+ +----+ | |
23378 // | | | |
23379 // V V | |
23380 // +-------------+ | |
23381 // | AND | | |
23382 // +-------------+ | |
23383 // | | |
23384 // +-----+ | |
23385 // | | |
23386 // V V V
23387 // +-------------+
23388 // | CMP |
23389 // +-------------+
23390 //
23391 // The AND node may be safely removed for some combinations of inputs. In
23392 // particular we need to take into account the extension type of the Input,
23393 // the exact values of AddConstant, CompConstant, and CC, along with the nominal
23394 // width of the input (this can work for any width inputs, the above graph is
23395 // specific to 8 bits.
23396 //
23397 // The specific equations were worked out by generating output tables for each
23398 // AArch64CC value in terms of and AddConstant (w1), CompConstant(w2). The
23399 // problem was simplified by working with 4 bit inputs, which means we only
23400 // needed to reason about 24 distinct bit patterns: 8 patterns unique to zero
23401 // extension (8,15), 8 patterns unique to sign extensions (-8,-1), and 8
23402 // patterns present in both extensions (0,7). For every distinct set of
23403 // AddConstant and CompConstants bit patterns we can consider the masked and
23404 // unmasked versions to be equivalent if the result of this function is true for
23405 // all 16 distinct bit patterns of for the current extension type of Input (w0).
23406 //
23407 // sub w8, w0, w1
23408 // and w10, w8, #0x0f
23409 // cmp w8, w2
23410 // cset w9, AArch64CC
23411 // cmp w10, w2
23412 // cset w11, AArch64CC
23413 // cmp w9, w11
23414 // cset w0, eq
23415 // ret
23416 //
23417 // Since the above function shows when the outputs are equivalent it defines
23418 // when it is safe to remove the AND. Unfortunately it only runs on AArch64 and
23419 // would be expensive to run during compiles. The equations below were written
23420 // in a test harness that confirmed they gave equivalent outputs to the above
23421 // for all inputs function, so they can be used determine if the removal is
23422 // legal instead.
23423 //
23424 // isEquivalentMaskless() is the code for testing if the AND can be removed
23425 // factored out of the DAG recognition as the DAG can take several forms.
23426
isEquivalentMaskless(unsigned CC,unsigned width,ISD::LoadExtType ExtType,int AddConstant,int CompConstant)23427 static bool isEquivalentMaskless(unsigned CC, unsigned width,
23428 ISD::LoadExtType ExtType, int AddConstant,
23429 int CompConstant) {
23430 // By being careful about our equations and only writing the in term
23431 // symbolic values and well known constants (0, 1, -1, MaxUInt) we can
23432 // make them generally applicable to all bit widths.
23433 int MaxUInt = (1 << width);
23434
23435 // For the purposes of these comparisons sign extending the type is
23436 // equivalent to zero extending the add and displacing it by half the integer
23437 // width. Provided we are careful and make sure our equations are valid over
23438 // the whole range we can just adjust the input and avoid writing equations
23439 // for sign extended inputs.
23440 if (ExtType == ISD::SEXTLOAD)
23441 AddConstant -= (1 << (width-1));
23442
23443 switch(CC) {
23444 case AArch64CC::LE:
23445 case AArch64CC::GT:
23446 if ((AddConstant == 0) ||
23447 (CompConstant == MaxUInt - 1 && AddConstant < 0) ||
23448 (AddConstant >= 0 && CompConstant < 0) ||
23449 (AddConstant <= 0 && CompConstant <= 0 && CompConstant < AddConstant))
23450 return true;
23451 break;
23452 case AArch64CC::LT:
23453 case AArch64CC::GE:
23454 if ((AddConstant == 0) ||
23455 (AddConstant >= 0 && CompConstant <= 0) ||
23456 (AddConstant <= 0 && CompConstant <= 0 && CompConstant <= AddConstant))
23457 return true;
23458 break;
23459 case AArch64CC::HI:
23460 case AArch64CC::LS:
23461 if ((AddConstant >= 0 && CompConstant < 0) ||
23462 (AddConstant <= 0 && CompConstant >= -1 &&
23463 CompConstant < AddConstant + MaxUInt))
23464 return true;
23465 break;
23466 case AArch64CC::PL:
23467 case AArch64CC::MI:
23468 if ((AddConstant == 0) ||
23469 (AddConstant > 0 && CompConstant <= 0) ||
23470 (AddConstant < 0 && CompConstant <= AddConstant))
23471 return true;
23472 break;
23473 case AArch64CC::LO:
23474 case AArch64CC::HS:
23475 if ((AddConstant >= 0 && CompConstant <= 0) ||
23476 (AddConstant <= 0 && CompConstant >= 0 &&
23477 CompConstant <= AddConstant + MaxUInt))
23478 return true;
23479 break;
23480 case AArch64CC::EQ:
23481 case AArch64CC::NE:
23482 if ((AddConstant > 0 && CompConstant < 0) ||
23483 (AddConstant < 0 && CompConstant >= 0 &&
23484 CompConstant < AddConstant + MaxUInt) ||
23485 (AddConstant >= 0 && CompConstant >= 0 &&
23486 CompConstant >= AddConstant) ||
23487 (AddConstant <= 0 && CompConstant < 0 && CompConstant < AddConstant))
23488 return true;
23489 break;
23490 case AArch64CC::VS:
23491 case AArch64CC::VC:
23492 case AArch64CC::AL:
23493 case AArch64CC::NV:
23494 return true;
23495 case AArch64CC::Invalid:
23496 break;
23497 }
23498
23499 return false;
23500 }
23501
23502 // (X & C) >u Mask --> (X & (C & (~Mask)) != 0
23503 // (X & C) <u Pow2 --> (X & (C & ~(Pow2-1)) == 0
performSubsToAndsCombine(SDNode * N,SDNode * SubsNode,SDNode * AndNode,SelectionDAG & DAG,unsigned CCIndex,unsigned CmpIndex,unsigned CC)23504 static SDValue performSubsToAndsCombine(SDNode *N, SDNode *SubsNode,
23505 SDNode *AndNode, SelectionDAG &DAG,
23506 unsigned CCIndex, unsigned CmpIndex,
23507 unsigned CC) {
23508 ConstantSDNode *SubsC = dyn_cast<ConstantSDNode>(SubsNode->getOperand(1));
23509 if (!SubsC)
23510 return SDValue();
23511
23512 APInt SubsAP = SubsC->getAPIntValue();
23513 if (CC == AArch64CC::HI) {
23514 if (!SubsAP.isMask())
23515 return SDValue();
23516 } else if (CC == AArch64CC::LO) {
23517 if (!SubsAP.isPowerOf2())
23518 return SDValue();
23519 } else
23520 return SDValue();
23521
23522 ConstantSDNode *AndC = dyn_cast<ConstantSDNode>(AndNode->getOperand(1));
23523 if (!AndC)
23524 return SDValue();
23525
23526 APInt MaskAP = CC == AArch64CC::HI ? SubsAP : (SubsAP - 1);
23527
23528 SDLoc DL(N);
23529 APInt AndSMask = (~MaskAP) & AndC->getAPIntValue();
23530 SDValue ANDS = DAG.getNode(
23531 AArch64ISD::ANDS, DL, SubsNode->getVTList(), AndNode->getOperand(0),
23532 DAG.getConstant(AndSMask, DL, SubsC->getValueType(0)));
23533 SDValue AArch64_CC =
23534 DAG.getConstant(CC == AArch64CC::HI ? AArch64CC::NE : AArch64CC::EQ, DL,
23535 N->getOperand(CCIndex)->getValueType(0));
23536
23537 // For now, only performCSELCombine and performBRCONDCombine call this
23538 // function. And both of them pass 2 for CCIndex, 3 for CmpIndex with 4
23539 // operands. So just init the ops direct to simplify the code. If we have some
23540 // other case with different CCIndex, CmpIndex, we need to use for loop to
23541 // rewrite the code here.
23542 // TODO: Do we need to assert number of operand is 4 here?
23543 assert((CCIndex == 2 && CmpIndex == 3) &&
23544 "Expected CCIndex to be 2 and CmpIndex to be 3.");
23545 SDValue Ops[] = {N->getOperand(0), N->getOperand(1), AArch64_CC,
23546 ANDS.getValue(1)};
23547 return DAG.getNode(N->getOpcode(), N, N->getVTList(), Ops);
23548 }
23549
23550 static
performCONDCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG,unsigned CCIndex,unsigned CmpIndex)23551 SDValue performCONDCombine(SDNode *N,
23552 TargetLowering::DAGCombinerInfo &DCI,
23553 SelectionDAG &DAG, unsigned CCIndex,
23554 unsigned CmpIndex) {
23555 unsigned CC = cast<ConstantSDNode>(N->getOperand(CCIndex))->getSExtValue();
23556 SDNode *SubsNode = N->getOperand(CmpIndex).getNode();
23557 unsigned CondOpcode = SubsNode->getOpcode();
23558
23559 if (CondOpcode != AArch64ISD::SUBS || SubsNode->hasAnyUseOfValue(0) ||
23560 !SubsNode->hasOneUse())
23561 return SDValue();
23562
23563 // There is a SUBS feeding this condition. Is it fed by a mask we can
23564 // use?
23565
23566 SDNode *AndNode = SubsNode->getOperand(0).getNode();
23567 unsigned MaskBits = 0;
23568
23569 if (AndNode->getOpcode() != ISD::AND)
23570 return SDValue();
23571
23572 if (SDValue Val = performSubsToAndsCombine(N, SubsNode, AndNode, DAG, CCIndex,
23573 CmpIndex, CC))
23574 return Val;
23575
23576 if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(AndNode->getOperand(1))) {
23577 uint32_t CNV = CN->getZExtValue();
23578 if (CNV == 255)
23579 MaskBits = 8;
23580 else if (CNV == 65535)
23581 MaskBits = 16;
23582 }
23583
23584 if (!MaskBits)
23585 return SDValue();
23586
23587 SDValue AddValue = AndNode->getOperand(0);
23588
23589 if (AddValue.getOpcode() != ISD::ADD)
23590 return SDValue();
23591
23592 // The basic dag structure is correct, grab the inputs and validate them.
23593
23594 SDValue AddInputValue1 = AddValue.getNode()->getOperand(0);
23595 SDValue AddInputValue2 = AddValue.getNode()->getOperand(1);
23596 SDValue SubsInputValue = SubsNode->getOperand(1);
23597
23598 // The mask is present and the provenance of all the values is a smaller type,
23599 // lets see if the mask is superfluous.
23600
23601 if (!isa<ConstantSDNode>(AddInputValue2.getNode()) ||
23602 !isa<ConstantSDNode>(SubsInputValue.getNode()))
23603 return SDValue();
23604
23605 ISD::LoadExtType ExtType;
23606
23607 if (!checkValueWidth(SubsInputValue, MaskBits, ExtType) ||
23608 !checkValueWidth(AddInputValue2, MaskBits, ExtType) ||
23609 !checkValueWidth(AddInputValue1, MaskBits, ExtType) )
23610 return SDValue();
23611
23612 if(!isEquivalentMaskless(CC, MaskBits, ExtType,
23613 cast<ConstantSDNode>(AddInputValue2.getNode())->getSExtValue(),
23614 cast<ConstantSDNode>(SubsInputValue.getNode())->getSExtValue()))
23615 return SDValue();
23616
23617 // The AND is not necessary, remove it.
23618
23619 SDVTList VTs = DAG.getVTList(SubsNode->getValueType(0),
23620 SubsNode->getValueType(1));
23621 SDValue Ops[] = { AddValue, SubsNode->getOperand(1) };
23622
23623 SDValue NewValue = DAG.getNode(CondOpcode, SDLoc(SubsNode), VTs, Ops);
23624 DAG.ReplaceAllUsesWith(SubsNode, NewValue.getNode());
23625
23626 return SDValue(N, 0);
23627 }
23628
23629 // Optimize compare with zero and branch.
performBRCONDCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)23630 static SDValue performBRCONDCombine(SDNode *N,
23631 TargetLowering::DAGCombinerInfo &DCI,
23632 SelectionDAG &DAG) {
23633 MachineFunction &MF = DAG.getMachineFunction();
23634 // Speculation tracking/SLH assumes that optimized TB(N)Z/CB(N)Z instructions
23635 // will not be produced, as they are conditional branch instructions that do
23636 // not set flags.
23637 if (MF.getFunction().hasFnAttribute(Attribute::SpeculativeLoadHardening))
23638 return SDValue();
23639
23640 if (SDValue NV = performCONDCombine(N, DCI, DAG, 2, 3))
23641 N = NV.getNode();
23642 SDValue Chain = N->getOperand(0);
23643 SDValue Dest = N->getOperand(1);
23644 SDValue CCVal = N->getOperand(2);
23645 SDValue Cmp = N->getOperand(3);
23646
23647 assert(isa<ConstantSDNode>(CCVal) && "Expected a ConstantSDNode here!");
23648 unsigned CC = CCVal->getAsZExtVal();
23649 if (CC != AArch64CC::EQ && CC != AArch64CC::NE)
23650 return SDValue();
23651
23652 unsigned CmpOpc = Cmp.getOpcode();
23653 if (CmpOpc != AArch64ISD::ADDS && CmpOpc != AArch64ISD::SUBS)
23654 return SDValue();
23655
23656 // Only attempt folding if there is only one use of the flag and no use of the
23657 // value.
23658 if (!Cmp->hasNUsesOfValue(0, 0) || !Cmp->hasNUsesOfValue(1, 1))
23659 return SDValue();
23660
23661 SDValue LHS = Cmp.getOperand(0);
23662 SDValue RHS = Cmp.getOperand(1);
23663
23664 assert(LHS.getValueType() == RHS.getValueType() &&
23665 "Expected the value type to be the same for both operands!");
23666 if (LHS.getValueType() != MVT::i32 && LHS.getValueType() != MVT::i64)
23667 return SDValue();
23668
23669 if (isNullConstant(LHS))
23670 std::swap(LHS, RHS);
23671
23672 if (!isNullConstant(RHS))
23673 return SDValue();
23674
23675 if (LHS.getOpcode() == ISD::SHL || LHS.getOpcode() == ISD::SRA ||
23676 LHS.getOpcode() == ISD::SRL)
23677 return SDValue();
23678
23679 // Fold the compare into the branch instruction.
23680 SDValue BR;
23681 if (CC == AArch64CC::EQ)
23682 BR = DAG.getNode(AArch64ISD::CBZ, SDLoc(N), MVT::Other, Chain, LHS, Dest);
23683 else
23684 BR = DAG.getNode(AArch64ISD::CBNZ, SDLoc(N), MVT::Other, Chain, LHS, Dest);
23685
23686 // Do not add new nodes to DAG combiner worklist.
23687 DCI.CombineTo(N, BR, false);
23688
23689 return SDValue();
23690 }
23691
foldCSELofCTTZ(SDNode * N,SelectionDAG & DAG)23692 static SDValue foldCSELofCTTZ(SDNode *N, SelectionDAG &DAG) {
23693 unsigned CC = N->getConstantOperandVal(2);
23694 SDValue SUBS = N->getOperand(3);
23695 SDValue Zero, CTTZ;
23696
23697 if (CC == AArch64CC::EQ && SUBS.getOpcode() == AArch64ISD::SUBS) {
23698 Zero = N->getOperand(0);
23699 CTTZ = N->getOperand(1);
23700 } else if (CC == AArch64CC::NE && SUBS.getOpcode() == AArch64ISD::SUBS) {
23701 Zero = N->getOperand(1);
23702 CTTZ = N->getOperand(0);
23703 } else
23704 return SDValue();
23705
23706 if ((CTTZ.getOpcode() != ISD::CTTZ && CTTZ.getOpcode() != ISD::TRUNCATE) ||
23707 (CTTZ.getOpcode() == ISD::TRUNCATE &&
23708 CTTZ.getOperand(0).getOpcode() != ISD::CTTZ))
23709 return SDValue();
23710
23711 assert((CTTZ.getValueType() == MVT::i32 || CTTZ.getValueType() == MVT::i64) &&
23712 "Illegal type in CTTZ folding");
23713
23714 if (!isNullConstant(Zero) || !isNullConstant(SUBS.getOperand(1)))
23715 return SDValue();
23716
23717 SDValue X = CTTZ.getOpcode() == ISD::TRUNCATE
23718 ? CTTZ.getOperand(0).getOperand(0)
23719 : CTTZ.getOperand(0);
23720
23721 if (X != SUBS.getOperand(0))
23722 return SDValue();
23723
23724 unsigned BitWidth = CTTZ.getOpcode() == ISD::TRUNCATE
23725 ? CTTZ.getOperand(0).getValueSizeInBits()
23726 : CTTZ.getValueSizeInBits();
23727 SDValue BitWidthMinusOne =
23728 DAG.getConstant(BitWidth - 1, SDLoc(N), CTTZ.getValueType());
23729 return DAG.getNode(ISD::AND, SDLoc(N), CTTZ.getValueType(), CTTZ,
23730 BitWidthMinusOne);
23731 }
23732
23733 // (CSEL l r EQ (CMP (CSEL x y cc2 cond) x)) => (CSEL l r cc2 cond)
23734 // (CSEL l r EQ (CMP (CSEL x y cc2 cond) y)) => (CSEL l r !cc2 cond)
23735 // Where x and y are constants and x != y
23736
23737 // (CSEL l r NE (CMP (CSEL x y cc2 cond) x)) => (CSEL l r !cc2 cond)
23738 // (CSEL l r NE (CMP (CSEL x y cc2 cond) y)) => (CSEL l r cc2 cond)
23739 // Where x and y are constants and x != y
foldCSELOfCSEL(SDNode * Op,SelectionDAG & DAG)23740 static SDValue foldCSELOfCSEL(SDNode *Op, SelectionDAG &DAG) {
23741 SDValue L = Op->getOperand(0);
23742 SDValue R = Op->getOperand(1);
23743 AArch64CC::CondCode OpCC =
23744 static_cast<AArch64CC::CondCode>(Op->getConstantOperandVal(2));
23745
23746 SDValue OpCmp = Op->getOperand(3);
23747 if (!isCMP(OpCmp))
23748 return SDValue();
23749
23750 SDValue CmpLHS = OpCmp.getOperand(0);
23751 SDValue CmpRHS = OpCmp.getOperand(1);
23752
23753 if (CmpRHS.getOpcode() == AArch64ISD::CSEL)
23754 std::swap(CmpLHS, CmpRHS);
23755 else if (CmpLHS.getOpcode() != AArch64ISD::CSEL)
23756 return SDValue();
23757
23758 SDValue X = CmpLHS->getOperand(0);
23759 SDValue Y = CmpLHS->getOperand(1);
23760 if (!isa<ConstantSDNode>(X) || !isa<ConstantSDNode>(Y) || X == Y) {
23761 return SDValue();
23762 }
23763
23764 // If one of the constant is opaque constant, x,y sdnode is still different
23765 // but the real value maybe the same. So check APInt here to make sure the
23766 // code is correct.
23767 ConstantSDNode *CX = cast<ConstantSDNode>(X);
23768 ConstantSDNode *CY = cast<ConstantSDNode>(Y);
23769 if (CX->getAPIntValue() == CY->getAPIntValue())
23770 return SDValue();
23771
23772 AArch64CC::CondCode CC =
23773 static_cast<AArch64CC::CondCode>(CmpLHS->getConstantOperandVal(2));
23774 SDValue Cond = CmpLHS->getOperand(3);
23775
23776 if (CmpRHS == Y)
23777 CC = AArch64CC::getInvertedCondCode(CC);
23778 else if (CmpRHS != X)
23779 return SDValue();
23780
23781 if (OpCC == AArch64CC::NE)
23782 CC = AArch64CC::getInvertedCondCode(CC);
23783 else if (OpCC != AArch64CC::EQ)
23784 return SDValue();
23785
23786 SDLoc DL(Op);
23787 EVT VT = Op->getValueType(0);
23788
23789 SDValue CCValue = DAG.getConstant(CC, DL, MVT::i32);
23790 return DAG.getNode(AArch64ISD::CSEL, DL, VT, L, R, CCValue, Cond);
23791 }
23792
23793 // Optimize CSEL instructions
performCSELCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)23794 static SDValue performCSELCombine(SDNode *N,
23795 TargetLowering::DAGCombinerInfo &DCI,
23796 SelectionDAG &DAG) {
23797 // CSEL x, x, cc -> x
23798 if (N->getOperand(0) == N->getOperand(1))
23799 return N->getOperand(0);
23800
23801 if (SDValue R = foldCSELOfCSEL(N, DAG))
23802 return R;
23803
23804 // CSEL 0, cttz(X), eq(X, 0) -> AND cttz bitwidth-1
23805 // CSEL cttz(X), 0, ne(X, 0) -> AND cttz bitwidth-1
23806 if (SDValue Folded = foldCSELofCTTZ(N, DAG))
23807 return Folded;
23808
23809 return performCONDCombine(N, DCI, DAG, 2, 3);
23810 }
23811
23812 // Try to re-use an already extended operand of a vector SetCC feeding a
23813 // extended select. Doing so avoids requiring another full extension of the
23814 // SET_CC result when lowering the select.
tryToWidenSetCCOperands(SDNode * Op,SelectionDAG & DAG)23815 static SDValue tryToWidenSetCCOperands(SDNode *Op, SelectionDAG &DAG) {
23816 EVT Op0MVT = Op->getOperand(0).getValueType();
23817 if (!Op0MVT.isVector() || Op->use_empty())
23818 return SDValue();
23819
23820 // Make sure that all uses of Op are VSELECTs with result matching types where
23821 // the result type has a larger element type than the SetCC operand.
23822 SDNode *FirstUse = *Op->use_begin();
23823 if (FirstUse->getOpcode() != ISD::VSELECT)
23824 return SDValue();
23825 EVT UseMVT = FirstUse->getValueType(0);
23826 if (UseMVT.getScalarSizeInBits() <= Op0MVT.getScalarSizeInBits())
23827 return SDValue();
23828 if (any_of(Op->uses(), [&UseMVT](const SDNode *N) {
23829 return N->getOpcode() != ISD::VSELECT || N->getValueType(0) != UseMVT;
23830 }))
23831 return SDValue();
23832
23833 APInt V;
23834 if (!ISD::isConstantSplatVector(Op->getOperand(1).getNode(), V))
23835 return SDValue();
23836
23837 SDLoc DL(Op);
23838 SDValue Op0ExtV;
23839 SDValue Op1ExtV;
23840 ISD::CondCode CC = cast<CondCodeSDNode>(Op->getOperand(2))->get();
23841 // Check if the first operand of the SET_CC is already extended. If it is,
23842 // split the SET_CC and re-use the extended version of the operand.
23843 SDNode *Op0SExt = DAG.getNodeIfExists(ISD::SIGN_EXTEND, DAG.getVTList(UseMVT),
23844 Op->getOperand(0));
23845 SDNode *Op0ZExt = DAG.getNodeIfExists(ISD::ZERO_EXTEND, DAG.getVTList(UseMVT),
23846 Op->getOperand(0));
23847 if (Op0SExt && (isSignedIntSetCC(CC) || isIntEqualitySetCC(CC))) {
23848 Op0ExtV = SDValue(Op0SExt, 0);
23849 Op1ExtV = DAG.getNode(ISD::SIGN_EXTEND, DL, UseMVT, Op->getOperand(1));
23850 } else if (Op0ZExt && (isUnsignedIntSetCC(CC) || isIntEqualitySetCC(CC))) {
23851 Op0ExtV = SDValue(Op0ZExt, 0);
23852 Op1ExtV = DAG.getNode(ISD::ZERO_EXTEND, DL, UseMVT, Op->getOperand(1));
23853 } else
23854 return SDValue();
23855
23856 return DAG.getNode(ISD::SETCC, DL, UseMVT.changeVectorElementType(MVT::i1),
23857 Op0ExtV, Op1ExtV, Op->getOperand(2));
23858 }
23859
23860 static SDValue
performVecReduceBitwiseCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)23861 performVecReduceBitwiseCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
23862 SelectionDAG &DAG) {
23863 SDValue Vec = N->getOperand(0);
23864 if (DCI.isBeforeLegalize() &&
23865 Vec.getValueType().getVectorElementType() == MVT::i1 &&
23866 Vec.getValueType().isFixedLengthVector() &&
23867 Vec.getValueType().isPow2VectorType()) {
23868 SDLoc DL(N);
23869 return getVectorBitwiseReduce(N->getOpcode(), Vec, N->getValueType(0), DL,
23870 DAG);
23871 }
23872
23873 return SDValue();
23874 }
23875
performSETCCCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)23876 static SDValue performSETCCCombine(SDNode *N,
23877 TargetLowering::DAGCombinerInfo &DCI,
23878 SelectionDAG &DAG) {
23879 assert(N->getOpcode() == ISD::SETCC && "Unexpected opcode!");
23880 SDValue LHS = N->getOperand(0);
23881 SDValue RHS = N->getOperand(1);
23882 ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(2))->get();
23883 SDLoc DL(N);
23884 EVT VT = N->getValueType(0);
23885
23886 if (SDValue V = tryToWidenSetCCOperands(N, DAG))
23887 return V;
23888
23889 // setcc (csel 0, 1, cond, X), 1, ne ==> csel 0, 1, !cond, X
23890 if (Cond == ISD::SETNE && isOneConstant(RHS) &&
23891 LHS->getOpcode() == AArch64ISD::CSEL &&
23892 isNullConstant(LHS->getOperand(0)) && isOneConstant(LHS->getOperand(1)) &&
23893 LHS->hasOneUse()) {
23894 // Invert CSEL's condition.
23895 auto OldCond =
23896 static_cast<AArch64CC::CondCode>(LHS.getConstantOperandVal(2));
23897 auto NewCond = getInvertedCondCode(OldCond);
23898
23899 // csel 0, 1, !cond, X
23900 SDValue CSEL =
23901 DAG.getNode(AArch64ISD::CSEL, DL, LHS.getValueType(), LHS.getOperand(0),
23902 LHS.getOperand(1), DAG.getConstant(NewCond, DL, MVT::i32),
23903 LHS.getOperand(3));
23904 return DAG.getZExtOrTrunc(CSEL, DL, VT);
23905 }
23906
23907 // setcc (srl x, imm), 0, ne ==> setcc (and x, (-1 << imm)), 0, ne
23908 if (Cond == ISD::SETNE && isNullConstant(RHS) &&
23909 LHS->getOpcode() == ISD::SRL && isa<ConstantSDNode>(LHS->getOperand(1)) &&
23910 LHS->getConstantOperandVal(1) < VT.getScalarSizeInBits() &&
23911 LHS->hasOneUse()) {
23912 EVT TstVT = LHS->getValueType(0);
23913 if (TstVT.isScalarInteger() && TstVT.getFixedSizeInBits() <= 64) {
23914 // this pattern will get better opt in emitComparison
23915 uint64_t TstImm = -1ULL << LHS->getConstantOperandVal(1);
23916 SDValue TST = DAG.getNode(ISD::AND, DL, TstVT, LHS->getOperand(0),
23917 DAG.getConstant(TstImm, DL, TstVT));
23918 return DAG.getNode(ISD::SETCC, DL, VT, TST, RHS, N->getOperand(2));
23919 }
23920 }
23921
23922 // setcc (iN (bitcast (vNi1 X))), 0, (eq|ne)
23923 // ==> setcc (iN (zext (i1 (vecreduce_or (vNi1 X))))), 0, (eq|ne)
23924 // setcc (iN (bitcast (vNi1 X))), -1, (eq|ne)
23925 // ==> setcc (iN (sext (i1 (vecreduce_and (vNi1 X))))), -1, (eq|ne)
23926 if (DCI.isBeforeLegalize() && VT.isScalarInteger() &&
23927 (Cond == ISD::SETEQ || Cond == ISD::SETNE) &&
23928 (isNullConstant(RHS) || isAllOnesConstant(RHS)) &&
23929 LHS->getOpcode() == ISD::BITCAST) {
23930 EVT ToVT = LHS->getValueType(0);
23931 EVT FromVT = LHS->getOperand(0).getValueType();
23932 if (FromVT.isFixedLengthVector() &&
23933 FromVT.getVectorElementType() == MVT::i1) {
23934 bool IsNull = isNullConstant(RHS);
23935 LHS = DAG.getNode(IsNull ? ISD::VECREDUCE_OR : ISD::VECREDUCE_AND,
23936 DL, MVT::i1, LHS->getOperand(0));
23937 LHS = DAG.getNode(IsNull ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND, DL, ToVT,
23938 LHS);
23939 return DAG.getSetCC(DL, VT, LHS, RHS, Cond);
23940 }
23941 }
23942
23943 // Try to perform the memcmp when the result is tested for [in]equality with 0
23944 if (SDValue V = performOrXorChainCombine(N, DAG))
23945 return V;
23946
23947 return SDValue();
23948 }
23949
23950 // Replace a flag-setting operator (eg ANDS) with the generic version
23951 // (eg AND) if the flag is unused.
performFlagSettingCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,unsigned GenericOpcode)23952 static SDValue performFlagSettingCombine(SDNode *N,
23953 TargetLowering::DAGCombinerInfo &DCI,
23954 unsigned GenericOpcode) {
23955 SDLoc DL(N);
23956 SDValue LHS = N->getOperand(0);
23957 SDValue RHS = N->getOperand(1);
23958 EVT VT = N->getValueType(0);
23959
23960 // If the flag result isn't used, convert back to a generic opcode.
23961 if (!N->hasAnyUseOfValue(1)) {
23962 SDValue Res = DCI.DAG.getNode(GenericOpcode, DL, VT, N->ops());
23963 return DCI.DAG.getMergeValues({Res, DCI.DAG.getConstant(0, DL, MVT::i32)},
23964 DL);
23965 }
23966
23967 // Combine identical generic nodes into this node, re-using the result.
23968 if (SDNode *Generic = DCI.DAG.getNodeIfExists(
23969 GenericOpcode, DCI.DAG.getVTList(VT), {LHS, RHS}))
23970 DCI.CombineTo(Generic, SDValue(N, 0));
23971
23972 return SDValue();
23973 }
23974
performSetCCPunpkCombine(SDNode * N,SelectionDAG & DAG)23975 static SDValue performSetCCPunpkCombine(SDNode *N, SelectionDAG &DAG) {
23976 // setcc_merge_zero pred
23977 // (sign_extend (extract_subvector (setcc_merge_zero ... pred ...))), 0, ne
23978 // => extract_subvector (inner setcc_merge_zero)
23979 SDValue Pred = N->getOperand(0);
23980 SDValue LHS = N->getOperand(1);
23981 SDValue RHS = N->getOperand(2);
23982 ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(3))->get();
23983
23984 if (Cond != ISD::SETNE || !isZerosVector(RHS.getNode()) ||
23985 LHS->getOpcode() != ISD::SIGN_EXTEND)
23986 return SDValue();
23987
23988 SDValue Extract = LHS->getOperand(0);
23989 if (Extract->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
23990 Extract->getValueType(0) != N->getValueType(0) ||
23991 Extract->getConstantOperandVal(1) != 0)
23992 return SDValue();
23993
23994 SDValue InnerSetCC = Extract->getOperand(0);
23995 if (InnerSetCC->getOpcode() != AArch64ISD::SETCC_MERGE_ZERO)
23996 return SDValue();
23997
23998 // By this point we've effectively got
23999 // zero_inactive_lanes_and_trunc_i1(sext_i1(A)). If we can prove A's inactive
24000 // lanes are already zero then the trunc(sext()) sequence is redundant and we
24001 // can operate on A directly.
24002 SDValue InnerPred = InnerSetCC.getOperand(0);
24003 if (Pred.getOpcode() == AArch64ISD::PTRUE &&
24004 InnerPred.getOpcode() == AArch64ISD::PTRUE &&
24005 Pred.getConstantOperandVal(0) == InnerPred.getConstantOperandVal(0) &&
24006 Pred->getConstantOperandVal(0) >= AArch64SVEPredPattern::vl1 &&
24007 Pred->getConstantOperandVal(0) <= AArch64SVEPredPattern::vl256)
24008 return Extract;
24009
24010 return SDValue();
24011 }
24012
24013 static SDValue
performSetccMergeZeroCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)24014 performSetccMergeZeroCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
24015 assert(N->getOpcode() == AArch64ISD::SETCC_MERGE_ZERO &&
24016 "Unexpected opcode!");
24017
24018 SelectionDAG &DAG = DCI.DAG;
24019 SDValue Pred = N->getOperand(0);
24020 SDValue LHS = N->getOperand(1);
24021 SDValue RHS = N->getOperand(2);
24022 ISD::CondCode Cond = cast<CondCodeSDNode>(N->getOperand(3))->get();
24023
24024 if (SDValue V = performSetCCPunpkCombine(N, DAG))
24025 return V;
24026
24027 if (Cond == ISD::SETNE && isZerosVector(RHS.getNode()) &&
24028 LHS->getOpcode() == ISD::SIGN_EXTEND &&
24029 LHS->getOperand(0)->getValueType(0) == N->getValueType(0)) {
24030 // setcc_merge_zero(
24031 // pred, extend(setcc_merge_zero(pred, ...)), != splat(0))
24032 // => setcc_merge_zero(pred, ...)
24033 if (LHS->getOperand(0)->getOpcode() == AArch64ISD::SETCC_MERGE_ZERO &&
24034 LHS->getOperand(0)->getOperand(0) == Pred)
24035 return LHS->getOperand(0);
24036
24037 // setcc_merge_zero(
24038 // all_active, extend(nxvNi1 ...), != splat(0))
24039 // -> nxvNi1 ...
24040 if (isAllActivePredicate(DAG, Pred))
24041 return LHS->getOperand(0);
24042
24043 // setcc_merge_zero(
24044 // pred, extend(nxvNi1 ...), != splat(0))
24045 // -> nxvNi1 and(pred, ...)
24046 if (DCI.isAfterLegalizeDAG())
24047 // Do this after legalization to allow more folds on setcc_merge_zero
24048 // to be recognized.
24049 return DAG.getNode(ISD::AND, SDLoc(N), N->getValueType(0),
24050 LHS->getOperand(0), Pred);
24051 }
24052
24053 return SDValue();
24054 }
24055
24056 // Optimize some simple tbz/tbnz cases. Returns the new operand and bit to test
24057 // as well as whether the test should be inverted. This code is required to
24058 // catch these cases (as opposed to standard dag combines) because
24059 // AArch64ISD::TBZ is matched during legalization.
getTestBitOperand(SDValue Op,unsigned & Bit,bool & Invert,SelectionDAG & DAG)24060 static SDValue getTestBitOperand(SDValue Op, unsigned &Bit, bool &Invert,
24061 SelectionDAG &DAG) {
24062
24063 if (!Op->hasOneUse())
24064 return Op;
24065
24066 // We don't handle undef/constant-fold cases below, as they should have
24067 // already been taken care of (e.g. and of 0, test of undefined shifted bits,
24068 // etc.)
24069
24070 // (tbz (trunc x), b) -> (tbz x, b)
24071 // This case is just here to enable more of the below cases to be caught.
24072 if (Op->getOpcode() == ISD::TRUNCATE &&
24073 Bit < Op->getValueType(0).getSizeInBits()) {
24074 return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
24075 }
24076
24077 // (tbz (any_ext x), b) -> (tbz x, b) if we don't use the extended bits.
24078 if (Op->getOpcode() == ISD::ANY_EXTEND &&
24079 Bit < Op->getOperand(0).getValueSizeInBits()) {
24080 return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
24081 }
24082
24083 if (Op->getNumOperands() != 2)
24084 return Op;
24085
24086 auto *C = dyn_cast<ConstantSDNode>(Op->getOperand(1));
24087 if (!C)
24088 return Op;
24089
24090 switch (Op->getOpcode()) {
24091 default:
24092 return Op;
24093
24094 // (tbz (and x, m), b) -> (tbz x, b)
24095 case ISD::AND:
24096 if ((C->getZExtValue() >> Bit) & 1)
24097 return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
24098 return Op;
24099
24100 // (tbz (shl x, c), b) -> (tbz x, b-c)
24101 case ISD::SHL:
24102 if (C->getZExtValue() <= Bit &&
24103 (Bit - C->getZExtValue()) < Op->getValueType(0).getSizeInBits()) {
24104 Bit = Bit - C->getZExtValue();
24105 return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
24106 }
24107 return Op;
24108
24109 // (tbz (sra x, c), b) -> (tbz x, b+c) or (tbz x, msb) if b+c is > # bits in x
24110 case ISD::SRA:
24111 Bit = Bit + C->getZExtValue();
24112 if (Bit >= Op->getValueType(0).getSizeInBits())
24113 Bit = Op->getValueType(0).getSizeInBits() - 1;
24114 return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
24115
24116 // (tbz (srl x, c), b) -> (tbz x, b+c)
24117 case ISD::SRL:
24118 if ((Bit + C->getZExtValue()) < Op->getValueType(0).getSizeInBits()) {
24119 Bit = Bit + C->getZExtValue();
24120 return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
24121 }
24122 return Op;
24123
24124 // (tbz (xor x, -1), b) -> (tbnz x, b)
24125 case ISD::XOR:
24126 if ((C->getZExtValue() >> Bit) & 1)
24127 Invert = !Invert;
24128 return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG);
24129 }
24130 }
24131
24132 // Optimize test single bit zero/non-zero and branch.
performTBZCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)24133 static SDValue performTBZCombine(SDNode *N,
24134 TargetLowering::DAGCombinerInfo &DCI,
24135 SelectionDAG &DAG) {
24136 unsigned Bit = N->getConstantOperandVal(2);
24137 bool Invert = false;
24138 SDValue TestSrc = N->getOperand(1);
24139 SDValue NewTestSrc = getTestBitOperand(TestSrc, Bit, Invert, DAG);
24140
24141 if (TestSrc == NewTestSrc)
24142 return SDValue();
24143
24144 unsigned NewOpc = N->getOpcode();
24145 if (Invert) {
24146 if (NewOpc == AArch64ISD::TBZ)
24147 NewOpc = AArch64ISD::TBNZ;
24148 else {
24149 assert(NewOpc == AArch64ISD::TBNZ);
24150 NewOpc = AArch64ISD::TBZ;
24151 }
24152 }
24153
24154 SDLoc DL(N);
24155 return DAG.getNode(NewOpc, DL, MVT::Other, N->getOperand(0), NewTestSrc,
24156 DAG.getConstant(Bit, DL, MVT::i64), N->getOperand(3));
24157 }
24158
24159 // Swap vselect operands where it may allow a predicated operation to achieve
24160 // the `sel`.
24161 //
24162 // (vselect (setcc ( condcode) (_) (_)) (a) (op (a) (b)))
24163 // => (vselect (setcc (!condcode) (_) (_)) (op (a) (b)) (a))
trySwapVSelectOperands(SDNode * N,SelectionDAG & DAG)24164 static SDValue trySwapVSelectOperands(SDNode *N, SelectionDAG &DAG) {
24165 auto SelectA = N->getOperand(1);
24166 auto SelectB = N->getOperand(2);
24167 auto NTy = N->getValueType(0);
24168
24169 if (!NTy.isScalableVector())
24170 return SDValue();
24171 SDValue SetCC = N->getOperand(0);
24172 if (SetCC.getOpcode() != ISD::SETCC || !SetCC.hasOneUse())
24173 return SDValue();
24174
24175 switch (SelectB.getOpcode()) {
24176 default:
24177 return SDValue();
24178 case ISD::FMUL:
24179 case ISD::FSUB:
24180 case ISD::FADD:
24181 break;
24182 }
24183 if (SelectA != SelectB.getOperand(0))
24184 return SDValue();
24185
24186 ISD::CondCode CC = cast<CondCodeSDNode>(SetCC.getOperand(2))->get();
24187 ISD::CondCode InverseCC =
24188 ISD::getSetCCInverse(CC, SetCC.getOperand(0).getValueType());
24189 auto InverseSetCC =
24190 DAG.getSetCC(SDLoc(SetCC), SetCC.getValueType(), SetCC.getOperand(0),
24191 SetCC.getOperand(1), InverseCC);
24192
24193 return DAG.getNode(ISD::VSELECT, SDLoc(N), NTy,
24194 {InverseSetCC, SelectB, SelectA});
24195 }
24196
24197 // vselect (v1i1 setcc) ->
24198 // vselect (v1iXX setcc) (XX is the size of the compared operand type)
24199 // FIXME: Currently the type legalizer can't handle VSELECT having v1i1 as
24200 // condition. If it can legalize "VSELECT v1i1" correctly, no need to combine
24201 // such VSELECT.
performVSelectCombine(SDNode * N,SelectionDAG & DAG)24202 static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
24203 if (auto SwapResult = trySwapVSelectOperands(N, DAG))
24204 return SwapResult;
24205
24206 SDValue N0 = N->getOperand(0);
24207 EVT CCVT = N0.getValueType();
24208
24209 if (isAllActivePredicate(DAG, N0))
24210 return N->getOperand(1);
24211
24212 if (isAllInactivePredicate(N0))
24213 return N->getOperand(2);
24214
24215 // Check for sign pattern (VSELECT setgt, iN lhs, -1, 1, -1) and transform
24216 // into (OR (ASR lhs, N-1), 1), which requires less instructions for the
24217 // supported types.
24218 SDValue SetCC = N->getOperand(0);
24219 if (SetCC.getOpcode() == ISD::SETCC &&
24220 SetCC.getOperand(2) == DAG.getCondCode(ISD::SETGT)) {
24221 SDValue CmpLHS = SetCC.getOperand(0);
24222 EVT VT = CmpLHS.getValueType();
24223 SDNode *CmpRHS = SetCC.getOperand(1).getNode();
24224 SDNode *SplatLHS = N->getOperand(1).getNode();
24225 SDNode *SplatRHS = N->getOperand(2).getNode();
24226 APInt SplatLHSVal;
24227 if (CmpLHS.getValueType() == N->getOperand(1).getValueType() &&
24228 VT.isSimple() &&
24229 is_contained(ArrayRef({MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16,
24230 MVT::v2i32, MVT::v4i32, MVT::v2i64}),
24231 VT.getSimpleVT().SimpleTy) &&
24232 ISD::isConstantSplatVector(SplatLHS, SplatLHSVal) &&
24233 SplatLHSVal.isOne() && ISD::isConstantSplatVectorAllOnes(CmpRHS) &&
24234 ISD::isConstantSplatVectorAllOnes(SplatRHS)) {
24235 unsigned NumElts = VT.getVectorNumElements();
24236 SmallVector<SDValue, 8> Ops(
24237 NumElts, DAG.getConstant(VT.getScalarSizeInBits() - 1, SDLoc(N),
24238 VT.getScalarType()));
24239 SDValue Val = DAG.getBuildVector(VT, SDLoc(N), Ops);
24240
24241 auto Shift = DAG.getNode(ISD::SRA, SDLoc(N), VT, CmpLHS, Val);
24242 auto Or = DAG.getNode(ISD::OR, SDLoc(N), VT, Shift, N->getOperand(1));
24243 return Or;
24244 }
24245 }
24246
24247 EVT CmpVT = N0.getOperand(0).getValueType();
24248 if (N0.getOpcode() != ISD::SETCC ||
24249 CCVT.getVectorElementCount() != ElementCount::getFixed(1) ||
24250 CCVT.getVectorElementType() != MVT::i1 ||
24251 CmpVT.getVectorElementType().isFloatingPoint())
24252 return SDValue();
24253
24254 EVT ResVT = N->getValueType(0);
24255 // Only combine when the result type is of the same size as the compared
24256 // operands.
24257 if (ResVT.getSizeInBits() != CmpVT.getSizeInBits())
24258 return SDValue();
24259
24260 SDValue IfTrue = N->getOperand(1);
24261 SDValue IfFalse = N->getOperand(2);
24262 SetCC = DAG.getSetCC(SDLoc(N), CmpVT.changeVectorElementTypeToInteger(),
24263 N0.getOperand(0), N0.getOperand(1),
24264 cast<CondCodeSDNode>(N0.getOperand(2))->get());
24265 return DAG.getNode(ISD::VSELECT, SDLoc(N), ResVT, SetCC,
24266 IfTrue, IfFalse);
24267 }
24268
24269 /// A vector select: "(select vL, vR, (setcc LHS, RHS))" is best performed with
24270 /// the compare-mask instructions rather than going via NZCV, even if LHS and
24271 /// RHS are really scalar. This replaces any scalar setcc in the above pattern
24272 /// with a vector one followed by a DUP shuffle on the result.
performSelectCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)24273 static SDValue performSelectCombine(SDNode *N,
24274 TargetLowering::DAGCombinerInfo &DCI) {
24275 SelectionDAG &DAG = DCI.DAG;
24276 SDValue N0 = N->getOperand(0);
24277 EVT ResVT = N->getValueType(0);
24278
24279 if (N0.getOpcode() != ISD::SETCC)
24280 return SDValue();
24281
24282 if (ResVT.isScalableVT())
24283 return SDValue();
24284
24285 // Make sure the SETCC result is either i1 (initial DAG), or i32, the lowered
24286 // scalar SetCCResultType. We also don't expect vectors, because we assume
24287 // that selects fed by vector SETCCs are canonicalized to VSELECT.
24288 assert((N0.getValueType() == MVT::i1 || N0.getValueType() == MVT::i32) &&
24289 "Scalar-SETCC feeding SELECT has unexpected result type!");
24290
24291 // If NumMaskElts == 0, the comparison is larger than select result. The
24292 // largest real NEON comparison is 64-bits per lane, which means the result is
24293 // at most 32-bits and an illegal vector. Just bail out for now.
24294 EVT SrcVT = N0.getOperand(0).getValueType();
24295
24296 // Don't try to do this optimization when the setcc itself has i1 operands.
24297 // There are no legal vectors of i1, so this would be pointless. v1f16 is
24298 // ruled out to prevent the creation of setcc that need to be scalarized.
24299 if (SrcVT == MVT::i1 ||
24300 (SrcVT.isFloatingPoint() && SrcVT.getSizeInBits() <= 16))
24301 return SDValue();
24302
24303 int NumMaskElts = ResVT.getSizeInBits() / SrcVT.getSizeInBits();
24304 if (!ResVT.isVector() || NumMaskElts == 0)
24305 return SDValue();
24306
24307 SrcVT = EVT::getVectorVT(*DAG.getContext(), SrcVT, NumMaskElts);
24308 EVT CCVT = SrcVT.changeVectorElementTypeToInteger();
24309
24310 // Also bail out if the vector CCVT isn't the same size as ResVT.
24311 // This can happen if the SETCC operand size doesn't divide the ResVT size
24312 // (e.g., f64 vs v3f32).
24313 if (CCVT.getSizeInBits() != ResVT.getSizeInBits())
24314 return SDValue();
24315
24316 // Make sure we didn't create illegal types, if we're not supposed to.
24317 assert(DCI.isBeforeLegalize() ||
24318 DAG.getTargetLoweringInfo().isTypeLegal(SrcVT));
24319
24320 // First perform a vector comparison, where lane 0 is the one we're interested
24321 // in.
24322 SDLoc DL(N0);
24323 SDValue LHS =
24324 DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, SrcVT, N0.getOperand(0));
24325 SDValue RHS =
24326 DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, SrcVT, N0.getOperand(1));
24327 SDValue SetCC = DAG.getNode(ISD::SETCC, DL, CCVT, LHS, RHS, N0.getOperand(2));
24328
24329 // Now duplicate the comparison mask we want across all other lanes.
24330 SmallVector<int, 8> DUPMask(CCVT.getVectorNumElements(), 0);
24331 SDValue Mask = DAG.getVectorShuffle(CCVT, DL, SetCC, SetCC, DUPMask);
24332 Mask = DAG.getNode(ISD::BITCAST, DL,
24333 ResVT.changeVectorElementTypeToInteger(), Mask);
24334
24335 return DAG.getSelect(DL, ResVT, Mask, N->getOperand(1), N->getOperand(2));
24336 }
24337
performDUPCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)24338 static SDValue performDUPCombine(SDNode *N,
24339 TargetLowering::DAGCombinerInfo &DCI) {
24340 EVT VT = N->getValueType(0);
24341 SDLoc DL(N);
24342 // If "v2i32 DUP(x)" and "v4i32 DUP(x)" both exist, use an extract from the
24343 // 128bit vector version.
24344 if (VT.is64BitVector() && DCI.isAfterLegalizeDAG()) {
24345 EVT LVT = VT.getDoubleNumVectorElementsVT(*DCI.DAG.getContext());
24346 SmallVector<SDValue> Ops(N->ops());
24347 if (SDNode *LN = DCI.DAG.getNodeIfExists(N->getOpcode(),
24348 DCI.DAG.getVTList(LVT), Ops)) {
24349 return DCI.DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, SDValue(LN, 0),
24350 DCI.DAG.getConstant(0, DL, MVT::i64));
24351 }
24352 }
24353
24354 if (N->getOpcode() == AArch64ISD::DUP) {
24355 if (DCI.isAfterLegalizeDAG()) {
24356 // If scalar dup's operand is extract_vector_elt, try to combine them into
24357 // duplane. For example,
24358 //
24359 // t21: i32 = extract_vector_elt t19, Constant:i64<0>
24360 // t18: v4i32 = AArch64ISD::DUP t21
24361 // ==>
24362 // t22: v4i32 = AArch64ISD::DUPLANE32 t19, Constant:i64<0>
24363 SDValue EXTRACT_VEC_ELT = N->getOperand(0);
24364 if (EXTRACT_VEC_ELT.getOpcode() == ISD::EXTRACT_VECTOR_ELT) {
24365 if (VT == EXTRACT_VEC_ELT.getOperand(0).getValueType()) {
24366 unsigned Opcode = getDUPLANEOp(VT.getVectorElementType());
24367 return DCI.DAG.getNode(Opcode, DL, VT, EXTRACT_VEC_ELT.getOperand(0),
24368 EXTRACT_VEC_ELT.getOperand(1));
24369 }
24370 }
24371 }
24372
24373 return performPostLD1Combine(N, DCI, false);
24374 }
24375
24376 return SDValue();
24377 }
24378
24379 /// Get rid of unnecessary NVCASTs (that don't change the type).
performNVCASTCombine(SDNode * N,SelectionDAG & DAG)24380 static SDValue performNVCASTCombine(SDNode *N, SelectionDAG &DAG) {
24381 if (N->getValueType(0) == N->getOperand(0).getValueType())
24382 return N->getOperand(0);
24383 if (N->getOperand(0).getOpcode() == AArch64ISD::NVCAST)
24384 return DAG.getNode(AArch64ISD::NVCAST, SDLoc(N), N->getValueType(0),
24385 N->getOperand(0).getOperand(0));
24386
24387 return SDValue();
24388 }
24389
24390 // If all users of the globaladdr are of the form (globaladdr + constant), find
24391 // the smallest constant, fold it into the globaladdr's offset and rewrite the
24392 // globaladdr as (globaladdr + constant) - constant.
performGlobalAddressCombine(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * Subtarget,const TargetMachine & TM)24393 static SDValue performGlobalAddressCombine(SDNode *N, SelectionDAG &DAG,
24394 const AArch64Subtarget *Subtarget,
24395 const TargetMachine &TM) {
24396 auto *GN = cast<GlobalAddressSDNode>(N);
24397 if (Subtarget->ClassifyGlobalReference(GN->getGlobal(), TM) !=
24398 AArch64II::MO_NO_FLAG)
24399 return SDValue();
24400
24401 uint64_t MinOffset = -1ull;
24402 for (SDNode *N : GN->uses()) {
24403 if (N->getOpcode() != ISD::ADD)
24404 return SDValue();
24405 auto *C = dyn_cast<ConstantSDNode>(N->getOperand(0));
24406 if (!C)
24407 C = dyn_cast<ConstantSDNode>(N->getOperand(1));
24408 if (!C)
24409 return SDValue();
24410 MinOffset = std::min(MinOffset, C->getZExtValue());
24411 }
24412 uint64_t Offset = MinOffset + GN->getOffset();
24413
24414 // Require that the new offset is larger than the existing one. Otherwise, we
24415 // can end up oscillating between two possible DAGs, for example,
24416 // (add (add globaladdr + 10, -1), 1) and (add globaladdr + 9, 1).
24417 if (Offset <= uint64_t(GN->getOffset()))
24418 return SDValue();
24419
24420 // Check whether folding this offset is legal. It must not go out of bounds of
24421 // the referenced object to avoid violating the code model, and must be
24422 // smaller than 2^20 because this is the largest offset expressible in all
24423 // object formats. (The IMAGE_REL_ARM64_PAGEBASE_REL21 relocation in COFF
24424 // stores an immediate signed 21 bit offset.)
24425 //
24426 // This check also prevents us from folding negative offsets, which will end
24427 // up being treated in the same way as large positive ones. They could also
24428 // cause code model violations, and aren't really common enough to matter.
24429 if (Offset >= (1 << 20))
24430 return SDValue();
24431
24432 const GlobalValue *GV = GN->getGlobal();
24433 Type *T = GV->getValueType();
24434 if (!T->isSized() ||
24435 Offset > GV->getDataLayout().getTypeAllocSize(T))
24436 return SDValue();
24437
24438 SDLoc DL(GN);
24439 SDValue Result = DAG.getGlobalAddress(GV, DL, MVT::i64, Offset);
24440 return DAG.getNode(ISD::SUB, DL, MVT::i64, Result,
24441 DAG.getConstant(MinOffset, DL, MVT::i64));
24442 }
24443
performCTLZCombine(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)24444 static SDValue performCTLZCombine(SDNode *N, SelectionDAG &DAG,
24445 const AArch64Subtarget *Subtarget) {
24446 SDValue BR = N->getOperand(0);
24447 if (!Subtarget->hasCSSC() || BR.getOpcode() != ISD::BITREVERSE ||
24448 !BR.getValueType().isScalarInteger())
24449 return SDValue();
24450
24451 SDLoc DL(N);
24452 return DAG.getNode(ISD::CTTZ, DL, BR.getValueType(), BR.getOperand(0));
24453 }
24454
24455 // Turns the vector of indices into a vector of byte offstes by scaling Offset
24456 // by (BitWidth / 8).
getScaledOffsetForBitWidth(SelectionDAG & DAG,SDValue Offset,SDLoc DL,unsigned BitWidth)24457 static SDValue getScaledOffsetForBitWidth(SelectionDAG &DAG, SDValue Offset,
24458 SDLoc DL, unsigned BitWidth) {
24459 assert(Offset.getValueType().isScalableVector() &&
24460 "This method is only for scalable vectors of offsets");
24461
24462 SDValue Shift = DAG.getConstant(Log2_32(BitWidth / 8), DL, MVT::i64);
24463 SDValue SplatShift = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv2i64, Shift);
24464
24465 return DAG.getNode(ISD::SHL, DL, MVT::nxv2i64, Offset, SplatShift);
24466 }
24467
24468 /// Check if the value of \p OffsetInBytes can be used as an immediate for
24469 /// the gather load/prefetch and scatter store instructions with vector base and
24470 /// immediate offset addressing mode:
24471 ///
24472 /// [<Zn>.[S|D]{, #<imm>}]
24473 ///
24474 /// where <imm> = sizeof(<T>) * k, for k = 0, 1, ..., 31.
isValidImmForSVEVecImmAddrMode(unsigned OffsetInBytes,unsigned ScalarSizeInBytes)24475 inline static bool isValidImmForSVEVecImmAddrMode(unsigned OffsetInBytes,
24476 unsigned ScalarSizeInBytes) {
24477 // The immediate is not a multiple of the scalar size.
24478 if (OffsetInBytes % ScalarSizeInBytes)
24479 return false;
24480
24481 // The immediate is out of range.
24482 if (OffsetInBytes / ScalarSizeInBytes > 31)
24483 return false;
24484
24485 return true;
24486 }
24487
24488 /// Check if the value of \p Offset represents a valid immediate for the SVE
24489 /// gather load/prefetch and scatter store instructiona with vector base and
24490 /// immediate offset addressing mode:
24491 ///
24492 /// [<Zn>.[S|D]{, #<imm>}]
24493 ///
24494 /// where <imm> = sizeof(<T>) * k, for k = 0, 1, ..., 31.
isValidImmForSVEVecImmAddrMode(SDValue Offset,unsigned ScalarSizeInBytes)24495 static bool isValidImmForSVEVecImmAddrMode(SDValue Offset,
24496 unsigned ScalarSizeInBytes) {
24497 ConstantSDNode *OffsetConst = dyn_cast<ConstantSDNode>(Offset.getNode());
24498 return OffsetConst && isValidImmForSVEVecImmAddrMode(
24499 OffsetConst->getZExtValue(), ScalarSizeInBytes);
24500 }
24501
performScatterStoreCombine(SDNode * N,SelectionDAG & DAG,unsigned Opcode,bool OnlyPackedOffsets=true)24502 static SDValue performScatterStoreCombine(SDNode *N, SelectionDAG &DAG,
24503 unsigned Opcode,
24504 bool OnlyPackedOffsets = true) {
24505 const SDValue Src = N->getOperand(2);
24506 const EVT SrcVT = Src->getValueType(0);
24507 assert(SrcVT.isScalableVector() &&
24508 "Scatter stores are only possible for SVE vectors");
24509
24510 SDLoc DL(N);
24511 MVT SrcElVT = SrcVT.getVectorElementType().getSimpleVT();
24512
24513 // Make sure that source data will fit into an SVE register
24514 if (SrcVT.getSizeInBits().getKnownMinValue() > AArch64::SVEBitsPerBlock)
24515 return SDValue();
24516
24517 // For FPs, ACLE only supports _packed_ single and double precision types.
24518 // SST1Q_[INDEX_]PRED is the ST1Q for sve2p1 and should allow all sizes.
24519 if (SrcElVT.isFloatingPoint())
24520 if ((SrcVT != MVT::nxv4f32) && (SrcVT != MVT::nxv2f64) &&
24521 ((Opcode != AArch64ISD::SST1Q_PRED &&
24522 Opcode != AArch64ISD::SST1Q_INDEX_PRED) ||
24523 ((SrcVT != MVT::nxv8f16) && (SrcVT != MVT::nxv8bf16))))
24524 return SDValue();
24525
24526 // Depending on the addressing mode, this is either a pointer or a vector of
24527 // pointers (that fits into one register)
24528 SDValue Base = N->getOperand(4);
24529 // Depending on the addressing mode, this is either a single offset or a
24530 // vector of offsets (that fits into one register)
24531 SDValue Offset = N->getOperand(5);
24532
24533 // For "scalar + vector of indices", just scale the indices. This only
24534 // applies to non-temporal scatters because there's no instruction that takes
24535 // indices.
24536 if (Opcode == AArch64ISD::SSTNT1_INDEX_PRED) {
24537 Offset =
24538 getScaledOffsetForBitWidth(DAG, Offset, DL, SrcElVT.getSizeInBits());
24539 Opcode = AArch64ISD::SSTNT1_PRED;
24540 } else if (Opcode == AArch64ISD::SST1Q_INDEX_PRED) {
24541 Offset =
24542 getScaledOffsetForBitWidth(DAG, Offset, DL, SrcElVT.getSizeInBits());
24543 Opcode = AArch64ISD::SST1Q_PRED;
24544 }
24545
24546 // In the case of non-temporal gather loads there's only one SVE instruction
24547 // per data-size: "scalar + vector", i.e.
24548 // * stnt1{b|h|w|d} { z0.s }, p0/z, [z0.s, x0]
24549 // Since we do have intrinsics that allow the arguments to be in a different
24550 // order, we may need to swap them to match the spec.
24551 if ((Opcode == AArch64ISD::SSTNT1_PRED || Opcode == AArch64ISD::SST1Q_PRED) &&
24552 Offset.getValueType().isVector())
24553 std::swap(Base, Offset);
24554
24555 // SST1_IMM requires that the offset is an immediate that is:
24556 // * a multiple of #SizeInBytes,
24557 // * in the range [0, 31 x #SizeInBytes],
24558 // where #SizeInBytes is the size in bytes of the stored items. For
24559 // immediates outside that range and non-immediate scalar offsets use SST1 or
24560 // SST1_UXTW instead.
24561 if (Opcode == AArch64ISD::SST1_IMM_PRED) {
24562 if (!isValidImmForSVEVecImmAddrMode(Offset,
24563 SrcVT.getScalarSizeInBits() / 8)) {
24564 if (MVT::nxv4i32 == Base.getValueType().getSimpleVT().SimpleTy)
24565 Opcode = AArch64ISD::SST1_UXTW_PRED;
24566 else
24567 Opcode = AArch64ISD::SST1_PRED;
24568
24569 std::swap(Base, Offset);
24570 }
24571 }
24572
24573 auto &TLI = DAG.getTargetLoweringInfo();
24574 if (!TLI.isTypeLegal(Base.getValueType()))
24575 return SDValue();
24576
24577 // Some scatter store variants allow unpacked offsets, but only as nxv2i32
24578 // vectors. These are implicitly sign (sxtw) or zero (zxtw) extend to
24579 // nxv2i64. Legalize accordingly.
24580 if (!OnlyPackedOffsets &&
24581 Offset.getValueType().getSimpleVT().SimpleTy == MVT::nxv2i32)
24582 Offset = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv2i64, Offset).getValue(0);
24583
24584 if (!TLI.isTypeLegal(Offset.getValueType()))
24585 return SDValue();
24586
24587 // Source value type that is representable in hardware
24588 EVT HwSrcVt = getSVEContainerType(SrcVT);
24589
24590 // Keep the original type of the input data to store - this is needed to be
24591 // able to select the correct instruction, e.g. ST1B, ST1H, ST1W and ST1D. For
24592 // FP values we want the integer equivalent, so just use HwSrcVt.
24593 SDValue InputVT = DAG.getValueType(SrcVT);
24594 if (SrcVT.isFloatingPoint())
24595 InputVT = DAG.getValueType(HwSrcVt);
24596
24597 SDVTList VTs = DAG.getVTList(MVT::Other);
24598 SDValue SrcNew;
24599
24600 if (Src.getValueType().isFloatingPoint())
24601 SrcNew = DAG.getNode(ISD::BITCAST, DL, HwSrcVt, Src);
24602 else
24603 SrcNew = DAG.getNode(ISD::ANY_EXTEND, DL, HwSrcVt, Src);
24604
24605 SDValue Ops[] = {N->getOperand(0), // Chain
24606 SrcNew,
24607 N->getOperand(3), // Pg
24608 Base,
24609 Offset,
24610 InputVT};
24611
24612 return DAG.getNode(Opcode, DL, VTs, Ops);
24613 }
24614
performGatherLoadCombine(SDNode * N,SelectionDAG & DAG,unsigned Opcode,bool OnlyPackedOffsets=true)24615 static SDValue performGatherLoadCombine(SDNode *N, SelectionDAG &DAG,
24616 unsigned Opcode,
24617 bool OnlyPackedOffsets = true) {
24618 const EVT RetVT = N->getValueType(0);
24619 assert(RetVT.isScalableVector() &&
24620 "Gather loads are only possible for SVE vectors");
24621
24622 SDLoc DL(N);
24623
24624 // Make sure that the loaded data will fit into an SVE register
24625 if (RetVT.getSizeInBits().getKnownMinValue() > AArch64::SVEBitsPerBlock)
24626 return SDValue();
24627
24628 // Depending on the addressing mode, this is either a pointer or a vector of
24629 // pointers (that fits into one register)
24630 SDValue Base = N->getOperand(3);
24631 // Depending on the addressing mode, this is either a single offset or a
24632 // vector of offsets (that fits into one register)
24633 SDValue Offset = N->getOperand(4);
24634
24635 // For "scalar + vector of indices", scale the indices to obtain unscaled
24636 // offsets. This applies to non-temporal and quadword gathers, which do not
24637 // have an addressing mode with scaled offset.
24638 if (Opcode == AArch64ISD::GLDNT1_INDEX_MERGE_ZERO) {
24639 Offset = getScaledOffsetForBitWidth(DAG, Offset, DL,
24640 RetVT.getScalarSizeInBits());
24641 Opcode = AArch64ISD::GLDNT1_MERGE_ZERO;
24642 } else if (Opcode == AArch64ISD::GLD1Q_INDEX_MERGE_ZERO) {
24643 Offset = getScaledOffsetForBitWidth(DAG, Offset, DL,
24644 RetVT.getScalarSizeInBits());
24645 Opcode = AArch64ISD::GLD1Q_MERGE_ZERO;
24646 }
24647
24648 // In the case of non-temporal gather loads and quadword gather loads there's
24649 // only one addressing mode : "vector + scalar", e.g.
24650 // ldnt1{b|h|w|d} { z0.s }, p0/z, [z0.s, x0]
24651 // Since we do have intrinsics that allow the arguments to be in a different
24652 // order, we may need to swap them to match the spec.
24653 if ((Opcode == AArch64ISD::GLDNT1_MERGE_ZERO ||
24654 Opcode == AArch64ISD::GLD1Q_MERGE_ZERO) &&
24655 Offset.getValueType().isVector())
24656 std::swap(Base, Offset);
24657
24658 // GLD{FF}1_IMM requires that the offset is an immediate that is:
24659 // * a multiple of #SizeInBytes,
24660 // * in the range [0, 31 x #SizeInBytes],
24661 // where #SizeInBytes is the size in bytes of the loaded items. For
24662 // immediates outside that range and non-immediate scalar offsets use
24663 // GLD1_MERGE_ZERO or GLD1_UXTW_MERGE_ZERO instead.
24664 if (Opcode == AArch64ISD::GLD1_IMM_MERGE_ZERO ||
24665 Opcode == AArch64ISD::GLDFF1_IMM_MERGE_ZERO) {
24666 if (!isValidImmForSVEVecImmAddrMode(Offset,
24667 RetVT.getScalarSizeInBits() / 8)) {
24668 if (MVT::nxv4i32 == Base.getValueType().getSimpleVT().SimpleTy)
24669 Opcode = (Opcode == AArch64ISD::GLD1_IMM_MERGE_ZERO)
24670 ? AArch64ISD::GLD1_UXTW_MERGE_ZERO
24671 : AArch64ISD::GLDFF1_UXTW_MERGE_ZERO;
24672 else
24673 Opcode = (Opcode == AArch64ISD::GLD1_IMM_MERGE_ZERO)
24674 ? AArch64ISD::GLD1_MERGE_ZERO
24675 : AArch64ISD::GLDFF1_MERGE_ZERO;
24676
24677 std::swap(Base, Offset);
24678 }
24679 }
24680
24681 auto &TLI = DAG.getTargetLoweringInfo();
24682 if (!TLI.isTypeLegal(Base.getValueType()))
24683 return SDValue();
24684
24685 // Some gather load variants allow unpacked offsets, but only as nxv2i32
24686 // vectors. These are implicitly sign (sxtw) or zero (zxtw) extend to
24687 // nxv2i64. Legalize accordingly.
24688 if (!OnlyPackedOffsets &&
24689 Offset.getValueType().getSimpleVT().SimpleTy == MVT::nxv2i32)
24690 Offset = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv2i64, Offset).getValue(0);
24691
24692 // Return value type that is representable in hardware
24693 EVT HwRetVt = getSVEContainerType(RetVT);
24694
24695 // Keep the original output value type around - this is needed to be able to
24696 // select the correct instruction, e.g. LD1B, LD1H, LD1W and LD1D. For FP
24697 // values we want the integer equivalent, so just use HwRetVT.
24698 SDValue OutVT = DAG.getValueType(RetVT);
24699 if (RetVT.isFloatingPoint())
24700 OutVT = DAG.getValueType(HwRetVt);
24701
24702 SDVTList VTs = DAG.getVTList(HwRetVt, MVT::Other);
24703 SDValue Ops[] = {N->getOperand(0), // Chain
24704 N->getOperand(2), // Pg
24705 Base, Offset, OutVT};
24706
24707 SDValue Load = DAG.getNode(Opcode, DL, VTs, Ops);
24708 SDValue LoadChain = SDValue(Load.getNode(), 1);
24709
24710 if (RetVT.isInteger() && (RetVT != HwRetVt))
24711 Load = DAG.getNode(ISD::TRUNCATE, DL, RetVT, Load.getValue(0));
24712
24713 // If the original return value was FP, bitcast accordingly. Doing it here
24714 // means that we can avoid adding TableGen patterns for FPs.
24715 if (RetVT.isFloatingPoint())
24716 Load = DAG.getNode(ISD::BITCAST, DL, RetVT, Load.getValue(0));
24717
24718 return DAG.getMergeValues({Load, LoadChain}, DL);
24719 }
24720
24721 static SDValue
performSignExtendInRegCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)24722 performSignExtendInRegCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
24723 SelectionDAG &DAG) {
24724 SDLoc DL(N);
24725 SDValue Src = N->getOperand(0);
24726 unsigned Opc = Src->getOpcode();
24727
24728 // Sign extend of an unsigned unpack -> signed unpack
24729 if (Opc == AArch64ISD::UUNPKHI || Opc == AArch64ISD::UUNPKLO) {
24730
24731 unsigned SOpc = Opc == AArch64ISD::UUNPKHI ? AArch64ISD::SUNPKHI
24732 : AArch64ISD::SUNPKLO;
24733
24734 // Push the sign extend to the operand of the unpack
24735 // This is necessary where, for example, the operand of the unpack
24736 // is another unpack:
24737 // 4i32 sign_extend_inreg (4i32 uunpklo(8i16 uunpklo (16i8 opnd)), from 4i8)
24738 // ->
24739 // 4i32 sunpklo (8i16 sign_extend_inreg(8i16 uunpklo (16i8 opnd), from 8i8)
24740 // ->
24741 // 4i32 sunpklo(8i16 sunpklo(16i8 opnd))
24742 SDValue ExtOp = Src->getOperand(0);
24743 auto VT = cast<VTSDNode>(N->getOperand(1))->getVT();
24744 EVT EltTy = VT.getVectorElementType();
24745 (void)EltTy;
24746
24747 assert((EltTy == MVT::i8 || EltTy == MVT::i16 || EltTy == MVT::i32) &&
24748 "Sign extending from an invalid type");
24749
24750 EVT ExtVT = VT.getDoubleNumVectorElementsVT(*DAG.getContext());
24751
24752 SDValue Ext = DAG.getNode(ISD::SIGN_EXTEND_INREG, DL, ExtOp.getValueType(),
24753 ExtOp, DAG.getValueType(ExtVT));
24754
24755 return DAG.getNode(SOpc, DL, N->getValueType(0), Ext);
24756 }
24757
24758 if (DCI.isBeforeLegalizeOps())
24759 return SDValue();
24760
24761 if (!EnableCombineMGatherIntrinsics)
24762 return SDValue();
24763
24764 // SVE load nodes (e.g. AArch64ISD::GLD1) are straightforward candidates
24765 // for DAG Combine with SIGN_EXTEND_INREG. Bail out for all other nodes.
24766 unsigned NewOpc;
24767 unsigned MemVTOpNum = 4;
24768 switch (Opc) {
24769 case AArch64ISD::LD1_MERGE_ZERO:
24770 NewOpc = AArch64ISD::LD1S_MERGE_ZERO;
24771 MemVTOpNum = 3;
24772 break;
24773 case AArch64ISD::LDNF1_MERGE_ZERO:
24774 NewOpc = AArch64ISD::LDNF1S_MERGE_ZERO;
24775 MemVTOpNum = 3;
24776 break;
24777 case AArch64ISD::LDFF1_MERGE_ZERO:
24778 NewOpc = AArch64ISD::LDFF1S_MERGE_ZERO;
24779 MemVTOpNum = 3;
24780 break;
24781 case AArch64ISD::GLD1_MERGE_ZERO:
24782 NewOpc = AArch64ISD::GLD1S_MERGE_ZERO;
24783 break;
24784 case AArch64ISD::GLD1_SCALED_MERGE_ZERO:
24785 NewOpc = AArch64ISD::GLD1S_SCALED_MERGE_ZERO;
24786 break;
24787 case AArch64ISD::GLD1_SXTW_MERGE_ZERO:
24788 NewOpc = AArch64ISD::GLD1S_SXTW_MERGE_ZERO;
24789 break;
24790 case AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO:
24791 NewOpc = AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO;
24792 break;
24793 case AArch64ISD::GLD1_UXTW_MERGE_ZERO:
24794 NewOpc = AArch64ISD::GLD1S_UXTW_MERGE_ZERO;
24795 break;
24796 case AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO:
24797 NewOpc = AArch64ISD::GLD1S_UXTW_SCALED_MERGE_ZERO;
24798 break;
24799 case AArch64ISD::GLD1_IMM_MERGE_ZERO:
24800 NewOpc = AArch64ISD::GLD1S_IMM_MERGE_ZERO;
24801 break;
24802 case AArch64ISD::GLDFF1_MERGE_ZERO:
24803 NewOpc = AArch64ISD::GLDFF1S_MERGE_ZERO;
24804 break;
24805 case AArch64ISD::GLDFF1_SCALED_MERGE_ZERO:
24806 NewOpc = AArch64ISD::GLDFF1S_SCALED_MERGE_ZERO;
24807 break;
24808 case AArch64ISD::GLDFF1_SXTW_MERGE_ZERO:
24809 NewOpc = AArch64ISD::GLDFF1S_SXTW_MERGE_ZERO;
24810 break;
24811 case AArch64ISD::GLDFF1_SXTW_SCALED_MERGE_ZERO:
24812 NewOpc = AArch64ISD::GLDFF1S_SXTW_SCALED_MERGE_ZERO;
24813 break;
24814 case AArch64ISD::GLDFF1_UXTW_MERGE_ZERO:
24815 NewOpc = AArch64ISD::GLDFF1S_UXTW_MERGE_ZERO;
24816 break;
24817 case AArch64ISD::GLDFF1_UXTW_SCALED_MERGE_ZERO:
24818 NewOpc = AArch64ISD::GLDFF1S_UXTW_SCALED_MERGE_ZERO;
24819 break;
24820 case AArch64ISD::GLDFF1_IMM_MERGE_ZERO:
24821 NewOpc = AArch64ISD::GLDFF1S_IMM_MERGE_ZERO;
24822 break;
24823 case AArch64ISD::GLDNT1_MERGE_ZERO:
24824 NewOpc = AArch64ISD::GLDNT1S_MERGE_ZERO;
24825 break;
24826 default:
24827 return SDValue();
24828 }
24829
24830 EVT SignExtSrcVT = cast<VTSDNode>(N->getOperand(1))->getVT();
24831 EVT SrcMemVT = cast<VTSDNode>(Src->getOperand(MemVTOpNum))->getVT();
24832
24833 if ((SignExtSrcVT != SrcMemVT) || !Src.hasOneUse())
24834 return SDValue();
24835
24836 EVT DstVT = N->getValueType(0);
24837 SDVTList VTs = DAG.getVTList(DstVT, MVT::Other);
24838
24839 SmallVector<SDValue, 5> Ops;
24840 for (unsigned I = 0; I < Src->getNumOperands(); ++I)
24841 Ops.push_back(Src->getOperand(I));
24842
24843 SDValue ExtLoad = DAG.getNode(NewOpc, SDLoc(N), VTs, Ops);
24844 DCI.CombineTo(N, ExtLoad);
24845 DCI.CombineTo(Src.getNode(), ExtLoad, ExtLoad.getValue(1));
24846
24847 // Return N so it doesn't get rechecked
24848 return SDValue(N, 0);
24849 }
24850
24851 /// Legalize the gather prefetch (scalar + vector addressing mode) when the
24852 /// offset vector is an unpacked 32-bit scalable vector. The other cases (Offset
24853 /// != nxv2i32) do not need legalization.
legalizeSVEGatherPrefetchOffsVec(SDNode * N,SelectionDAG & DAG)24854 static SDValue legalizeSVEGatherPrefetchOffsVec(SDNode *N, SelectionDAG &DAG) {
24855 const unsigned OffsetPos = 4;
24856 SDValue Offset = N->getOperand(OffsetPos);
24857
24858 // Not an unpacked vector, bail out.
24859 if (Offset.getValueType().getSimpleVT().SimpleTy != MVT::nxv2i32)
24860 return SDValue();
24861
24862 // Extend the unpacked offset vector to 64-bit lanes.
24863 SDLoc DL(N);
24864 Offset = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::nxv2i64, Offset);
24865 SmallVector<SDValue, 5> Ops(N->op_begin(), N->op_end());
24866 // Replace the offset operand with the 64-bit one.
24867 Ops[OffsetPos] = Offset;
24868
24869 return DAG.getNode(N->getOpcode(), DL, DAG.getVTList(MVT::Other), Ops);
24870 }
24871
24872 /// Combines a node carrying the intrinsic
24873 /// `aarch64_sve_prf<T>_gather_scalar_offset` into a node that uses
24874 /// `aarch64_sve_prfb_gather_uxtw_index` when the scalar offset passed to
24875 /// `aarch64_sve_prf<T>_gather_scalar_offset` is not a valid immediate for the
24876 /// sve gather prefetch instruction with vector plus immediate addressing mode.
combineSVEPrefetchVecBaseImmOff(SDNode * N,SelectionDAG & DAG,unsigned ScalarSizeInBytes)24877 static SDValue combineSVEPrefetchVecBaseImmOff(SDNode *N, SelectionDAG &DAG,
24878 unsigned ScalarSizeInBytes) {
24879 const unsigned ImmPos = 4, OffsetPos = 3;
24880 // No need to combine the node if the immediate is valid...
24881 if (isValidImmForSVEVecImmAddrMode(N->getOperand(ImmPos), ScalarSizeInBytes))
24882 return SDValue();
24883
24884 // ...otherwise swap the offset base with the offset...
24885 SmallVector<SDValue, 5> Ops(N->op_begin(), N->op_end());
24886 std::swap(Ops[ImmPos], Ops[OffsetPos]);
24887 // ...and remap the intrinsic `aarch64_sve_prf<T>_gather_scalar_offset` to
24888 // `aarch64_sve_prfb_gather_uxtw_index`.
24889 SDLoc DL(N);
24890 Ops[1] = DAG.getConstant(Intrinsic::aarch64_sve_prfb_gather_uxtw_index, DL,
24891 MVT::i64);
24892
24893 return DAG.getNode(N->getOpcode(), DL, DAG.getVTList(MVT::Other), Ops);
24894 }
24895
24896 // Return true if the vector operation can guarantee only the first lane of its
24897 // result contains data, with all bits in other lanes set to zero.
isLanes1toNKnownZero(SDValue Op)24898 static bool isLanes1toNKnownZero(SDValue Op) {
24899 switch (Op.getOpcode()) {
24900 default:
24901 return false;
24902 case AArch64ISD::ANDV_PRED:
24903 case AArch64ISD::EORV_PRED:
24904 case AArch64ISD::FADDA_PRED:
24905 case AArch64ISD::FADDV_PRED:
24906 case AArch64ISD::FMAXNMV_PRED:
24907 case AArch64ISD::FMAXV_PRED:
24908 case AArch64ISD::FMINNMV_PRED:
24909 case AArch64ISD::FMINV_PRED:
24910 case AArch64ISD::ORV_PRED:
24911 case AArch64ISD::SADDV_PRED:
24912 case AArch64ISD::SMAXV_PRED:
24913 case AArch64ISD::SMINV_PRED:
24914 case AArch64ISD::UADDV_PRED:
24915 case AArch64ISD::UMAXV_PRED:
24916 case AArch64ISD::UMINV_PRED:
24917 return true;
24918 }
24919 }
24920
removeRedundantInsertVectorElt(SDNode * N)24921 static SDValue removeRedundantInsertVectorElt(SDNode *N) {
24922 assert(N->getOpcode() == ISD::INSERT_VECTOR_ELT && "Unexpected node!");
24923 SDValue InsertVec = N->getOperand(0);
24924 SDValue InsertElt = N->getOperand(1);
24925 SDValue InsertIdx = N->getOperand(2);
24926
24927 // We only care about inserts into the first element...
24928 if (!isNullConstant(InsertIdx))
24929 return SDValue();
24930 // ...of a zero'd vector...
24931 if (!ISD::isConstantSplatVectorAllZeros(InsertVec.getNode()))
24932 return SDValue();
24933 // ...where the inserted data was previously extracted...
24934 if (InsertElt.getOpcode() != ISD::EXTRACT_VECTOR_ELT)
24935 return SDValue();
24936
24937 SDValue ExtractVec = InsertElt.getOperand(0);
24938 SDValue ExtractIdx = InsertElt.getOperand(1);
24939
24940 // ...from the first element of a vector.
24941 if (!isNullConstant(ExtractIdx))
24942 return SDValue();
24943
24944 // If we get here we are effectively trying to zero lanes 1-N of a vector.
24945
24946 // Ensure there's no type conversion going on.
24947 if (N->getValueType(0) != ExtractVec.getValueType())
24948 return SDValue();
24949
24950 if (!isLanes1toNKnownZero(ExtractVec))
24951 return SDValue();
24952
24953 // The explicit zeroing is redundant.
24954 return ExtractVec;
24955 }
24956
24957 static SDValue
performInsertVectorEltCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI)24958 performInsertVectorEltCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
24959 if (SDValue Res = removeRedundantInsertVectorElt(N))
24960 return Res;
24961
24962 return performPostLD1Combine(N, DCI, true);
24963 }
24964
performFPExtendCombine(SDNode * N,SelectionDAG & DAG,TargetLowering::DAGCombinerInfo & DCI,const AArch64Subtarget * Subtarget)24965 static SDValue performFPExtendCombine(SDNode *N, SelectionDAG &DAG,
24966 TargetLowering::DAGCombinerInfo &DCI,
24967 const AArch64Subtarget *Subtarget) {
24968 SDValue N0 = N->getOperand(0);
24969 EVT VT = N->getValueType(0);
24970
24971 // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
24972 if (N->hasOneUse() && N->use_begin()->getOpcode() == ISD::FP_ROUND)
24973 return SDValue();
24974
24975 auto hasValidElementTypeForFPExtLoad = [](EVT VT) {
24976 EVT EltVT = VT.getVectorElementType();
24977 return EltVT == MVT::f32 || EltVT == MVT::f64;
24978 };
24979
24980 // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
24981 // We purposefully don't care about legality of the nodes here as we know
24982 // they can be split down into something legal.
24983 if (DCI.isBeforeLegalizeOps() && ISD::isNormalLoad(N0.getNode()) &&
24984 N0.hasOneUse() && Subtarget->useSVEForFixedLengthVectors() &&
24985 VT.isFixedLengthVector() && hasValidElementTypeForFPExtLoad(VT) &&
24986 VT.getFixedSizeInBits() >= Subtarget->getMinSVEVectorSizeInBits()) {
24987 LoadSDNode *LN0 = cast<LoadSDNode>(N0);
24988 SDValue ExtLoad = DAG.getExtLoad(ISD::EXTLOAD, SDLoc(N), VT,
24989 LN0->getChain(), LN0->getBasePtr(),
24990 N0.getValueType(), LN0->getMemOperand());
24991 DCI.CombineTo(N, ExtLoad);
24992 DCI.CombineTo(
24993 N0.getNode(),
24994 DAG.getNode(ISD::FP_ROUND, SDLoc(N0), N0.getValueType(), ExtLoad,
24995 DAG.getIntPtrConstant(1, SDLoc(N0), /*isTarget=*/true)),
24996 ExtLoad.getValue(1));
24997 return SDValue(N, 0); // Return N so it doesn't get rechecked!
24998 }
24999
25000 return SDValue();
25001 }
25002
performBSPExpandForSVE(SDNode * N,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)25003 static SDValue performBSPExpandForSVE(SDNode *N, SelectionDAG &DAG,
25004 const AArch64Subtarget *Subtarget) {
25005 EVT VT = N->getValueType(0);
25006
25007 // Don't expand for NEON, SVE2 or SME
25008 if (!VT.isScalableVector() || Subtarget->hasSVE2() || Subtarget->hasSME())
25009 return SDValue();
25010
25011 SDLoc DL(N);
25012
25013 SDValue Mask = N->getOperand(0);
25014 SDValue In1 = N->getOperand(1);
25015 SDValue In2 = N->getOperand(2);
25016
25017 SDValue InvMask = DAG.getNOT(DL, Mask, VT);
25018 SDValue Sel = DAG.getNode(ISD::AND, DL, VT, Mask, In1);
25019 SDValue SelInv = DAG.getNode(ISD::AND, DL, VT, InvMask, In2);
25020 return DAG.getNode(ISD::OR, DL, VT, Sel, SelInv);
25021 }
25022
performDupLane128Combine(SDNode * N,SelectionDAG & DAG)25023 static SDValue performDupLane128Combine(SDNode *N, SelectionDAG &DAG) {
25024 EVT VT = N->getValueType(0);
25025
25026 SDValue Insert = N->getOperand(0);
25027 if (Insert.getOpcode() != ISD::INSERT_SUBVECTOR)
25028 return SDValue();
25029
25030 if (!Insert.getOperand(0).isUndef())
25031 return SDValue();
25032
25033 uint64_t IdxInsert = Insert.getConstantOperandVal(2);
25034 uint64_t IdxDupLane = N->getConstantOperandVal(1);
25035 if (IdxInsert != 0 || IdxDupLane != 0)
25036 return SDValue();
25037
25038 SDValue Bitcast = Insert.getOperand(1);
25039 if (Bitcast.getOpcode() != ISD::BITCAST)
25040 return SDValue();
25041
25042 SDValue Subvec = Bitcast.getOperand(0);
25043 EVT SubvecVT = Subvec.getValueType();
25044 if (!SubvecVT.is128BitVector())
25045 return SDValue();
25046 EVT NewSubvecVT =
25047 getPackedSVEVectorVT(Subvec.getValueType().getVectorElementType());
25048
25049 SDLoc DL(N);
25050 SDValue NewInsert =
25051 DAG.getNode(ISD::INSERT_SUBVECTOR, DL, NewSubvecVT,
25052 DAG.getUNDEF(NewSubvecVT), Subvec, Insert->getOperand(2));
25053 SDValue NewDuplane128 = DAG.getNode(AArch64ISD::DUPLANE128, DL, NewSubvecVT,
25054 NewInsert, N->getOperand(1));
25055 return DAG.getNode(ISD::BITCAST, DL, VT, NewDuplane128);
25056 }
25057
25058 // Try to combine mull with uzp1.
tryCombineMULLWithUZP1(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)25059 static SDValue tryCombineMULLWithUZP1(SDNode *N,
25060 TargetLowering::DAGCombinerInfo &DCI,
25061 SelectionDAG &DAG) {
25062 if (DCI.isBeforeLegalizeOps())
25063 return SDValue();
25064
25065 SDValue LHS = N->getOperand(0);
25066 SDValue RHS = N->getOperand(1);
25067
25068 SDValue ExtractHigh;
25069 SDValue ExtractLow;
25070 SDValue TruncHigh;
25071 SDValue TruncLow;
25072 SDLoc DL(N);
25073
25074 // Check the operands are trunc and extract_high.
25075 if (isEssentiallyExtractHighSubvector(LHS) &&
25076 RHS.getOpcode() == ISD::TRUNCATE) {
25077 TruncHigh = RHS;
25078 if (LHS.getOpcode() == ISD::BITCAST)
25079 ExtractHigh = LHS.getOperand(0);
25080 else
25081 ExtractHigh = LHS;
25082 } else if (isEssentiallyExtractHighSubvector(RHS) &&
25083 LHS.getOpcode() == ISD::TRUNCATE) {
25084 TruncHigh = LHS;
25085 if (RHS.getOpcode() == ISD::BITCAST)
25086 ExtractHigh = RHS.getOperand(0);
25087 else
25088 ExtractHigh = RHS;
25089 } else
25090 return SDValue();
25091
25092 // If the truncate's operand is BUILD_VECTOR with DUP, do not combine the op
25093 // with uzp1.
25094 // You can see the regressions on test/CodeGen/AArch64/aarch64-smull.ll
25095 SDValue TruncHighOp = TruncHigh.getOperand(0);
25096 EVT TruncHighOpVT = TruncHighOp.getValueType();
25097 if (TruncHighOp.getOpcode() == AArch64ISD::DUP ||
25098 DAG.isSplatValue(TruncHighOp, false))
25099 return SDValue();
25100
25101 // Check there is other extract_high with same source vector.
25102 // For example,
25103 //
25104 // t18: v4i16 = extract_subvector t2, Constant:i64<0>
25105 // t12: v4i16 = truncate t11
25106 // t31: v4i32 = AArch64ISD::SMULL t18, t12
25107 // t23: v4i16 = extract_subvector t2, Constant:i64<4>
25108 // t16: v4i16 = truncate t15
25109 // t30: v4i32 = AArch64ISD::SMULL t23, t1
25110 //
25111 // This dagcombine assumes the two extract_high uses same source vector in
25112 // order to detect the pair of the mull. If they have different source vector,
25113 // this code will not work.
25114 // TODO: Should also try to look through a bitcast.
25115 bool HasFoundMULLow = true;
25116 SDValue ExtractHighSrcVec = ExtractHigh.getOperand(0);
25117 if (ExtractHighSrcVec->use_size() != 2)
25118 HasFoundMULLow = false;
25119
25120 // Find ExtractLow.
25121 for (SDNode *User : ExtractHighSrcVec.getNode()->uses()) {
25122 if (User == ExtractHigh.getNode())
25123 continue;
25124
25125 if (User->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
25126 !isNullConstant(User->getOperand(1))) {
25127 HasFoundMULLow = false;
25128 break;
25129 }
25130
25131 ExtractLow.setNode(User);
25132 }
25133
25134 if (!ExtractLow || !ExtractLow->hasOneUse())
25135 HasFoundMULLow = false;
25136
25137 // Check ExtractLow's user.
25138 if (HasFoundMULLow) {
25139 SDNode *ExtractLowUser = *ExtractLow.getNode()->use_begin();
25140 if (ExtractLowUser->getOpcode() != N->getOpcode()) {
25141 HasFoundMULLow = false;
25142 } else {
25143 if (ExtractLowUser->getOperand(0) == ExtractLow) {
25144 if (ExtractLowUser->getOperand(1).getOpcode() == ISD::TRUNCATE)
25145 TruncLow = ExtractLowUser->getOperand(1);
25146 else
25147 HasFoundMULLow = false;
25148 } else {
25149 if (ExtractLowUser->getOperand(0).getOpcode() == ISD::TRUNCATE)
25150 TruncLow = ExtractLowUser->getOperand(0);
25151 else
25152 HasFoundMULLow = false;
25153 }
25154 }
25155 }
25156
25157 // If the truncate's operand is BUILD_VECTOR with DUP, do not combine the op
25158 // with uzp1.
25159 // You can see the regressions on test/CodeGen/AArch64/aarch64-smull.ll
25160 EVT TruncHighVT = TruncHigh.getValueType();
25161 EVT UZP1VT = TruncHighVT.getDoubleNumVectorElementsVT(*DAG.getContext());
25162 SDValue TruncLowOp =
25163 HasFoundMULLow ? TruncLow.getOperand(0) : DAG.getUNDEF(UZP1VT);
25164 EVT TruncLowOpVT = TruncLowOp.getValueType();
25165 if (HasFoundMULLow && (TruncLowOp.getOpcode() == AArch64ISD::DUP ||
25166 DAG.isSplatValue(TruncLowOp, false)))
25167 return SDValue();
25168
25169 // Create uzp1, extract_high and extract_low.
25170 if (TruncHighOpVT != UZP1VT)
25171 TruncHighOp = DAG.getNode(ISD::BITCAST, DL, UZP1VT, TruncHighOp);
25172 if (TruncLowOpVT != UZP1VT)
25173 TruncLowOp = DAG.getNode(ISD::BITCAST, DL, UZP1VT, TruncLowOp);
25174
25175 SDValue UZP1 =
25176 DAG.getNode(AArch64ISD::UZP1, DL, UZP1VT, TruncLowOp, TruncHighOp);
25177 SDValue HighIdxCst =
25178 DAG.getConstant(TruncHighVT.getVectorNumElements(), DL, MVT::i64);
25179 SDValue NewTruncHigh =
25180 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, TruncHighVT, UZP1, HighIdxCst);
25181 DAG.ReplaceAllUsesWith(TruncHigh, NewTruncHigh);
25182
25183 if (HasFoundMULLow) {
25184 EVT TruncLowVT = TruncLow.getValueType();
25185 SDValue NewTruncLow = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, TruncLowVT,
25186 UZP1, ExtractLow.getOperand(1));
25187 DAG.ReplaceAllUsesWith(TruncLow, NewTruncLow);
25188 }
25189
25190 return SDValue(N, 0);
25191 }
25192
performMULLCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)25193 static SDValue performMULLCombine(SDNode *N,
25194 TargetLowering::DAGCombinerInfo &DCI,
25195 SelectionDAG &DAG) {
25196 if (SDValue Val =
25197 tryCombineLongOpWithDup(Intrinsic::not_intrinsic, N, DCI, DAG))
25198 return Val;
25199
25200 if (SDValue Val = tryCombineMULLWithUZP1(N, DCI, DAG))
25201 return Val;
25202
25203 return SDValue();
25204 }
25205
25206 static SDValue
performScalarToVectorCombine(SDNode * N,TargetLowering::DAGCombinerInfo & DCI,SelectionDAG & DAG)25207 performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
25208 SelectionDAG &DAG) {
25209 // Let's do below transform.
25210 //
25211 // t34: v4i32 = AArch64ISD::UADDLV t2
25212 // t35: i32 = extract_vector_elt t34, Constant:i64<0>
25213 // t7: i64 = zero_extend t35
25214 // t20: v1i64 = scalar_to_vector t7
25215 // ==>
25216 // t34: v4i32 = AArch64ISD::UADDLV t2
25217 // t39: v2i32 = extract_subvector t34, Constant:i64<0>
25218 // t40: v1i64 = AArch64ISD::NVCAST t39
25219 if (DCI.isBeforeLegalizeOps())
25220 return SDValue();
25221
25222 EVT VT = N->getValueType(0);
25223 if (VT != MVT::v1i64)
25224 return SDValue();
25225
25226 SDValue ZEXT = N->getOperand(0);
25227 if (ZEXT.getOpcode() != ISD::ZERO_EXTEND || ZEXT.getValueType() != MVT::i64)
25228 return SDValue();
25229
25230 SDValue EXTRACT_VEC_ELT = ZEXT.getOperand(0);
25231 if (EXTRACT_VEC_ELT.getOpcode() != ISD::EXTRACT_VECTOR_ELT ||
25232 EXTRACT_VEC_ELT.getValueType() != MVT::i32)
25233 return SDValue();
25234
25235 if (!isNullConstant(EXTRACT_VEC_ELT.getOperand(1)))
25236 return SDValue();
25237
25238 SDValue UADDLV = EXTRACT_VEC_ELT.getOperand(0);
25239 if (UADDLV.getOpcode() != AArch64ISD::UADDLV ||
25240 UADDLV.getValueType() != MVT::v4i32 ||
25241 UADDLV.getOperand(0).getValueType() != MVT::v8i8)
25242 return SDValue();
25243
25244 // Let's generate new sequence with AArch64ISD::NVCAST.
25245 SDLoc DL(N);
25246 SDValue EXTRACT_SUBVEC =
25247 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, MVT::v2i32, UADDLV,
25248 DAG.getConstant(0, DL, MVT::i64));
25249 SDValue NVCAST =
25250 DAG.getNode(AArch64ISD::NVCAST, DL, MVT::v1i64, EXTRACT_SUBVEC);
25251
25252 return NVCAST;
25253 }
25254
PerformDAGCombine(SDNode * N,DAGCombinerInfo & DCI) const25255 SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
25256 DAGCombinerInfo &DCI) const {
25257 SelectionDAG &DAG = DCI.DAG;
25258 switch (N->getOpcode()) {
25259 default:
25260 LLVM_DEBUG(dbgs() << "Custom combining: skipping\n");
25261 break;
25262 case ISD::VECREDUCE_AND:
25263 case ISD::VECREDUCE_OR:
25264 case ISD::VECREDUCE_XOR:
25265 return performVecReduceBitwiseCombine(N, DCI, DAG);
25266 case ISD::ADD:
25267 case ISD::SUB:
25268 return performAddSubCombine(N, DCI);
25269 case ISD::BUILD_VECTOR:
25270 return performBuildVectorCombine(N, DCI, DAG);
25271 case ISD::TRUNCATE:
25272 return performTruncateCombine(N, DAG);
25273 case AArch64ISD::ANDS:
25274 return performFlagSettingCombine(N, DCI, ISD::AND);
25275 case AArch64ISD::ADC:
25276 if (auto R = foldOverflowCheck(N, DAG, /* IsAdd */ true))
25277 return R;
25278 return foldADCToCINC(N, DAG);
25279 case AArch64ISD::SBC:
25280 return foldOverflowCheck(N, DAG, /* IsAdd */ false);
25281 case AArch64ISD::ADCS:
25282 if (auto R = foldOverflowCheck(N, DAG, /* IsAdd */ true))
25283 return R;
25284 return performFlagSettingCombine(N, DCI, AArch64ISD::ADC);
25285 case AArch64ISD::SBCS:
25286 if (auto R = foldOverflowCheck(N, DAG, /* IsAdd */ false))
25287 return R;
25288 return performFlagSettingCombine(N, DCI, AArch64ISD::SBC);
25289 case AArch64ISD::BICi: {
25290 APInt DemandedBits =
25291 APInt::getAllOnes(N->getValueType(0).getScalarSizeInBits());
25292 APInt DemandedElts =
25293 APInt::getAllOnes(N->getValueType(0).getVectorNumElements());
25294
25295 if (DAG.getTargetLoweringInfo().SimplifyDemandedBits(
25296 SDValue(N, 0), DemandedBits, DemandedElts, DCI))
25297 return SDValue();
25298
25299 break;
25300 }
25301 case ISD::XOR:
25302 return performXorCombine(N, DAG, DCI, Subtarget);
25303 case ISD::MUL:
25304 return performMulCombine(N, DAG, DCI, Subtarget);
25305 case ISD::SINT_TO_FP:
25306 case ISD::UINT_TO_FP:
25307 return performIntToFpCombine(N, DAG, Subtarget);
25308 case ISD::FP_TO_SINT:
25309 case ISD::FP_TO_UINT:
25310 case ISD::FP_TO_SINT_SAT:
25311 case ISD::FP_TO_UINT_SAT:
25312 return performFpToIntCombine(N, DAG, DCI, Subtarget);
25313 case ISD::OR:
25314 return performORCombine(N, DCI, Subtarget, *this);
25315 case ISD::AND:
25316 return performANDCombine(N, DCI);
25317 case ISD::FADD:
25318 return performFADDCombine(N, DCI);
25319 case ISD::INTRINSIC_WO_CHAIN:
25320 return performIntrinsicCombine(N, DCI, Subtarget);
25321 case ISD::ANY_EXTEND:
25322 case ISD::ZERO_EXTEND:
25323 case ISD::SIGN_EXTEND:
25324 return performExtendCombine(N, DCI, DAG);
25325 case ISD::SIGN_EXTEND_INREG:
25326 return performSignExtendInRegCombine(N, DCI, DAG);
25327 case ISD::CONCAT_VECTORS:
25328 return performConcatVectorsCombine(N, DCI, DAG);
25329 case ISD::EXTRACT_SUBVECTOR:
25330 return performExtractSubvectorCombine(N, DCI, DAG);
25331 case ISD::INSERT_SUBVECTOR:
25332 return performInsertSubvectorCombine(N, DCI, DAG);
25333 case ISD::SELECT:
25334 return performSelectCombine(N, DCI);
25335 case ISD::VSELECT:
25336 return performVSelectCombine(N, DCI.DAG);
25337 case ISD::SETCC:
25338 return performSETCCCombine(N, DCI, DAG);
25339 case ISD::LOAD:
25340 return performLOADCombine(N, DCI, DAG, Subtarget);
25341 case ISD::STORE:
25342 return performSTORECombine(N, DCI, DAG, Subtarget);
25343 case ISD::MSTORE:
25344 return performMSTORECombine(N, DCI, DAG, Subtarget);
25345 case ISD::MGATHER:
25346 case ISD::MSCATTER:
25347 return performMaskedGatherScatterCombine(N, DCI, DAG);
25348 case ISD::FP_EXTEND:
25349 return performFPExtendCombine(N, DAG, DCI, Subtarget);
25350 case AArch64ISD::BRCOND:
25351 return performBRCONDCombine(N, DCI, DAG);
25352 case AArch64ISD::TBNZ:
25353 case AArch64ISD::TBZ:
25354 return performTBZCombine(N, DCI, DAG);
25355 case AArch64ISD::CSEL:
25356 return performCSELCombine(N, DCI, DAG);
25357 case AArch64ISD::DUP:
25358 case AArch64ISD::DUPLANE8:
25359 case AArch64ISD::DUPLANE16:
25360 case AArch64ISD::DUPLANE32:
25361 case AArch64ISD::DUPLANE64:
25362 return performDUPCombine(N, DCI);
25363 case AArch64ISD::DUPLANE128:
25364 return performDupLane128Combine(N, DAG);
25365 case AArch64ISD::NVCAST:
25366 return performNVCASTCombine(N, DAG);
25367 case AArch64ISD::SPLICE:
25368 return performSpliceCombine(N, DAG);
25369 case AArch64ISD::UUNPKLO:
25370 case AArch64ISD::UUNPKHI:
25371 return performUnpackCombine(N, DAG, Subtarget);
25372 case AArch64ISD::UZP1:
25373 case AArch64ISD::UZP2:
25374 return performUzpCombine(N, DAG, Subtarget);
25375 case AArch64ISD::SETCC_MERGE_ZERO:
25376 return performSetccMergeZeroCombine(N, DCI);
25377 case AArch64ISD::REINTERPRET_CAST:
25378 return performReinterpretCastCombine(N);
25379 case AArch64ISD::GLD1_MERGE_ZERO:
25380 case AArch64ISD::GLD1_SCALED_MERGE_ZERO:
25381 case AArch64ISD::GLD1_UXTW_MERGE_ZERO:
25382 case AArch64ISD::GLD1_SXTW_MERGE_ZERO:
25383 case AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO:
25384 case AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO:
25385 case AArch64ISD::GLD1_IMM_MERGE_ZERO:
25386 case AArch64ISD::GLD1S_MERGE_ZERO:
25387 case AArch64ISD::GLD1S_SCALED_MERGE_ZERO:
25388 case AArch64ISD::GLD1S_UXTW_MERGE_ZERO:
25389 case AArch64ISD::GLD1S_SXTW_MERGE_ZERO:
25390 case AArch64ISD::GLD1S_UXTW_SCALED_MERGE_ZERO:
25391 case AArch64ISD::GLD1S_SXTW_SCALED_MERGE_ZERO:
25392 case AArch64ISD::GLD1S_IMM_MERGE_ZERO:
25393 return performGLD1Combine(N, DAG);
25394 case AArch64ISD::VASHR:
25395 case AArch64ISD::VLSHR:
25396 return performVectorShiftCombine(N, *this, DCI);
25397 case AArch64ISD::SUNPKLO:
25398 return performSunpkloCombine(N, DAG);
25399 case AArch64ISD::BSP:
25400 return performBSPExpandForSVE(N, DAG, Subtarget);
25401 case ISD::INSERT_VECTOR_ELT:
25402 return performInsertVectorEltCombine(N, DCI);
25403 case ISD::EXTRACT_VECTOR_ELT:
25404 return performExtractVectorEltCombine(N, DCI, Subtarget);
25405 case ISD::VECREDUCE_ADD:
25406 return performVecReduceAddCombine(N, DCI.DAG, Subtarget);
25407 case AArch64ISD::UADDV:
25408 return performUADDVCombine(N, DAG);
25409 case AArch64ISD::SMULL:
25410 case AArch64ISD::UMULL:
25411 case AArch64ISD::PMULL:
25412 return performMULLCombine(N, DCI, DAG);
25413 case ISD::INTRINSIC_VOID:
25414 case ISD::INTRINSIC_W_CHAIN:
25415 switch (N->getConstantOperandVal(1)) {
25416 case Intrinsic::aarch64_sve_prfb_gather_scalar_offset:
25417 return combineSVEPrefetchVecBaseImmOff(N, DAG, 1 /*=ScalarSizeInBytes*/);
25418 case Intrinsic::aarch64_sve_prfh_gather_scalar_offset:
25419 return combineSVEPrefetchVecBaseImmOff(N, DAG, 2 /*=ScalarSizeInBytes*/);
25420 case Intrinsic::aarch64_sve_prfw_gather_scalar_offset:
25421 return combineSVEPrefetchVecBaseImmOff(N, DAG, 4 /*=ScalarSizeInBytes*/);
25422 case Intrinsic::aarch64_sve_prfd_gather_scalar_offset:
25423 return combineSVEPrefetchVecBaseImmOff(N, DAG, 8 /*=ScalarSizeInBytes*/);
25424 case Intrinsic::aarch64_sve_prfb_gather_uxtw_index:
25425 case Intrinsic::aarch64_sve_prfb_gather_sxtw_index:
25426 case Intrinsic::aarch64_sve_prfh_gather_uxtw_index:
25427 case Intrinsic::aarch64_sve_prfh_gather_sxtw_index:
25428 case Intrinsic::aarch64_sve_prfw_gather_uxtw_index:
25429 case Intrinsic::aarch64_sve_prfw_gather_sxtw_index:
25430 case Intrinsic::aarch64_sve_prfd_gather_uxtw_index:
25431 case Intrinsic::aarch64_sve_prfd_gather_sxtw_index:
25432 return legalizeSVEGatherPrefetchOffsVec(N, DAG);
25433 case Intrinsic::aarch64_neon_ld2:
25434 case Intrinsic::aarch64_neon_ld3:
25435 case Intrinsic::aarch64_neon_ld4:
25436 case Intrinsic::aarch64_neon_ld1x2:
25437 case Intrinsic::aarch64_neon_ld1x3:
25438 case Intrinsic::aarch64_neon_ld1x4:
25439 case Intrinsic::aarch64_neon_ld2lane:
25440 case Intrinsic::aarch64_neon_ld3lane:
25441 case Intrinsic::aarch64_neon_ld4lane:
25442 case Intrinsic::aarch64_neon_ld2r:
25443 case Intrinsic::aarch64_neon_ld3r:
25444 case Intrinsic::aarch64_neon_ld4r:
25445 case Intrinsic::aarch64_neon_st2:
25446 case Intrinsic::aarch64_neon_st3:
25447 case Intrinsic::aarch64_neon_st4:
25448 case Intrinsic::aarch64_neon_st1x2:
25449 case Intrinsic::aarch64_neon_st1x3:
25450 case Intrinsic::aarch64_neon_st1x4:
25451 case Intrinsic::aarch64_neon_st2lane:
25452 case Intrinsic::aarch64_neon_st3lane:
25453 case Intrinsic::aarch64_neon_st4lane:
25454 return performNEONPostLDSTCombine(N, DCI, DAG);
25455 case Intrinsic::aarch64_sve_ldnt1:
25456 return performLDNT1Combine(N, DAG);
25457 case Intrinsic::aarch64_sve_ld1rq:
25458 return performLD1ReplicateCombine<AArch64ISD::LD1RQ_MERGE_ZERO>(N, DAG);
25459 case Intrinsic::aarch64_sve_ld1ro:
25460 return performLD1ReplicateCombine<AArch64ISD::LD1RO_MERGE_ZERO>(N, DAG);
25461 case Intrinsic::aarch64_sve_ldnt1_gather_scalar_offset:
25462 return performGatherLoadCombine(N, DAG, AArch64ISD::GLDNT1_MERGE_ZERO);
25463 case Intrinsic::aarch64_sve_ldnt1_gather:
25464 return performGatherLoadCombine(N, DAG, AArch64ISD::GLDNT1_MERGE_ZERO);
25465 case Intrinsic::aarch64_sve_ldnt1_gather_index:
25466 return performGatherLoadCombine(N, DAG,
25467 AArch64ISD::GLDNT1_INDEX_MERGE_ZERO);
25468 case Intrinsic::aarch64_sve_ldnt1_gather_uxtw:
25469 return performGatherLoadCombine(N, DAG, AArch64ISD::GLDNT1_MERGE_ZERO);
25470 case Intrinsic::aarch64_sve_ld1:
25471 return performLD1Combine(N, DAG, AArch64ISD::LD1_MERGE_ZERO);
25472 case Intrinsic::aarch64_sve_ldnf1:
25473 return performLD1Combine(N, DAG, AArch64ISD::LDNF1_MERGE_ZERO);
25474 case Intrinsic::aarch64_sve_ldff1:
25475 return performLD1Combine(N, DAG, AArch64ISD::LDFF1_MERGE_ZERO);
25476 case Intrinsic::aarch64_sve_st1:
25477 return performST1Combine(N, DAG);
25478 case Intrinsic::aarch64_sve_stnt1:
25479 return performSTNT1Combine(N, DAG);
25480 case Intrinsic::aarch64_sve_stnt1_scatter_scalar_offset:
25481 return performScatterStoreCombine(N, DAG, AArch64ISD::SSTNT1_PRED);
25482 case Intrinsic::aarch64_sve_stnt1_scatter_uxtw:
25483 return performScatterStoreCombine(N, DAG, AArch64ISD::SSTNT1_PRED);
25484 case Intrinsic::aarch64_sve_stnt1_scatter:
25485 return performScatterStoreCombine(N, DAG, AArch64ISD::SSTNT1_PRED);
25486 case Intrinsic::aarch64_sve_stnt1_scatter_index:
25487 return performScatterStoreCombine(N, DAG, AArch64ISD::SSTNT1_INDEX_PRED);
25488 case Intrinsic::aarch64_sve_ld1_gather:
25489 return performGatherLoadCombine(N, DAG, AArch64ISD::GLD1_MERGE_ZERO);
25490 case Intrinsic::aarch64_sve_ld1q_gather_scalar_offset:
25491 case Intrinsic::aarch64_sve_ld1q_gather_vector_offset:
25492 return performGatherLoadCombine(N, DAG, AArch64ISD::GLD1Q_MERGE_ZERO);
25493 case Intrinsic::aarch64_sve_ld1q_gather_index:
25494 return performGatherLoadCombine(N, DAG,
25495 AArch64ISD::GLD1Q_INDEX_MERGE_ZERO);
25496 case Intrinsic::aarch64_sve_ld1_gather_index:
25497 return performGatherLoadCombine(N, DAG,
25498 AArch64ISD::GLD1_SCALED_MERGE_ZERO);
25499 case Intrinsic::aarch64_sve_ld1_gather_sxtw:
25500 return performGatherLoadCombine(N, DAG, AArch64ISD::GLD1_SXTW_MERGE_ZERO,
25501 /*OnlyPackedOffsets=*/false);
25502 case Intrinsic::aarch64_sve_ld1_gather_uxtw:
25503 return performGatherLoadCombine(N, DAG, AArch64ISD::GLD1_UXTW_MERGE_ZERO,
25504 /*OnlyPackedOffsets=*/false);
25505 case Intrinsic::aarch64_sve_ld1_gather_sxtw_index:
25506 return performGatherLoadCombine(N, DAG,
25507 AArch64ISD::GLD1_SXTW_SCALED_MERGE_ZERO,
25508 /*OnlyPackedOffsets=*/false);
25509 case Intrinsic::aarch64_sve_ld1_gather_uxtw_index:
25510 return performGatherLoadCombine(N, DAG,
25511 AArch64ISD::GLD1_UXTW_SCALED_MERGE_ZERO,
25512 /*OnlyPackedOffsets=*/false);
25513 case Intrinsic::aarch64_sve_ld1_gather_scalar_offset:
25514 return performGatherLoadCombine(N, DAG, AArch64ISD::GLD1_IMM_MERGE_ZERO);
25515 case Intrinsic::aarch64_sve_ldff1_gather:
25516 return performGatherLoadCombine(N, DAG, AArch64ISD::GLDFF1_MERGE_ZERO);
25517 case Intrinsic::aarch64_sve_ldff1_gather_index:
25518 return performGatherLoadCombine(N, DAG,
25519 AArch64ISD::GLDFF1_SCALED_MERGE_ZERO);
25520 case Intrinsic::aarch64_sve_ldff1_gather_sxtw:
25521 return performGatherLoadCombine(N, DAG,
25522 AArch64ISD::GLDFF1_SXTW_MERGE_ZERO,
25523 /*OnlyPackedOffsets=*/false);
25524 case Intrinsic::aarch64_sve_ldff1_gather_uxtw:
25525 return performGatherLoadCombine(N, DAG,
25526 AArch64ISD::GLDFF1_UXTW_MERGE_ZERO,
25527 /*OnlyPackedOffsets=*/false);
25528 case Intrinsic::aarch64_sve_ldff1_gather_sxtw_index:
25529 return performGatherLoadCombine(N, DAG,
25530 AArch64ISD::GLDFF1_SXTW_SCALED_MERGE_ZERO,
25531 /*OnlyPackedOffsets=*/false);
25532 case Intrinsic::aarch64_sve_ldff1_gather_uxtw_index:
25533 return performGatherLoadCombine(N, DAG,
25534 AArch64ISD::GLDFF1_UXTW_SCALED_MERGE_ZERO,
25535 /*OnlyPackedOffsets=*/false);
25536 case Intrinsic::aarch64_sve_ldff1_gather_scalar_offset:
25537 return performGatherLoadCombine(N, DAG,
25538 AArch64ISD::GLDFF1_IMM_MERGE_ZERO);
25539 case Intrinsic::aarch64_sve_st1q_scatter_scalar_offset:
25540 case Intrinsic::aarch64_sve_st1q_scatter_vector_offset:
25541 return performScatterStoreCombine(N, DAG, AArch64ISD::SST1Q_PRED);
25542 case Intrinsic::aarch64_sve_st1q_scatter_index:
25543 return performScatterStoreCombine(N, DAG, AArch64ISD::SST1Q_INDEX_PRED);
25544 case Intrinsic::aarch64_sve_st1_scatter:
25545 return performScatterStoreCombine(N, DAG, AArch64ISD::SST1_PRED);
25546 case Intrinsic::aarch64_sve_st1_scatter_index:
25547 return performScatterStoreCombine(N, DAG, AArch64ISD::SST1_SCALED_PRED);
25548 case Intrinsic::aarch64_sve_st1_scatter_sxtw:
25549 return performScatterStoreCombine(N, DAG, AArch64ISD::SST1_SXTW_PRED,
25550 /*OnlyPackedOffsets=*/false);
25551 case Intrinsic::aarch64_sve_st1_scatter_uxtw:
25552 return performScatterStoreCombine(N, DAG, AArch64ISD::SST1_UXTW_PRED,
25553 /*OnlyPackedOffsets=*/false);
25554 case Intrinsic::aarch64_sve_st1_scatter_sxtw_index:
25555 return performScatterStoreCombine(N, DAG,
25556 AArch64ISD::SST1_SXTW_SCALED_PRED,
25557 /*OnlyPackedOffsets=*/false);
25558 case Intrinsic::aarch64_sve_st1_scatter_uxtw_index:
25559 return performScatterStoreCombine(N, DAG,
25560 AArch64ISD::SST1_UXTW_SCALED_PRED,
25561 /*OnlyPackedOffsets=*/false);
25562 case Intrinsic::aarch64_sve_st1_scatter_scalar_offset:
25563 return performScatterStoreCombine(N, DAG, AArch64ISD::SST1_IMM_PRED);
25564 case Intrinsic::aarch64_rndr:
25565 case Intrinsic::aarch64_rndrrs: {
25566 unsigned IntrinsicID = N->getConstantOperandVal(1);
25567 auto Register =
25568 (IntrinsicID == Intrinsic::aarch64_rndr ? AArch64SysReg::RNDR
25569 : AArch64SysReg::RNDRRS);
25570 SDLoc DL(N);
25571 SDValue A = DAG.getNode(
25572 AArch64ISD::MRS, DL, DAG.getVTList(MVT::i64, MVT::Glue, MVT::Other),
25573 N->getOperand(0), DAG.getConstant(Register, DL, MVT::i64));
25574 SDValue B = DAG.getNode(
25575 AArch64ISD::CSINC, DL, MVT::i32, DAG.getConstant(0, DL, MVT::i32),
25576 DAG.getConstant(0, DL, MVT::i32),
25577 DAG.getConstant(AArch64CC::NE, DL, MVT::i32), A.getValue(1));
25578 return DAG.getMergeValues(
25579 {A, DAG.getZExtOrTrunc(B, DL, MVT::i1), A.getValue(2)}, DL);
25580 }
25581 case Intrinsic::aarch64_sme_ldr_zt:
25582 return DAG.getNode(AArch64ISD::RESTORE_ZT, SDLoc(N),
25583 DAG.getVTList(MVT::Other), N->getOperand(0),
25584 N->getOperand(2), N->getOperand(3));
25585 case Intrinsic::aarch64_sme_str_zt:
25586 return DAG.getNode(AArch64ISD::SAVE_ZT, SDLoc(N),
25587 DAG.getVTList(MVT::Other), N->getOperand(0),
25588 N->getOperand(2), N->getOperand(3));
25589 default:
25590 break;
25591 }
25592 break;
25593 case ISD::GlobalAddress:
25594 return performGlobalAddressCombine(N, DAG, Subtarget, getTargetMachine());
25595 case ISD::CTLZ:
25596 return performCTLZCombine(N, DAG, Subtarget);
25597 case ISD::SCALAR_TO_VECTOR:
25598 return performScalarToVectorCombine(N, DCI, DAG);
25599 }
25600 return SDValue();
25601 }
25602
25603 // Check if the return value is used as only a return value, as otherwise
25604 // we can't perform a tail-call. In particular, we need to check for
25605 // target ISD nodes that are returns and any other "odd" constructs
25606 // that the generic analysis code won't necessarily catch.
isUsedByReturnOnly(SDNode * N,SDValue & Chain) const25607 bool AArch64TargetLowering::isUsedByReturnOnly(SDNode *N,
25608 SDValue &Chain) const {
25609 if (N->getNumValues() != 1)
25610 return false;
25611 if (!N->hasNUsesOfValue(1, 0))
25612 return false;
25613
25614 SDValue TCChain = Chain;
25615 SDNode *Copy = *N->use_begin();
25616 if (Copy->getOpcode() == ISD::CopyToReg) {
25617 // If the copy has a glue operand, we conservatively assume it isn't safe to
25618 // perform a tail call.
25619 if (Copy->getOperand(Copy->getNumOperands() - 1).getValueType() ==
25620 MVT::Glue)
25621 return false;
25622 TCChain = Copy->getOperand(0);
25623 } else if (Copy->getOpcode() != ISD::FP_EXTEND)
25624 return false;
25625
25626 bool HasRet = false;
25627 for (SDNode *Node : Copy->uses()) {
25628 if (Node->getOpcode() != AArch64ISD::RET_GLUE)
25629 return false;
25630 HasRet = true;
25631 }
25632
25633 if (!HasRet)
25634 return false;
25635
25636 Chain = TCChain;
25637 return true;
25638 }
25639
25640 // Return whether the an instruction can potentially be optimized to a tail
25641 // call. This will cause the optimizers to attempt to move, or duplicate,
25642 // return instructions to help enable tail call optimizations for this
25643 // instruction.
mayBeEmittedAsTailCall(const CallInst * CI) const25644 bool AArch64TargetLowering::mayBeEmittedAsTailCall(const CallInst *CI) const {
25645 return CI->isTailCall();
25646 }
25647
isIndexingLegal(MachineInstr & MI,Register Base,Register Offset,bool IsPre,MachineRegisterInfo & MRI) const25648 bool AArch64TargetLowering::isIndexingLegal(MachineInstr &MI, Register Base,
25649 Register Offset, bool IsPre,
25650 MachineRegisterInfo &MRI) const {
25651 auto CstOffset = getIConstantVRegVal(Offset, MRI);
25652 if (!CstOffset || CstOffset->isZero())
25653 return false;
25654
25655 // All of the indexed addressing mode instructions take a signed 9 bit
25656 // immediate offset. Our CstOffset is a G_PTR_ADD offset so it already
25657 // encodes the sign/indexing direction.
25658 return isInt<9>(CstOffset->getSExtValue());
25659 }
25660
getIndexedAddressParts(SDNode * N,SDNode * Op,SDValue & Base,SDValue & Offset,SelectionDAG & DAG) const25661 bool AArch64TargetLowering::getIndexedAddressParts(SDNode *N, SDNode *Op,
25662 SDValue &Base,
25663 SDValue &Offset,
25664 SelectionDAG &DAG) const {
25665 if (Op->getOpcode() != ISD::ADD && Op->getOpcode() != ISD::SUB)
25666 return false;
25667
25668 // Non-null if there is exactly one user of the loaded value (ignoring chain).
25669 SDNode *ValOnlyUser = nullptr;
25670 for (SDNode::use_iterator UI = N->use_begin(), UE = N->use_end(); UI != UE;
25671 ++UI) {
25672 if (UI.getUse().getResNo() == 1)
25673 continue; // Ignore chain.
25674 if (ValOnlyUser == nullptr)
25675 ValOnlyUser = *UI;
25676 else {
25677 ValOnlyUser = nullptr; // Multiple non-chain uses, bail out.
25678 break;
25679 }
25680 }
25681
25682 auto IsUndefOrZero = [](SDValue V) {
25683 return V.isUndef() || isNullOrNullSplat(V, /*AllowUndefs*/ true);
25684 };
25685
25686 // If the only user of the value is a scalable vector splat, it is
25687 // preferable to do a replicating load (ld1r*).
25688 if (ValOnlyUser && ValOnlyUser->getValueType(0).isScalableVector() &&
25689 (ValOnlyUser->getOpcode() == ISD::SPLAT_VECTOR ||
25690 (ValOnlyUser->getOpcode() == AArch64ISD::DUP_MERGE_PASSTHRU &&
25691 IsUndefOrZero(ValOnlyUser->getOperand(2)))))
25692 return false;
25693
25694 Base = Op->getOperand(0);
25695 // All of the indexed addressing mode instructions take a signed
25696 // 9 bit immediate offset.
25697 if (ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(Op->getOperand(1))) {
25698 int64_t RHSC = RHS->getSExtValue();
25699 if (Op->getOpcode() == ISD::SUB)
25700 RHSC = -(uint64_t)RHSC;
25701 if (!isInt<9>(RHSC))
25702 return false;
25703 // Always emit pre-inc/post-inc addressing mode. Use negated constant offset
25704 // when dealing with subtraction.
25705 Offset = DAG.getConstant(RHSC, SDLoc(N), RHS->getValueType(0));
25706 return true;
25707 }
25708 return false;
25709 }
25710
getPreIndexedAddressParts(SDNode * N,SDValue & Base,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG) const25711 bool AArch64TargetLowering::getPreIndexedAddressParts(SDNode *N, SDValue &Base,
25712 SDValue &Offset,
25713 ISD::MemIndexedMode &AM,
25714 SelectionDAG &DAG) const {
25715 EVT VT;
25716 SDValue Ptr;
25717 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
25718 VT = LD->getMemoryVT();
25719 Ptr = LD->getBasePtr();
25720 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
25721 VT = ST->getMemoryVT();
25722 Ptr = ST->getBasePtr();
25723 } else
25724 return false;
25725
25726 if (!getIndexedAddressParts(N, Ptr.getNode(), Base, Offset, DAG))
25727 return false;
25728 AM = ISD::PRE_INC;
25729 return true;
25730 }
25731
getPostIndexedAddressParts(SDNode * N,SDNode * Op,SDValue & Base,SDValue & Offset,ISD::MemIndexedMode & AM,SelectionDAG & DAG) const25732 bool AArch64TargetLowering::getPostIndexedAddressParts(
25733 SDNode *N, SDNode *Op, SDValue &Base, SDValue &Offset,
25734 ISD::MemIndexedMode &AM, SelectionDAG &DAG) const {
25735 EVT VT;
25736 SDValue Ptr;
25737 if (LoadSDNode *LD = dyn_cast<LoadSDNode>(N)) {
25738 VT = LD->getMemoryVT();
25739 Ptr = LD->getBasePtr();
25740 } else if (StoreSDNode *ST = dyn_cast<StoreSDNode>(N)) {
25741 VT = ST->getMemoryVT();
25742 Ptr = ST->getBasePtr();
25743 } else
25744 return false;
25745
25746 if (!getIndexedAddressParts(N, Op, Base, Offset, DAG))
25747 return false;
25748 // Post-indexing updates the base, so it's not a valid transform
25749 // if that's not the same as the load's pointer.
25750 if (Ptr != Base)
25751 return false;
25752 AM = ISD::POST_INC;
25753 return true;
25754 }
25755
replaceBoolVectorBitcast(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG)25756 static void replaceBoolVectorBitcast(SDNode *N,
25757 SmallVectorImpl<SDValue> &Results,
25758 SelectionDAG &DAG) {
25759 SDLoc DL(N);
25760 SDValue Op = N->getOperand(0);
25761 EVT VT = N->getValueType(0);
25762 [[maybe_unused]] EVT SrcVT = Op.getValueType();
25763 assert(SrcVT.isVector() && SrcVT.getVectorElementType() == MVT::i1 &&
25764 "Must be bool vector.");
25765
25766 // Special handling for Clang's __builtin_convertvector. For vectors with <8
25767 // elements, it adds a vector concatenation with undef(s). If we encounter
25768 // this here, we can skip the concat.
25769 if (Op.getOpcode() == ISD::CONCAT_VECTORS && !Op.getOperand(0).isUndef()) {
25770 bool AllUndef = true;
25771 for (unsigned I = 1; I < Op.getNumOperands(); ++I)
25772 AllUndef &= Op.getOperand(I).isUndef();
25773
25774 if (AllUndef)
25775 Op = Op.getOperand(0);
25776 }
25777
25778 SDValue VectorBits = vectorToScalarBitmask(Op.getNode(), DAG);
25779 if (VectorBits)
25780 Results.push_back(DAG.getZExtOrTrunc(VectorBits, DL, VT));
25781 }
25782
CustomNonLegalBITCASTResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG,EVT ExtendVT,EVT CastVT)25783 static void CustomNonLegalBITCASTResults(SDNode *N,
25784 SmallVectorImpl<SDValue> &Results,
25785 SelectionDAG &DAG, EVT ExtendVT,
25786 EVT CastVT) {
25787 SDLoc DL(N);
25788 SDValue Op = N->getOperand(0);
25789 EVT VT = N->getValueType(0);
25790
25791 // Use SCALAR_TO_VECTOR for lane zero
25792 SDValue Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, ExtendVT, Op);
25793 SDValue CastVal = DAG.getNode(ISD::BITCAST, DL, CastVT, Vec);
25794 SDValue IdxZero = DAG.getVectorIdxConstant(0, DL);
25795 Results.push_back(
25796 DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, CastVal, IdxZero));
25797 }
25798
ReplaceBITCASTResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG) const25799 void AArch64TargetLowering::ReplaceBITCASTResults(
25800 SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
25801 SDLoc DL(N);
25802 SDValue Op = N->getOperand(0);
25803 EVT VT = N->getValueType(0);
25804 EVT SrcVT = Op.getValueType();
25805
25806 if (VT == MVT::v2i16 && SrcVT == MVT::i32) {
25807 CustomNonLegalBITCASTResults(N, Results, DAG, MVT::v2i32, MVT::v4i16);
25808 return;
25809 }
25810
25811 if (VT == MVT::v4i8 && SrcVT == MVT::i32) {
25812 CustomNonLegalBITCASTResults(N, Results, DAG, MVT::v2i32, MVT::v8i8);
25813 return;
25814 }
25815
25816 if (VT == MVT::v2i8 && SrcVT == MVT::i16) {
25817 CustomNonLegalBITCASTResults(N, Results, DAG, MVT::v4i16, MVT::v8i8);
25818 return;
25819 }
25820
25821 if (VT.isScalableVector() && !isTypeLegal(VT) && isTypeLegal(SrcVT)) {
25822 assert(!VT.isFloatingPoint() && SrcVT.isFloatingPoint() &&
25823 "Expected fp->int bitcast!");
25824
25825 // Bitcasting between unpacked vector types of different element counts is
25826 // not a NOP because the live elements are laid out differently.
25827 // 01234567
25828 // e.g. nxv2i32 = XX??XX??
25829 // nxv4f16 = X?X?X?X?
25830 if (VT.getVectorElementCount() != SrcVT.getVectorElementCount())
25831 return;
25832
25833 SDValue CastResult = getSVESafeBitCast(getSVEContainerType(VT), Op, DAG);
25834 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, CastResult));
25835 return;
25836 }
25837
25838 if (SrcVT.isVector() && SrcVT.getVectorElementType() == MVT::i1 &&
25839 !VT.isVector())
25840 return replaceBoolVectorBitcast(N, Results, DAG);
25841
25842 if (VT != MVT::i16 || (SrcVT != MVT::f16 && SrcVT != MVT::bf16))
25843 return;
25844
25845 Op = DAG.getTargetInsertSubreg(AArch64::hsub, DL, MVT::f32,
25846 DAG.getUNDEF(MVT::i32), Op);
25847 Op = DAG.getNode(ISD::BITCAST, DL, MVT::i32, Op);
25848 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Op));
25849 }
25850
ReplaceAddWithADDP(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)25851 static void ReplaceAddWithADDP(SDNode *N, SmallVectorImpl<SDValue> &Results,
25852 SelectionDAG &DAG,
25853 const AArch64Subtarget *Subtarget) {
25854 EVT VT = N->getValueType(0);
25855 if (!VT.is256BitVector() ||
25856 (VT.getScalarType().isFloatingPoint() &&
25857 !N->getFlags().hasAllowReassociation()) ||
25858 (VT.getScalarType() == MVT::f16 && !Subtarget->hasFullFP16()) ||
25859 VT.getScalarType() == MVT::bf16)
25860 return;
25861
25862 SDValue X = N->getOperand(0);
25863 auto *Shuf = dyn_cast<ShuffleVectorSDNode>(N->getOperand(1));
25864 if (!Shuf) {
25865 Shuf = dyn_cast<ShuffleVectorSDNode>(N->getOperand(0));
25866 X = N->getOperand(1);
25867 if (!Shuf)
25868 return;
25869 }
25870
25871 if (Shuf->getOperand(0) != X || !Shuf->getOperand(1)->isUndef())
25872 return;
25873
25874 // Check the mask is 1,0,3,2,5,4,...
25875 ArrayRef<int> Mask = Shuf->getMask();
25876 for (int I = 0, E = Mask.size(); I < E; I++)
25877 if (Mask[I] != (I % 2 == 0 ? I + 1 : I - 1))
25878 return;
25879
25880 SDLoc DL(N);
25881 auto LoHi = DAG.SplitVector(X, DL);
25882 assert(LoHi.first.getValueType() == LoHi.second.getValueType());
25883 SDValue Addp = DAG.getNode(AArch64ISD::ADDP, N, LoHi.first.getValueType(),
25884 LoHi.first, LoHi.second);
25885
25886 // Shuffle the elements back into order.
25887 SmallVector<int> NMask;
25888 for (unsigned I = 0, E = VT.getVectorNumElements() / 2; I < E; I++) {
25889 NMask.push_back(I);
25890 NMask.push_back(I);
25891 }
25892 Results.push_back(
25893 DAG.getVectorShuffle(VT, DL,
25894 DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Addp,
25895 DAG.getUNDEF(LoHi.first.getValueType())),
25896 DAG.getUNDEF(VT), NMask));
25897 }
25898
ReplaceReductionResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG,unsigned InterOp,unsigned AcrossOp)25899 static void ReplaceReductionResults(SDNode *N,
25900 SmallVectorImpl<SDValue> &Results,
25901 SelectionDAG &DAG, unsigned InterOp,
25902 unsigned AcrossOp) {
25903 EVT LoVT, HiVT;
25904 SDValue Lo, Hi;
25905 SDLoc dl(N);
25906 std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(N->getValueType(0));
25907 std::tie(Lo, Hi) = DAG.SplitVectorOperand(N, 0);
25908 SDValue InterVal = DAG.getNode(InterOp, dl, LoVT, Lo, Hi);
25909 SDValue SplitVal = DAG.getNode(AcrossOp, dl, LoVT, InterVal);
25910 Results.push_back(SplitVal);
25911 }
25912
ReplaceExtractSubVectorResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG) const25913 void AArch64TargetLowering::ReplaceExtractSubVectorResults(
25914 SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
25915 SDValue In = N->getOperand(0);
25916 EVT InVT = In.getValueType();
25917
25918 // Common code will handle these just fine.
25919 if (!InVT.isScalableVector() || !InVT.isInteger())
25920 return;
25921
25922 SDLoc DL(N);
25923 EVT VT = N->getValueType(0);
25924
25925 // The following checks bail if this is not a halving operation.
25926
25927 ElementCount ResEC = VT.getVectorElementCount();
25928
25929 if (InVT.getVectorElementCount() != (ResEC * 2))
25930 return;
25931
25932 auto *CIndex = dyn_cast<ConstantSDNode>(N->getOperand(1));
25933 if (!CIndex)
25934 return;
25935
25936 unsigned Index = CIndex->getZExtValue();
25937 if ((Index != 0) && (Index != ResEC.getKnownMinValue()))
25938 return;
25939
25940 unsigned Opcode = (Index == 0) ? AArch64ISD::UUNPKLO : AArch64ISD::UUNPKHI;
25941 EVT ExtendedHalfVT = VT.widenIntegerVectorElementType(*DAG.getContext());
25942
25943 SDValue Half = DAG.getNode(Opcode, DL, ExtendedHalfVT, N->getOperand(0));
25944 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, Half));
25945 }
25946
25947 // Create an even/odd pair of X registers holding integer value V.
createGPRPairNode(SelectionDAG & DAG,SDValue V)25948 static SDValue createGPRPairNode(SelectionDAG &DAG, SDValue V) {
25949 SDLoc dl(V.getNode());
25950 auto [VLo, VHi] = DAG.SplitScalar(V, dl, MVT::i64, MVT::i64);
25951 if (DAG.getDataLayout().isBigEndian())
25952 std::swap (VLo, VHi);
25953 SDValue RegClass =
25954 DAG.getTargetConstant(AArch64::XSeqPairsClassRegClassID, dl, MVT::i32);
25955 SDValue SubReg0 = DAG.getTargetConstant(AArch64::sube64, dl, MVT::i32);
25956 SDValue SubReg1 = DAG.getTargetConstant(AArch64::subo64, dl, MVT::i32);
25957 const SDValue Ops[] = { RegClass, VLo, SubReg0, VHi, SubReg1 };
25958 return SDValue(
25959 DAG.getMachineNode(TargetOpcode::REG_SEQUENCE, dl, MVT::Untyped, Ops), 0);
25960 }
25961
ReplaceCMP_SWAP_128Results(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)25962 static void ReplaceCMP_SWAP_128Results(SDNode *N,
25963 SmallVectorImpl<SDValue> &Results,
25964 SelectionDAG &DAG,
25965 const AArch64Subtarget *Subtarget) {
25966 assert(N->getValueType(0) == MVT::i128 &&
25967 "AtomicCmpSwap on types less than 128 should be legal");
25968
25969 MachineMemOperand *MemOp = cast<MemSDNode>(N)->getMemOperand();
25970 if (Subtarget->hasLSE() || Subtarget->outlineAtomics()) {
25971 // LSE has a 128-bit compare and swap (CASP), but i128 is not a legal type,
25972 // so lower it here, wrapped in REG_SEQUENCE and EXTRACT_SUBREG.
25973 SDValue Ops[] = {
25974 createGPRPairNode(DAG, N->getOperand(2)), // Compare value
25975 createGPRPairNode(DAG, N->getOperand(3)), // Store value
25976 N->getOperand(1), // Ptr
25977 N->getOperand(0), // Chain in
25978 };
25979
25980 unsigned Opcode;
25981 switch (MemOp->getMergedOrdering()) {
25982 case AtomicOrdering::Monotonic:
25983 Opcode = AArch64::CASPX;
25984 break;
25985 case AtomicOrdering::Acquire:
25986 Opcode = AArch64::CASPAX;
25987 break;
25988 case AtomicOrdering::Release:
25989 Opcode = AArch64::CASPLX;
25990 break;
25991 case AtomicOrdering::AcquireRelease:
25992 case AtomicOrdering::SequentiallyConsistent:
25993 Opcode = AArch64::CASPALX;
25994 break;
25995 default:
25996 llvm_unreachable("Unexpected ordering!");
25997 }
25998
25999 MachineSDNode *CmpSwap = DAG.getMachineNode(
26000 Opcode, SDLoc(N), DAG.getVTList(MVT::Untyped, MVT::Other), Ops);
26001 DAG.setNodeMemRefs(CmpSwap, {MemOp});
26002
26003 unsigned SubReg1 = AArch64::sube64, SubReg2 = AArch64::subo64;
26004 if (DAG.getDataLayout().isBigEndian())
26005 std::swap(SubReg1, SubReg2);
26006 SDValue Lo = DAG.getTargetExtractSubreg(SubReg1, SDLoc(N), MVT::i64,
26007 SDValue(CmpSwap, 0));
26008 SDValue Hi = DAG.getTargetExtractSubreg(SubReg2, SDLoc(N), MVT::i64,
26009 SDValue(CmpSwap, 0));
26010 Results.push_back(
26011 DAG.getNode(ISD::BUILD_PAIR, SDLoc(N), MVT::i128, Lo, Hi));
26012 Results.push_back(SDValue(CmpSwap, 1)); // Chain out
26013 return;
26014 }
26015
26016 unsigned Opcode;
26017 switch (MemOp->getMergedOrdering()) {
26018 case AtomicOrdering::Monotonic:
26019 Opcode = AArch64::CMP_SWAP_128_MONOTONIC;
26020 break;
26021 case AtomicOrdering::Acquire:
26022 Opcode = AArch64::CMP_SWAP_128_ACQUIRE;
26023 break;
26024 case AtomicOrdering::Release:
26025 Opcode = AArch64::CMP_SWAP_128_RELEASE;
26026 break;
26027 case AtomicOrdering::AcquireRelease:
26028 case AtomicOrdering::SequentiallyConsistent:
26029 Opcode = AArch64::CMP_SWAP_128;
26030 break;
26031 default:
26032 llvm_unreachable("Unexpected ordering!");
26033 }
26034
26035 SDLoc DL(N);
26036 auto Desired = DAG.SplitScalar(N->getOperand(2), DL, MVT::i64, MVT::i64);
26037 auto New = DAG.SplitScalar(N->getOperand(3), DL, MVT::i64, MVT::i64);
26038 SDValue Ops[] = {N->getOperand(1), Desired.first, Desired.second,
26039 New.first, New.second, N->getOperand(0)};
26040 SDNode *CmpSwap = DAG.getMachineNode(
26041 Opcode, SDLoc(N), DAG.getVTList(MVT::i64, MVT::i64, MVT::i32, MVT::Other),
26042 Ops);
26043 DAG.setNodeMemRefs(cast<MachineSDNode>(CmpSwap), {MemOp});
26044
26045 Results.push_back(DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i128,
26046 SDValue(CmpSwap, 0), SDValue(CmpSwap, 1)));
26047 Results.push_back(SDValue(CmpSwap, 3));
26048 }
26049
getAtomicLoad128Opcode(unsigned ISDOpcode,AtomicOrdering Ordering)26050 static unsigned getAtomicLoad128Opcode(unsigned ISDOpcode,
26051 AtomicOrdering Ordering) {
26052 // ATOMIC_LOAD_CLR only appears when lowering ATOMIC_LOAD_AND (see
26053 // LowerATOMIC_LOAD_AND). We can't take that approach with 128-bit, because
26054 // the type is not legal. Therefore we shouldn't expect to see a 128-bit
26055 // ATOMIC_LOAD_CLR at any point.
26056 assert(ISDOpcode != ISD::ATOMIC_LOAD_CLR &&
26057 "ATOMIC_LOAD_AND should be lowered to LDCLRP directly");
26058 assert(ISDOpcode != ISD::ATOMIC_LOAD_ADD && "There is no 128 bit LDADD");
26059 assert(ISDOpcode != ISD::ATOMIC_LOAD_SUB && "There is no 128 bit LDSUB");
26060
26061 if (ISDOpcode == ISD::ATOMIC_LOAD_AND) {
26062 // The operand will need to be XORed in a separate step.
26063 switch (Ordering) {
26064 case AtomicOrdering::Monotonic:
26065 return AArch64::LDCLRP;
26066 break;
26067 case AtomicOrdering::Acquire:
26068 return AArch64::LDCLRPA;
26069 break;
26070 case AtomicOrdering::Release:
26071 return AArch64::LDCLRPL;
26072 break;
26073 case AtomicOrdering::AcquireRelease:
26074 case AtomicOrdering::SequentiallyConsistent:
26075 return AArch64::LDCLRPAL;
26076 break;
26077 default:
26078 llvm_unreachable("Unexpected ordering!");
26079 }
26080 }
26081
26082 if (ISDOpcode == ISD::ATOMIC_LOAD_OR) {
26083 switch (Ordering) {
26084 case AtomicOrdering::Monotonic:
26085 return AArch64::LDSETP;
26086 break;
26087 case AtomicOrdering::Acquire:
26088 return AArch64::LDSETPA;
26089 break;
26090 case AtomicOrdering::Release:
26091 return AArch64::LDSETPL;
26092 break;
26093 case AtomicOrdering::AcquireRelease:
26094 case AtomicOrdering::SequentiallyConsistent:
26095 return AArch64::LDSETPAL;
26096 break;
26097 default:
26098 llvm_unreachable("Unexpected ordering!");
26099 }
26100 }
26101
26102 if (ISDOpcode == ISD::ATOMIC_SWAP) {
26103 switch (Ordering) {
26104 case AtomicOrdering::Monotonic:
26105 return AArch64::SWPP;
26106 break;
26107 case AtomicOrdering::Acquire:
26108 return AArch64::SWPPA;
26109 break;
26110 case AtomicOrdering::Release:
26111 return AArch64::SWPPL;
26112 break;
26113 case AtomicOrdering::AcquireRelease:
26114 case AtomicOrdering::SequentiallyConsistent:
26115 return AArch64::SWPPAL;
26116 break;
26117 default:
26118 llvm_unreachable("Unexpected ordering!");
26119 }
26120 }
26121
26122 llvm_unreachable("Unexpected ISDOpcode!");
26123 }
26124
ReplaceATOMIC_LOAD_128Results(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG,const AArch64Subtarget * Subtarget)26125 static void ReplaceATOMIC_LOAD_128Results(SDNode *N,
26126 SmallVectorImpl<SDValue> &Results,
26127 SelectionDAG &DAG,
26128 const AArch64Subtarget *Subtarget) {
26129 // LSE128 has a 128-bit RMW ops, but i128 is not a legal type, so lower it
26130 // here. This follows the approach of the CMP_SWAP_XXX pseudo instructions
26131 // rather than the CASP instructions, because CASP has register classes for
26132 // the pairs of registers and therefore uses REG_SEQUENCE and EXTRACT_SUBREG
26133 // to present them as single operands. LSE128 instructions use the GPR64
26134 // register class (because the pair does not have to be sequential), like
26135 // CMP_SWAP_XXX, and therefore we use TRUNCATE and BUILD_PAIR.
26136
26137 assert(N->getValueType(0) == MVT::i128 &&
26138 "AtomicLoadXXX on types less than 128 should be legal");
26139
26140 if (!Subtarget->hasLSE128())
26141 return;
26142
26143 MachineMemOperand *MemOp = cast<MemSDNode>(N)->getMemOperand();
26144 const SDValue &Chain = N->getOperand(0);
26145 const SDValue &Ptr = N->getOperand(1);
26146 const SDValue &Val128 = N->getOperand(2);
26147 std::pair<SDValue, SDValue> Val2x64 =
26148 DAG.SplitScalar(Val128, SDLoc(Val128), MVT::i64, MVT::i64);
26149
26150 const unsigned ISDOpcode = N->getOpcode();
26151 const unsigned MachineOpcode =
26152 getAtomicLoad128Opcode(ISDOpcode, MemOp->getMergedOrdering());
26153
26154 if (ISDOpcode == ISD::ATOMIC_LOAD_AND) {
26155 SDLoc dl(Val128);
26156 Val2x64.first =
26157 DAG.getNode(ISD::XOR, dl, MVT::i64,
26158 DAG.getConstant(-1ULL, dl, MVT::i64), Val2x64.first);
26159 Val2x64.second =
26160 DAG.getNode(ISD::XOR, dl, MVT::i64,
26161 DAG.getConstant(-1ULL, dl, MVT::i64), Val2x64.second);
26162 }
26163
26164 SDValue Ops[] = {Val2x64.first, Val2x64.second, Ptr, Chain};
26165 if (DAG.getDataLayout().isBigEndian())
26166 std::swap(Ops[0], Ops[1]);
26167
26168 MachineSDNode *AtomicInst =
26169 DAG.getMachineNode(MachineOpcode, SDLoc(N),
26170 DAG.getVTList(MVT::i64, MVT::i64, MVT::Other), Ops);
26171
26172 DAG.setNodeMemRefs(AtomicInst, {MemOp});
26173
26174 SDValue Lo = SDValue(AtomicInst, 0), Hi = SDValue(AtomicInst, 1);
26175 if (DAG.getDataLayout().isBigEndian())
26176 std::swap(Lo, Hi);
26177
26178 Results.push_back(DAG.getNode(ISD::BUILD_PAIR, SDLoc(N), MVT::i128, Lo, Hi));
26179 Results.push_back(SDValue(AtomicInst, 2)); // Chain out
26180 }
26181
ReplaceNodeResults(SDNode * N,SmallVectorImpl<SDValue> & Results,SelectionDAG & DAG) const26182 void AArch64TargetLowering::ReplaceNodeResults(
26183 SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
26184 switch (N->getOpcode()) {
26185 default:
26186 llvm_unreachable("Don't know how to custom expand this");
26187 case ISD::BITCAST:
26188 ReplaceBITCASTResults(N, Results, DAG);
26189 return;
26190 case ISD::VECREDUCE_ADD:
26191 case ISD::VECREDUCE_SMAX:
26192 case ISD::VECREDUCE_SMIN:
26193 case ISD::VECREDUCE_UMAX:
26194 case ISD::VECREDUCE_UMIN:
26195 Results.push_back(LowerVECREDUCE(SDValue(N, 0), DAG));
26196 return;
26197 case ISD::ADD:
26198 case ISD::FADD:
26199 ReplaceAddWithADDP(N, Results, DAG, Subtarget);
26200 return;
26201
26202 case ISD::CTPOP:
26203 case ISD::PARITY:
26204 if (SDValue Result = LowerCTPOP_PARITY(SDValue(N, 0), DAG))
26205 Results.push_back(Result);
26206 return;
26207 case AArch64ISD::SADDV:
26208 ReplaceReductionResults(N, Results, DAG, ISD::ADD, AArch64ISD::SADDV);
26209 return;
26210 case AArch64ISD::UADDV:
26211 ReplaceReductionResults(N, Results, DAG, ISD::ADD, AArch64ISD::UADDV);
26212 return;
26213 case AArch64ISD::SMINV:
26214 ReplaceReductionResults(N, Results, DAG, ISD::SMIN, AArch64ISD::SMINV);
26215 return;
26216 case AArch64ISD::UMINV:
26217 ReplaceReductionResults(N, Results, DAG, ISD::UMIN, AArch64ISD::UMINV);
26218 return;
26219 case AArch64ISD::SMAXV:
26220 ReplaceReductionResults(N, Results, DAG, ISD::SMAX, AArch64ISD::SMAXV);
26221 return;
26222 case AArch64ISD::UMAXV:
26223 ReplaceReductionResults(N, Results, DAG, ISD::UMAX, AArch64ISD::UMAXV);
26224 return;
26225 case ISD::MULHS:
26226 if (useSVEForFixedLengthVectorVT(SDValue(N, 0).getValueType()))
26227 Results.push_back(
26228 LowerToPredicatedOp(SDValue(N, 0), DAG, AArch64ISD::MULHS_PRED));
26229 return;
26230 case ISD::MULHU:
26231 if (useSVEForFixedLengthVectorVT(SDValue(N, 0).getValueType()))
26232 Results.push_back(
26233 LowerToPredicatedOp(SDValue(N, 0), DAG, AArch64ISD::MULHU_PRED));
26234 return;
26235 case ISD::FP_TO_UINT:
26236 case ISD::FP_TO_SINT:
26237 case ISD::STRICT_FP_TO_SINT:
26238 case ISD::STRICT_FP_TO_UINT:
26239 assert(N->getValueType(0) == MVT::i128 && "unexpected illegal conversion");
26240 // Let normal code take care of it by not adding anything to Results.
26241 return;
26242 case ISD::ATOMIC_CMP_SWAP:
26243 ReplaceCMP_SWAP_128Results(N, Results, DAG, Subtarget);
26244 return;
26245 case ISD::ATOMIC_LOAD_CLR:
26246 assert(N->getValueType(0) != MVT::i128 &&
26247 "128-bit ATOMIC_LOAD_AND should be lowered directly to LDCLRP");
26248 break;
26249 case ISD::ATOMIC_LOAD_AND:
26250 case ISD::ATOMIC_LOAD_OR:
26251 case ISD::ATOMIC_SWAP: {
26252 assert(cast<AtomicSDNode>(N)->getVal().getValueType() == MVT::i128 &&
26253 "Expected 128-bit atomicrmw.");
26254 // These need custom type legalisation so we go directly to instruction.
26255 ReplaceATOMIC_LOAD_128Results(N, Results, DAG, Subtarget);
26256 return;
26257 }
26258 case ISD::ATOMIC_LOAD:
26259 case ISD::LOAD: {
26260 MemSDNode *LoadNode = cast<MemSDNode>(N);
26261 EVT MemVT = LoadNode->getMemoryVT();
26262 // Handle lowering 256 bit non temporal loads into LDNP for little-endian
26263 // targets.
26264 if (LoadNode->isNonTemporal() && Subtarget->isLittleEndian() &&
26265 MemVT.getSizeInBits() == 256u &&
26266 (MemVT.getScalarSizeInBits() == 8u ||
26267 MemVT.getScalarSizeInBits() == 16u ||
26268 MemVT.getScalarSizeInBits() == 32u ||
26269 MemVT.getScalarSizeInBits() == 64u)) {
26270
26271 SDValue Result = DAG.getMemIntrinsicNode(
26272 AArch64ISD::LDNP, SDLoc(N),
26273 DAG.getVTList({MemVT.getHalfNumVectorElementsVT(*DAG.getContext()),
26274 MemVT.getHalfNumVectorElementsVT(*DAG.getContext()),
26275 MVT::Other}),
26276 {LoadNode->getChain(), LoadNode->getBasePtr()},
26277 LoadNode->getMemoryVT(), LoadNode->getMemOperand());
26278
26279 SDValue Pair = DAG.getNode(ISD::CONCAT_VECTORS, SDLoc(N), MemVT,
26280 Result.getValue(0), Result.getValue(1));
26281 Results.append({Pair, Result.getValue(2) /* Chain */});
26282 return;
26283 }
26284
26285 if ((!LoadNode->isVolatile() && !LoadNode->isAtomic()) ||
26286 LoadNode->getMemoryVT() != MVT::i128) {
26287 // Non-volatile or atomic loads are optimized later in AArch64's load/store
26288 // optimizer.
26289 return;
26290 }
26291
26292 if (SDValue(N, 0).getValueType() == MVT::i128) {
26293 auto *AN = dyn_cast<AtomicSDNode>(LoadNode);
26294 bool isLoadAcquire =
26295 AN && AN->getSuccessOrdering() == AtomicOrdering::Acquire;
26296 unsigned Opcode = isLoadAcquire ? AArch64ISD::LDIAPP : AArch64ISD::LDP;
26297
26298 if (isLoadAcquire)
26299 assert(Subtarget->hasFeature(AArch64::FeatureRCPC3));
26300
26301 SDValue Result = DAG.getMemIntrinsicNode(
26302 Opcode, SDLoc(N), DAG.getVTList({MVT::i64, MVT::i64, MVT::Other}),
26303 {LoadNode->getChain(), LoadNode->getBasePtr()},
26304 LoadNode->getMemoryVT(), LoadNode->getMemOperand());
26305
26306 unsigned FirstRes = DAG.getDataLayout().isBigEndian() ? 1 : 0;
26307
26308 SDValue Pair =
26309 DAG.getNode(ISD::BUILD_PAIR, SDLoc(N), MVT::i128,
26310 Result.getValue(FirstRes), Result.getValue(1 - FirstRes));
26311 Results.append({Pair, Result.getValue(2) /* Chain */});
26312 }
26313 return;
26314 }
26315 case ISD::EXTRACT_SUBVECTOR:
26316 ReplaceExtractSubVectorResults(N, Results, DAG);
26317 return;
26318 case ISD::INSERT_SUBVECTOR:
26319 case ISD::CONCAT_VECTORS:
26320 // Custom lowering has been requested for INSERT_SUBVECTOR and
26321 // CONCAT_VECTORS -- but delegate to common code for result type
26322 // legalisation
26323 return;
26324 case ISD::INTRINSIC_WO_CHAIN: {
26325 EVT VT = N->getValueType(0);
26326
26327 Intrinsic::ID IntID =
26328 static_cast<Intrinsic::ID>(N->getConstantOperandVal(0));
26329 switch (IntID) {
26330 default:
26331 return;
26332 case Intrinsic::aarch64_sve_clasta_n: {
26333 assert((VT == MVT::i8 || VT == MVT::i16) &&
26334 "custom lowering for unexpected type");
26335 SDLoc DL(N);
26336 auto Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, N->getOperand(2));
26337 auto V = DAG.getNode(AArch64ISD::CLASTA_N, DL, MVT::i32,
26338 N->getOperand(1), Op2, N->getOperand(3));
26339 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
26340 return;
26341 }
26342 case Intrinsic::aarch64_sve_clastb_n: {
26343 assert((VT == MVT::i8 || VT == MVT::i16) &&
26344 "custom lowering for unexpected type");
26345 SDLoc DL(N);
26346 auto Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, N->getOperand(2));
26347 auto V = DAG.getNode(AArch64ISD::CLASTB_N, DL, MVT::i32,
26348 N->getOperand(1), Op2, N->getOperand(3));
26349 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
26350 return;
26351 }
26352 case Intrinsic::aarch64_sve_lasta: {
26353 assert((VT == MVT::i8 || VT == MVT::i16) &&
26354 "custom lowering for unexpected type");
26355 SDLoc DL(N);
26356 auto V = DAG.getNode(AArch64ISD::LASTA, DL, MVT::i32,
26357 N->getOperand(1), N->getOperand(2));
26358 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
26359 return;
26360 }
26361 case Intrinsic::aarch64_sve_lastb: {
26362 assert((VT == MVT::i8 || VT == MVT::i16) &&
26363 "custom lowering for unexpected type");
26364 SDLoc DL(N);
26365 auto V = DAG.getNode(AArch64ISD::LASTB, DL, MVT::i32,
26366 N->getOperand(1), N->getOperand(2));
26367 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
26368 return;
26369 }
26370 case Intrinsic::get_active_lane_mask: {
26371 if (!VT.isFixedLengthVector() || VT.getVectorElementType() != MVT::i1)
26372 return;
26373
26374 // NOTE: Only trivial type promotion is supported.
26375 EVT NewVT = getTypeToTransformTo(*DAG.getContext(), VT);
26376 if (NewVT.getVectorNumElements() != VT.getVectorNumElements())
26377 return;
26378
26379 SDLoc DL(N);
26380 auto V = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, NewVT, N->ops());
26381 Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, VT, V));
26382 return;
26383 }
26384 }
26385 }
26386 case ISD::READ_REGISTER: {
26387 SDLoc DL(N);
26388 assert(N->getValueType(0) == MVT::i128 &&
26389 "READ_REGISTER custom lowering is only for 128-bit sysregs");
26390 SDValue Chain = N->getOperand(0);
26391 SDValue SysRegName = N->getOperand(1);
26392
26393 SDValue Result = DAG.getNode(
26394 AArch64ISD::MRRS, DL, DAG.getVTList({MVT::i64, MVT::i64, MVT::Other}),
26395 Chain, SysRegName);
26396
26397 // Sysregs are not endian. Result.getValue(0) always contains the lower half
26398 // of the 128-bit System Register value.
26399 SDValue Pair = DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i128,
26400 Result.getValue(0), Result.getValue(1));
26401 Results.push_back(Pair);
26402 Results.push_back(Result.getValue(2)); // Chain
26403 return;
26404 }
26405 }
26406 }
26407
useLoadStackGuardNode() const26408 bool AArch64TargetLowering::useLoadStackGuardNode() const {
26409 if (Subtarget->isTargetAndroid() || Subtarget->isTargetFuchsia())
26410 return TargetLowering::useLoadStackGuardNode();
26411 return true;
26412 }
26413
combineRepeatedFPDivisors() const26414 unsigned AArch64TargetLowering::combineRepeatedFPDivisors() const {
26415 // Combine multiple FDIVs with the same divisor into multiple FMULs by the
26416 // reciprocal if there are three or more FDIVs.
26417 return 3;
26418 }
26419
26420 TargetLoweringBase::LegalizeTypeAction
getPreferredVectorAction(MVT VT) const26421 AArch64TargetLowering::getPreferredVectorAction(MVT VT) const {
26422 // During type legalization, we prefer to widen v1i8, v1i16, v1i32 to v8i8,
26423 // v4i16, v2i32 instead of to promote.
26424 if (VT == MVT::v1i8 || VT == MVT::v1i16 || VT == MVT::v1i32 ||
26425 VT == MVT::v1f32)
26426 return TypeWidenVector;
26427
26428 return TargetLoweringBase::getPreferredVectorAction(VT);
26429 }
26430
26431 // In v8.4a, ldp and stp instructions are guaranteed to be single-copy atomic
26432 // provided the address is 16-byte aligned.
isOpSuitableForLDPSTP(const Instruction * I) const26433 bool AArch64TargetLowering::isOpSuitableForLDPSTP(const Instruction *I) const {
26434 if (!Subtarget->hasLSE2())
26435 return false;
26436
26437 if (auto LI = dyn_cast<LoadInst>(I))
26438 return LI->getType()->getPrimitiveSizeInBits() == 128 &&
26439 LI->getAlign() >= Align(16);
26440
26441 if (auto SI = dyn_cast<StoreInst>(I))
26442 return SI->getValueOperand()->getType()->getPrimitiveSizeInBits() == 128 &&
26443 SI->getAlign() >= Align(16);
26444
26445 return false;
26446 }
26447
isOpSuitableForLSE128(const Instruction * I) const26448 bool AArch64TargetLowering::isOpSuitableForLSE128(const Instruction *I) const {
26449 if (!Subtarget->hasLSE128())
26450 return false;
26451
26452 // Only use SWPP for stores where LSE2 would require a fence. Unlike STP, SWPP
26453 // will clobber the two registers.
26454 if (const auto *SI = dyn_cast<StoreInst>(I))
26455 return SI->getValueOperand()->getType()->getPrimitiveSizeInBits() == 128 &&
26456 SI->getAlign() >= Align(16) &&
26457 (SI->getOrdering() == AtomicOrdering::Release ||
26458 SI->getOrdering() == AtomicOrdering::SequentiallyConsistent);
26459
26460 if (const auto *RMW = dyn_cast<AtomicRMWInst>(I))
26461 return RMW->getValOperand()->getType()->getPrimitiveSizeInBits() == 128 &&
26462 RMW->getAlign() >= Align(16) &&
26463 (RMW->getOperation() == AtomicRMWInst::Xchg ||
26464 RMW->getOperation() == AtomicRMWInst::And ||
26465 RMW->getOperation() == AtomicRMWInst::Or);
26466
26467 return false;
26468 }
26469
isOpSuitableForRCPC3(const Instruction * I) const26470 bool AArch64TargetLowering::isOpSuitableForRCPC3(const Instruction *I) const {
26471 if (!Subtarget->hasLSE2() || !Subtarget->hasRCPC3())
26472 return false;
26473
26474 if (auto LI = dyn_cast<LoadInst>(I))
26475 return LI->getType()->getPrimitiveSizeInBits() == 128 &&
26476 LI->getAlign() >= Align(16) &&
26477 LI->getOrdering() == AtomicOrdering::Acquire;
26478
26479 if (auto SI = dyn_cast<StoreInst>(I))
26480 return SI->getValueOperand()->getType()->getPrimitiveSizeInBits() == 128 &&
26481 SI->getAlign() >= Align(16) &&
26482 SI->getOrdering() == AtomicOrdering::Release;
26483
26484 return false;
26485 }
26486
shouldInsertFencesForAtomic(const Instruction * I) const26487 bool AArch64TargetLowering::shouldInsertFencesForAtomic(
26488 const Instruction *I) const {
26489 if (isOpSuitableForRCPC3(I))
26490 return false;
26491 if (isOpSuitableForLSE128(I))
26492 return false;
26493 if (isOpSuitableForLDPSTP(I))
26494 return true;
26495 return false;
26496 }
26497
shouldInsertTrailingFenceForAtomicStore(const Instruction * I) const26498 bool AArch64TargetLowering::shouldInsertTrailingFenceForAtomicStore(
26499 const Instruction *I) const {
26500 // Store-Release instructions only provide seq_cst guarantees when paired with
26501 // Load-Acquire instructions. MSVC CRT does not use these instructions to
26502 // implement seq_cst loads and stores, so we need additional explicit fences
26503 // after memory writes.
26504 if (!Subtarget->getTargetTriple().isWindowsMSVCEnvironment())
26505 return false;
26506
26507 switch (I->getOpcode()) {
26508 default:
26509 return false;
26510 case Instruction::AtomicCmpXchg:
26511 return cast<AtomicCmpXchgInst>(I)->getSuccessOrdering() ==
26512 AtomicOrdering::SequentiallyConsistent;
26513 case Instruction::AtomicRMW:
26514 return cast<AtomicRMWInst>(I)->getOrdering() ==
26515 AtomicOrdering::SequentiallyConsistent;
26516 case Instruction::Store:
26517 return cast<StoreInst>(I)->getOrdering() ==
26518 AtomicOrdering::SequentiallyConsistent;
26519 }
26520 }
26521
26522 // Loads and stores less than 128-bits are already atomic; ones above that
26523 // are doomed anyway, so defer to the default libcall and blame the OS when
26524 // things go wrong.
26525 TargetLoweringBase::AtomicExpansionKind
shouldExpandAtomicStoreInIR(StoreInst * SI) const26526 AArch64TargetLowering::shouldExpandAtomicStoreInIR(StoreInst *SI) const {
26527 unsigned Size = SI->getValueOperand()->getType()->getPrimitiveSizeInBits();
26528 if (Size != 128)
26529 return AtomicExpansionKind::None;
26530 if (isOpSuitableForRCPC3(SI))
26531 return AtomicExpansionKind::None;
26532 if (isOpSuitableForLSE128(SI))
26533 return AtomicExpansionKind::Expand;
26534 if (isOpSuitableForLDPSTP(SI))
26535 return AtomicExpansionKind::None;
26536 return AtomicExpansionKind::Expand;
26537 }
26538
26539 // Loads and stores less than 128-bits are already atomic; ones above that
26540 // are doomed anyway, so defer to the default libcall and blame the OS when
26541 // things go wrong.
26542 TargetLowering::AtomicExpansionKind
shouldExpandAtomicLoadInIR(LoadInst * LI) const26543 AArch64TargetLowering::shouldExpandAtomicLoadInIR(LoadInst *LI) const {
26544 unsigned Size = LI->getType()->getPrimitiveSizeInBits();
26545
26546 if (Size != 128)
26547 return AtomicExpansionKind::None;
26548 if (isOpSuitableForRCPC3(LI))
26549 return AtomicExpansionKind::None;
26550 // No LSE128 loads
26551 if (isOpSuitableForLDPSTP(LI))
26552 return AtomicExpansionKind::None;
26553
26554 // At -O0, fast-regalloc cannot cope with the live vregs necessary to
26555 // implement atomicrmw without spilling. If the target address is also on the
26556 // stack and close enough to the spill slot, this can lead to a situation
26557 // where the monitor always gets cleared and the atomic operation can never
26558 // succeed. So at -O0 lower this operation to a CAS loop.
26559 if (getTargetMachine().getOptLevel() == CodeGenOptLevel::None)
26560 return AtomicExpansionKind::CmpXChg;
26561
26562 // Using CAS for an atomic load has a better chance of succeeding under high
26563 // contention situations. So use it if available.
26564 return Subtarget->hasLSE() ? AtomicExpansionKind::CmpXChg
26565 : AtomicExpansionKind::LLSC;
26566 }
26567
26568 // The "default" for integer RMW operations is to expand to an LL/SC loop.
26569 // However, with the LSE instructions (or outline-atomics mode, which provides
26570 // library routines in place of the LSE-instructions), we can directly emit many
26571 // operations instead.
26572 //
26573 // Floating-point operations are always emitted to a cmpxchg loop, because they
26574 // may trigger a trap which aborts an LLSC sequence.
26575 TargetLowering::AtomicExpansionKind
shouldExpandAtomicRMWInIR(AtomicRMWInst * AI) const26576 AArch64TargetLowering::shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const {
26577 unsigned Size = AI->getType()->getPrimitiveSizeInBits();
26578 assert(Size <= 128 && "AtomicExpandPass should've handled larger sizes.");
26579
26580 if (AI->isFloatingPointOperation())
26581 return AtomicExpansionKind::CmpXChg;
26582
26583 bool CanUseLSE128 = Subtarget->hasLSE128() && Size == 128 &&
26584 (AI->getOperation() == AtomicRMWInst::Xchg ||
26585 AI->getOperation() == AtomicRMWInst::Or ||
26586 AI->getOperation() == AtomicRMWInst::And);
26587 if (CanUseLSE128)
26588 return AtomicExpansionKind::None;
26589
26590 // Nand is not supported in LSE.
26591 // Leave 128 bits to LLSC or CmpXChg.
26592 if (AI->getOperation() != AtomicRMWInst::Nand && Size < 128) {
26593 if (Subtarget->hasLSE())
26594 return AtomicExpansionKind::None;
26595 if (Subtarget->outlineAtomics()) {
26596 // [U]Min/[U]Max RWM atomics are used in __sync_fetch_ libcalls so far.
26597 // Don't outline them unless
26598 // (1) high level <atomic> support approved:
26599 // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p0493r1.pdf
26600 // (2) low level libgcc and compiler-rt support implemented by:
26601 // min/max outline atomics helpers
26602 if (AI->getOperation() != AtomicRMWInst::Min &&
26603 AI->getOperation() != AtomicRMWInst::Max &&
26604 AI->getOperation() != AtomicRMWInst::UMin &&
26605 AI->getOperation() != AtomicRMWInst::UMax) {
26606 return AtomicExpansionKind::None;
26607 }
26608 }
26609 }
26610
26611 // At -O0, fast-regalloc cannot cope with the live vregs necessary to
26612 // implement atomicrmw without spilling. If the target address is also on the
26613 // stack and close enough to the spill slot, this can lead to a situation
26614 // where the monitor always gets cleared and the atomic operation can never
26615 // succeed. So at -O0 lower this operation to a CAS loop. Also worthwhile if
26616 // we have a single CAS instruction that can replace the loop.
26617 if (getTargetMachine().getOptLevel() == CodeGenOptLevel::None ||
26618 Subtarget->hasLSE())
26619 return AtomicExpansionKind::CmpXChg;
26620
26621 return AtomicExpansionKind::LLSC;
26622 }
26623
26624 TargetLowering::AtomicExpansionKind
shouldExpandAtomicCmpXchgInIR(AtomicCmpXchgInst * AI) const26625 AArch64TargetLowering::shouldExpandAtomicCmpXchgInIR(
26626 AtomicCmpXchgInst *AI) const {
26627 // If subtarget has LSE, leave cmpxchg intact for codegen.
26628 if (Subtarget->hasLSE() || Subtarget->outlineAtomics())
26629 return AtomicExpansionKind::None;
26630 // At -O0, fast-regalloc cannot cope with the live vregs necessary to
26631 // implement cmpxchg without spilling. If the address being exchanged is also
26632 // on the stack and close enough to the spill slot, this can lead to a
26633 // situation where the monitor always gets cleared and the atomic operation
26634 // can never succeed. So at -O0 we need a late-expanded pseudo-inst instead.
26635 if (getTargetMachine().getOptLevel() == CodeGenOptLevel::None)
26636 return AtomicExpansionKind::None;
26637
26638 // 128-bit atomic cmpxchg is weird; AtomicExpand doesn't know how to expand
26639 // it.
26640 unsigned Size = AI->getCompareOperand()->getType()->getPrimitiveSizeInBits();
26641 if (Size > 64)
26642 return AtomicExpansionKind::None;
26643
26644 return AtomicExpansionKind::LLSC;
26645 }
26646
emitLoadLinked(IRBuilderBase & Builder,Type * ValueTy,Value * Addr,AtomicOrdering Ord) const26647 Value *AArch64TargetLowering::emitLoadLinked(IRBuilderBase &Builder,
26648 Type *ValueTy, Value *Addr,
26649 AtomicOrdering Ord) const {
26650 Module *M = Builder.GetInsertBlock()->getParent()->getParent();
26651 bool IsAcquire = isAcquireOrStronger(Ord);
26652
26653 // Since i128 isn't legal and intrinsics don't get type-lowered, the ldrexd
26654 // intrinsic must return {i64, i64} and we have to recombine them into a
26655 // single i128 here.
26656 if (ValueTy->getPrimitiveSizeInBits() == 128) {
26657 Intrinsic::ID Int =
26658 IsAcquire ? Intrinsic::aarch64_ldaxp : Intrinsic::aarch64_ldxp;
26659 Function *Ldxr = Intrinsic::getDeclaration(M, Int);
26660
26661 Value *LoHi = Builder.CreateCall(Ldxr, Addr, "lohi");
26662
26663 Value *Lo = Builder.CreateExtractValue(LoHi, 0, "lo");
26664 Value *Hi = Builder.CreateExtractValue(LoHi, 1, "hi");
26665 Lo = Builder.CreateZExt(Lo, ValueTy, "lo64");
26666 Hi = Builder.CreateZExt(Hi, ValueTy, "hi64");
26667 return Builder.CreateOr(
26668 Lo, Builder.CreateShl(Hi, ConstantInt::get(ValueTy, 64)), "val64");
26669 }
26670
26671 Type *Tys[] = { Addr->getType() };
26672 Intrinsic::ID Int =
26673 IsAcquire ? Intrinsic::aarch64_ldaxr : Intrinsic::aarch64_ldxr;
26674 Function *Ldxr = Intrinsic::getDeclaration(M, Int, Tys);
26675
26676 const DataLayout &DL = M->getDataLayout();
26677 IntegerType *IntEltTy = Builder.getIntNTy(DL.getTypeSizeInBits(ValueTy));
26678 CallInst *CI = Builder.CreateCall(Ldxr, Addr);
26679 CI->addParamAttr(
26680 0, Attribute::get(Builder.getContext(), Attribute::ElementType, ValueTy));
26681 Value *Trunc = Builder.CreateTrunc(CI, IntEltTy);
26682
26683 return Builder.CreateBitCast(Trunc, ValueTy);
26684 }
26685
emitAtomicCmpXchgNoStoreLLBalance(IRBuilderBase & Builder) const26686 void AArch64TargetLowering::emitAtomicCmpXchgNoStoreLLBalance(
26687 IRBuilderBase &Builder) const {
26688 Module *M = Builder.GetInsertBlock()->getParent()->getParent();
26689 Builder.CreateCall(Intrinsic::getDeclaration(M, Intrinsic::aarch64_clrex));
26690 }
26691
emitStoreConditional(IRBuilderBase & Builder,Value * Val,Value * Addr,AtomicOrdering Ord) const26692 Value *AArch64TargetLowering::emitStoreConditional(IRBuilderBase &Builder,
26693 Value *Val, Value *Addr,
26694 AtomicOrdering Ord) const {
26695 Module *M = Builder.GetInsertBlock()->getParent()->getParent();
26696 bool IsRelease = isReleaseOrStronger(Ord);
26697
26698 // Since the intrinsics must have legal type, the i128 intrinsics take two
26699 // parameters: "i64, i64". We must marshal Val into the appropriate form
26700 // before the call.
26701 if (Val->getType()->getPrimitiveSizeInBits() == 128) {
26702 Intrinsic::ID Int =
26703 IsRelease ? Intrinsic::aarch64_stlxp : Intrinsic::aarch64_stxp;
26704 Function *Stxr = Intrinsic::getDeclaration(M, Int);
26705 Type *Int64Ty = Type::getInt64Ty(M->getContext());
26706
26707 Value *Lo = Builder.CreateTrunc(Val, Int64Ty, "lo");
26708 Value *Hi = Builder.CreateTrunc(Builder.CreateLShr(Val, 64), Int64Ty, "hi");
26709 return Builder.CreateCall(Stxr, {Lo, Hi, Addr});
26710 }
26711
26712 Intrinsic::ID Int =
26713 IsRelease ? Intrinsic::aarch64_stlxr : Intrinsic::aarch64_stxr;
26714 Type *Tys[] = { Addr->getType() };
26715 Function *Stxr = Intrinsic::getDeclaration(M, Int, Tys);
26716
26717 const DataLayout &DL = M->getDataLayout();
26718 IntegerType *IntValTy = Builder.getIntNTy(DL.getTypeSizeInBits(Val->getType()));
26719 Val = Builder.CreateBitCast(Val, IntValTy);
26720
26721 CallInst *CI = Builder.CreateCall(
26722 Stxr, {Builder.CreateZExtOrBitCast(
26723 Val, Stxr->getFunctionType()->getParamType(0)),
26724 Addr});
26725 CI->addParamAttr(1, Attribute::get(Builder.getContext(),
26726 Attribute::ElementType, Val->getType()));
26727 return CI;
26728 }
26729
functionArgumentNeedsConsecutiveRegisters(Type * Ty,CallingConv::ID CallConv,bool isVarArg,const DataLayout & DL) const26730 bool AArch64TargetLowering::functionArgumentNeedsConsecutiveRegisters(
26731 Type *Ty, CallingConv::ID CallConv, bool isVarArg,
26732 const DataLayout &DL) const {
26733 if (!Ty->isArrayTy()) {
26734 const TypeSize &TySize = Ty->getPrimitiveSizeInBits();
26735 return TySize.isScalable() && TySize.getKnownMinValue() > 128;
26736 }
26737
26738 // All non aggregate members of the type must have the same type
26739 SmallVector<EVT> ValueVTs;
26740 ComputeValueVTs(*this, DL, Ty, ValueVTs);
26741 return all_equal(ValueVTs);
26742 }
26743
shouldNormalizeToSelectSequence(LLVMContext &,EVT) const26744 bool AArch64TargetLowering::shouldNormalizeToSelectSequence(LLVMContext &,
26745 EVT) const {
26746 return false;
26747 }
26748
UseTlsOffset(IRBuilderBase & IRB,unsigned Offset)26749 static Value *UseTlsOffset(IRBuilderBase &IRB, unsigned Offset) {
26750 Module *M = IRB.GetInsertBlock()->getParent()->getParent();
26751 Function *ThreadPointerFunc =
26752 Intrinsic::getDeclaration(M, Intrinsic::thread_pointer);
26753 return IRB.CreatePointerCast(
26754 IRB.CreateConstGEP1_32(IRB.getInt8Ty(), IRB.CreateCall(ThreadPointerFunc),
26755 Offset),
26756 IRB.getPtrTy(0));
26757 }
26758
getIRStackGuard(IRBuilderBase & IRB) const26759 Value *AArch64TargetLowering::getIRStackGuard(IRBuilderBase &IRB) const {
26760 // Android provides a fixed TLS slot for the stack cookie. See the definition
26761 // of TLS_SLOT_STACK_GUARD in
26762 // https://android.googlesource.com/platform/bionic/+/main/libc/platform/bionic/tls_defines.h
26763 if (Subtarget->isTargetAndroid())
26764 return UseTlsOffset(IRB, 0x28);
26765
26766 // Fuchsia is similar.
26767 // <zircon/tls.h> defines ZX_TLS_STACK_GUARD_OFFSET with this value.
26768 if (Subtarget->isTargetFuchsia())
26769 return UseTlsOffset(IRB, -0x10);
26770
26771 return TargetLowering::getIRStackGuard(IRB);
26772 }
26773
insertSSPDeclarations(Module & M) const26774 void AArch64TargetLowering::insertSSPDeclarations(Module &M) const {
26775 // MSVC CRT provides functionalities for stack protection.
26776 if (Subtarget->getTargetTriple().isWindowsMSVCEnvironment()) {
26777 // MSVC CRT has a global variable holding security cookie.
26778 M.getOrInsertGlobal("__security_cookie",
26779 PointerType::getUnqual(M.getContext()));
26780
26781 // MSVC CRT has a function to validate security cookie.
26782 FunctionCallee SecurityCheckCookie =
26783 M.getOrInsertFunction(Subtarget->getSecurityCheckCookieName(),
26784 Type::getVoidTy(M.getContext()),
26785 PointerType::getUnqual(M.getContext()));
26786 if (Function *F = dyn_cast<Function>(SecurityCheckCookie.getCallee())) {
26787 F->setCallingConv(CallingConv::Win64);
26788 F->addParamAttr(0, Attribute::AttrKind::InReg);
26789 }
26790 return;
26791 }
26792 TargetLowering::insertSSPDeclarations(M);
26793 }
26794
getSDagStackGuard(const Module & M) const26795 Value *AArch64TargetLowering::getSDagStackGuard(const Module &M) const {
26796 // MSVC CRT has a global variable holding security cookie.
26797 if (Subtarget->getTargetTriple().isWindowsMSVCEnvironment())
26798 return M.getGlobalVariable("__security_cookie");
26799 return TargetLowering::getSDagStackGuard(M);
26800 }
26801
getSSPStackGuardCheck(const Module & M) const26802 Function *AArch64TargetLowering::getSSPStackGuardCheck(const Module &M) const {
26803 // MSVC CRT has a function to validate security cookie.
26804 if (Subtarget->getTargetTriple().isWindowsMSVCEnvironment())
26805 return M.getFunction(Subtarget->getSecurityCheckCookieName());
26806 return TargetLowering::getSSPStackGuardCheck(M);
26807 }
26808
26809 Value *
getSafeStackPointerLocation(IRBuilderBase & IRB) const26810 AArch64TargetLowering::getSafeStackPointerLocation(IRBuilderBase &IRB) const {
26811 // Android provides a fixed TLS slot for the SafeStack pointer. See the
26812 // definition of TLS_SLOT_SAFESTACK in
26813 // https://android.googlesource.com/platform/bionic/+/master/libc/private/bionic_tls.h
26814 if (Subtarget->isTargetAndroid())
26815 return UseTlsOffset(IRB, 0x48);
26816
26817 // Fuchsia is similar.
26818 // <zircon/tls.h> defines ZX_TLS_UNSAFE_SP_OFFSET with this value.
26819 if (Subtarget->isTargetFuchsia())
26820 return UseTlsOffset(IRB, -0x8);
26821
26822 return TargetLowering::getSafeStackPointerLocation(IRB);
26823 }
26824
isMaskAndCmp0FoldingBeneficial(const Instruction & AndI) const26825 bool AArch64TargetLowering::isMaskAndCmp0FoldingBeneficial(
26826 const Instruction &AndI) const {
26827 // Only sink 'and' mask to cmp use block if it is masking a single bit, since
26828 // this is likely to be fold the and/cmp/br into a single tbz instruction. It
26829 // may be beneficial to sink in other cases, but we would have to check that
26830 // the cmp would not get folded into the br to form a cbz for these to be
26831 // beneficial.
26832 ConstantInt* Mask = dyn_cast<ConstantInt>(AndI.getOperand(1));
26833 if (!Mask)
26834 return false;
26835 return Mask->getValue().isPowerOf2();
26836 }
26837
26838 bool AArch64TargetLowering::
shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(SDValue X,ConstantSDNode * XC,ConstantSDNode * CC,SDValue Y,unsigned OldShiftOpcode,unsigned NewShiftOpcode,SelectionDAG & DAG) const26839 shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(
26840 SDValue X, ConstantSDNode *XC, ConstantSDNode *CC, SDValue Y,
26841 unsigned OldShiftOpcode, unsigned NewShiftOpcode,
26842 SelectionDAG &DAG) const {
26843 // Does baseline recommend not to perform the fold by default?
26844 if (!TargetLowering::shouldProduceAndByConstByHoistingConstFromShiftsLHSOfAnd(
26845 X, XC, CC, Y, OldShiftOpcode, NewShiftOpcode, DAG))
26846 return false;
26847 // Else, if this is a vector shift, prefer 'shl'.
26848 return X.getValueType().isScalarInteger() || NewShiftOpcode == ISD::SHL;
26849 }
26850
26851 TargetLowering::ShiftLegalizationStrategy
preferredShiftLegalizationStrategy(SelectionDAG & DAG,SDNode * N,unsigned int ExpansionFactor) const26852 AArch64TargetLowering::preferredShiftLegalizationStrategy(
26853 SelectionDAG &DAG, SDNode *N, unsigned int ExpansionFactor) const {
26854 if (DAG.getMachineFunction().getFunction().hasMinSize() &&
26855 !Subtarget->isTargetWindows() && !Subtarget->isTargetDarwin())
26856 return ShiftLegalizationStrategy::LowerToLibcall;
26857 return TargetLowering::preferredShiftLegalizationStrategy(DAG, N,
26858 ExpansionFactor);
26859 }
26860
initializeSplitCSR(MachineBasicBlock * Entry) const26861 void AArch64TargetLowering::initializeSplitCSR(MachineBasicBlock *Entry) const {
26862 // Update IsSplitCSR in AArch64unctionInfo.
26863 AArch64FunctionInfo *AFI = Entry->getParent()->getInfo<AArch64FunctionInfo>();
26864 AFI->setIsSplitCSR(true);
26865 }
26866
insertCopiesSplitCSR(MachineBasicBlock * Entry,const SmallVectorImpl<MachineBasicBlock * > & Exits) const26867 void AArch64TargetLowering::insertCopiesSplitCSR(
26868 MachineBasicBlock *Entry,
26869 const SmallVectorImpl<MachineBasicBlock *> &Exits) const {
26870 const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo();
26871 const MCPhysReg *IStart = TRI->getCalleeSavedRegsViaCopy(Entry->getParent());
26872 if (!IStart)
26873 return;
26874
26875 const TargetInstrInfo *TII = Subtarget->getInstrInfo();
26876 MachineRegisterInfo *MRI = &Entry->getParent()->getRegInfo();
26877 MachineBasicBlock::iterator MBBI = Entry->begin();
26878 for (const MCPhysReg *I = IStart; *I; ++I) {
26879 const TargetRegisterClass *RC = nullptr;
26880 if (AArch64::GPR64RegClass.contains(*I))
26881 RC = &AArch64::GPR64RegClass;
26882 else if (AArch64::FPR64RegClass.contains(*I))
26883 RC = &AArch64::FPR64RegClass;
26884 else
26885 llvm_unreachable("Unexpected register class in CSRsViaCopy!");
26886
26887 Register NewVR = MRI->createVirtualRegister(RC);
26888 // Create copy from CSR to a virtual register.
26889 // FIXME: this currently does not emit CFI pseudo-instructions, it works
26890 // fine for CXX_FAST_TLS since the C++-style TLS access functions should be
26891 // nounwind. If we want to generalize this later, we may need to emit
26892 // CFI pseudo-instructions.
26893 assert(Entry->getParent()->getFunction().hasFnAttribute(
26894 Attribute::NoUnwind) &&
26895 "Function should be nounwind in insertCopiesSplitCSR!");
26896 Entry->addLiveIn(*I);
26897 BuildMI(*Entry, MBBI, DebugLoc(), TII->get(TargetOpcode::COPY), NewVR)
26898 .addReg(*I);
26899
26900 // Insert the copy-back instructions right before the terminator.
26901 for (auto *Exit : Exits)
26902 BuildMI(*Exit, Exit->getFirstTerminator(), DebugLoc(),
26903 TII->get(TargetOpcode::COPY), *I)
26904 .addReg(NewVR);
26905 }
26906 }
26907
isIntDivCheap(EVT VT,AttributeList Attr) const26908 bool AArch64TargetLowering::isIntDivCheap(EVT VT, AttributeList Attr) const {
26909 // Integer division on AArch64 is expensive. However, when aggressively
26910 // optimizing for code size, we prefer to use a div instruction, as it is
26911 // usually smaller than the alternative sequence.
26912 // The exception to this is vector division. Since AArch64 doesn't have vector
26913 // integer division, leaving the division as-is is a loss even in terms of
26914 // size, because it will have to be scalarized, while the alternative code
26915 // sequence can be performed in vector form.
26916 bool OptSize = Attr.hasFnAttr(Attribute::MinSize);
26917 return OptSize && !VT.isVector();
26918 }
26919
preferIncOfAddToSubOfNot(EVT VT) const26920 bool AArch64TargetLowering::preferIncOfAddToSubOfNot(EVT VT) const {
26921 // We want inc-of-add for scalars and sub-of-not for vectors.
26922 return VT.isScalarInteger();
26923 }
26924
shouldConvertFpToSat(unsigned Op,EVT FPVT,EVT VT) const26925 bool AArch64TargetLowering::shouldConvertFpToSat(unsigned Op, EVT FPVT,
26926 EVT VT) const {
26927 // v8f16 without fp16 need to be extended to v8f32, which is more difficult to
26928 // legalize.
26929 if (FPVT == MVT::v8f16 && !Subtarget->hasFullFP16())
26930 return false;
26931 if (FPVT == MVT::v8bf16)
26932 return false;
26933 return TargetLowering::shouldConvertFpToSat(Op, FPVT, VT);
26934 }
26935
26936 MachineInstr *
EmitKCFICheck(MachineBasicBlock & MBB,MachineBasicBlock::instr_iterator & MBBI,const TargetInstrInfo * TII) const26937 AArch64TargetLowering::EmitKCFICheck(MachineBasicBlock &MBB,
26938 MachineBasicBlock::instr_iterator &MBBI,
26939 const TargetInstrInfo *TII) const {
26940 assert(MBBI->isCall() && MBBI->getCFIType() &&
26941 "Invalid call instruction for a KCFI check");
26942
26943 switch (MBBI->getOpcode()) {
26944 case AArch64::BLR:
26945 case AArch64::BLRNoIP:
26946 case AArch64::TCRETURNri:
26947 case AArch64::TCRETURNrix16x17:
26948 case AArch64::TCRETURNrix17:
26949 case AArch64::TCRETURNrinotx16:
26950 break;
26951 default:
26952 llvm_unreachable("Unexpected CFI call opcode");
26953 }
26954
26955 MachineOperand &Target = MBBI->getOperand(0);
26956 assert(Target.isReg() && "Invalid target operand for an indirect call");
26957 Target.setIsRenamable(false);
26958
26959 return BuildMI(MBB, MBBI, MBBI->getDebugLoc(), TII->get(AArch64::KCFI_CHECK))
26960 .addReg(Target.getReg())
26961 .addImm(MBBI->getCFIType())
26962 .getInstr();
26963 }
26964
enableAggressiveFMAFusion(EVT VT) const26965 bool AArch64TargetLowering::enableAggressiveFMAFusion(EVT VT) const {
26966 return Subtarget->hasAggressiveFMA() && VT.isFloatingPoint();
26967 }
26968
26969 unsigned
getVaListSizeInBits(const DataLayout & DL) const26970 AArch64TargetLowering::getVaListSizeInBits(const DataLayout &DL) const {
26971 if (Subtarget->isTargetDarwin() || Subtarget->isTargetWindows())
26972 return getPointerTy(DL).getSizeInBits();
26973
26974 return 3 * getPointerTy(DL).getSizeInBits() + 2 * 32;
26975 }
26976
finalizeLowering(MachineFunction & MF) const26977 void AArch64TargetLowering::finalizeLowering(MachineFunction &MF) const {
26978 MachineFrameInfo &MFI = MF.getFrameInfo();
26979 // If we have any vulnerable SVE stack objects then the stack protector
26980 // needs to be placed at the top of the SVE stack area, as the SVE locals
26981 // are placed above the other locals, so we allocate it as if it were a
26982 // scalable vector.
26983 // FIXME: It may be worthwhile having a specific interface for this rather
26984 // than doing it here in finalizeLowering.
26985 if (MFI.hasStackProtectorIndex()) {
26986 for (unsigned int i = 0, e = MFI.getObjectIndexEnd(); i != e; ++i) {
26987 if (MFI.getStackID(i) == TargetStackID::ScalableVector &&
26988 MFI.getObjectSSPLayout(i) != MachineFrameInfo::SSPLK_None) {
26989 MFI.setStackID(MFI.getStackProtectorIndex(),
26990 TargetStackID::ScalableVector);
26991 MFI.setObjectAlignment(MFI.getStackProtectorIndex(), Align(16));
26992 break;
26993 }
26994 }
26995 }
26996 MFI.computeMaxCallFrameSize(MF);
26997 TargetLoweringBase::finalizeLowering(MF);
26998 }
26999
27000 // Unlike X86, we let frame lowering assign offsets to all catch objects.
needsFixedCatchObjects() const27001 bool AArch64TargetLowering::needsFixedCatchObjects() const {
27002 return false;
27003 }
27004
shouldLocalize(const MachineInstr & MI,const TargetTransformInfo * TTI) const27005 bool AArch64TargetLowering::shouldLocalize(
27006 const MachineInstr &MI, const TargetTransformInfo *TTI) const {
27007 auto &MF = *MI.getMF();
27008 auto &MRI = MF.getRegInfo();
27009 auto maxUses = [](unsigned RematCost) {
27010 // A cost of 1 means remats are basically free.
27011 if (RematCost == 1)
27012 return std::numeric_limits<unsigned>::max();
27013 if (RematCost == 2)
27014 return 2U;
27015
27016 // Remat is too expensive, only sink if there's one user.
27017 if (RematCost > 2)
27018 return 1U;
27019 llvm_unreachable("Unexpected remat cost");
27020 };
27021
27022 unsigned Opc = MI.getOpcode();
27023 switch (Opc) {
27024 case TargetOpcode::G_GLOBAL_VALUE: {
27025 // On Darwin, TLS global vars get selected into function calls, which
27026 // we don't want localized, as they can get moved into the middle of a
27027 // another call sequence.
27028 const GlobalValue &GV = *MI.getOperand(1).getGlobal();
27029 if (GV.isThreadLocal() && Subtarget->isTargetMachO())
27030 return false;
27031 return true; // Always localize G_GLOBAL_VALUE to avoid high reg pressure.
27032 }
27033 case TargetOpcode::G_FCONSTANT:
27034 case TargetOpcode::G_CONSTANT: {
27035 const ConstantInt *CI;
27036 unsigned AdditionalCost = 0;
27037
27038 if (Opc == TargetOpcode::G_CONSTANT)
27039 CI = MI.getOperand(1).getCImm();
27040 else {
27041 LLT Ty = MRI.getType(MI.getOperand(0).getReg());
27042 // We try to estimate cost of 32/64b fpimms, as they'll likely be
27043 // materialized as integers.
27044 if (Ty.getScalarSizeInBits() != 32 && Ty.getScalarSizeInBits() != 64)
27045 break;
27046 auto APF = MI.getOperand(1).getFPImm()->getValueAPF();
27047 bool OptForSize =
27048 MF.getFunction().hasOptSize() || MF.getFunction().hasMinSize();
27049 if (isFPImmLegal(APF, EVT::getFloatingPointVT(Ty.getScalarSizeInBits()),
27050 OptForSize))
27051 return true; // Constant should be cheap.
27052 CI =
27053 ConstantInt::get(MF.getFunction().getContext(), APF.bitcastToAPInt());
27054 // FP materialization also costs an extra move, from gpr to fpr.
27055 AdditionalCost = 1;
27056 }
27057 APInt Imm = CI->getValue();
27058 InstructionCost Cost = TTI->getIntImmCost(
27059 Imm, CI->getType(), TargetTransformInfo::TCK_CodeSize);
27060 assert(Cost.isValid() && "Expected a valid imm cost");
27061
27062 unsigned RematCost = *Cost.getValue();
27063 RematCost += AdditionalCost;
27064 Register Reg = MI.getOperand(0).getReg();
27065 unsigned MaxUses = maxUses(RematCost);
27066 // Don't pass UINT_MAX sentinel value to hasAtMostUserInstrs().
27067 if (MaxUses == std::numeric_limits<unsigned>::max())
27068 --MaxUses;
27069 return MRI.hasAtMostUserInstrs(Reg, MaxUses);
27070 }
27071 // If we legalized G_GLOBAL_VALUE into ADRP + G_ADD_LOW, mark both as being
27072 // localizable.
27073 case AArch64::ADRP:
27074 case AArch64::G_ADD_LOW:
27075 // Need to localize G_PTR_ADD so that G_GLOBAL_VALUE can be localized too.
27076 case TargetOpcode::G_PTR_ADD:
27077 return true;
27078 default:
27079 break;
27080 }
27081 return TargetLoweringBase::shouldLocalize(MI, TTI);
27082 }
27083
fallBackToDAGISel(const Instruction & Inst) const27084 bool AArch64TargetLowering::fallBackToDAGISel(const Instruction &Inst) const {
27085 // Fallback for scalable vectors.
27086 // Note that if EnableSVEGISel is true, we allow scalable vector types for
27087 // all instructions, regardless of whether they are actually supported.
27088 if (!EnableSVEGISel) {
27089 if (Inst.getType()->isScalableTy()) {
27090 return true;
27091 }
27092
27093 for (unsigned i = 0; i < Inst.getNumOperands(); ++i)
27094 if (Inst.getOperand(i)->getType()->isScalableTy())
27095 return true;
27096
27097 if (const AllocaInst *AI = dyn_cast<AllocaInst>(&Inst)) {
27098 if (AI->getAllocatedType()->isScalableTy())
27099 return true;
27100 }
27101 }
27102
27103 // Checks to allow the use of SME instructions
27104 if (auto *Base = dyn_cast<CallBase>(&Inst)) {
27105 auto CallerAttrs = SMEAttrs(*Inst.getFunction());
27106 auto CalleeAttrs = SMEAttrs(*Base);
27107 if (CallerAttrs.requiresSMChange(CalleeAttrs) ||
27108 CallerAttrs.requiresLazySave(CalleeAttrs) ||
27109 CallerAttrs.requiresPreservingZT0(CalleeAttrs))
27110 return true;
27111 }
27112 return false;
27113 }
27114
27115 // Return the largest legal scalable vector type that matches VT's element type.
getContainerForFixedLengthVector(SelectionDAG & DAG,EVT VT)27116 static EVT getContainerForFixedLengthVector(SelectionDAG &DAG, EVT VT) {
27117 assert(VT.isFixedLengthVector() &&
27118 DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
27119 "Expected legal fixed length vector!");
27120 switch (VT.getVectorElementType().getSimpleVT().SimpleTy) {
27121 default:
27122 llvm_unreachable("unexpected element type for SVE container");
27123 case MVT::i8:
27124 return EVT(MVT::nxv16i8);
27125 case MVT::i16:
27126 return EVT(MVT::nxv8i16);
27127 case MVT::i32:
27128 return EVT(MVT::nxv4i32);
27129 case MVT::i64:
27130 return EVT(MVT::nxv2i64);
27131 case MVT::bf16:
27132 return EVT(MVT::nxv8bf16);
27133 case MVT::f16:
27134 return EVT(MVT::nxv8f16);
27135 case MVT::f32:
27136 return EVT(MVT::nxv4f32);
27137 case MVT::f64:
27138 return EVT(MVT::nxv2f64);
27139 }
27140 }
27141
27142 // Return a PTRUE with active lanes corresponding to the extent of VT.
getPredicateForFixedLengthVector(SelectionDAG & DAG,SDLoc & DL,EVT VT)27143 static SDValue getPredicateForFixedLengthVector(SelectionDAG &DAG, SDLoc &DL,
27144 EVT VT) {
27145 assert(VT.isFixedLengthVector() &&
27146 DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
27147 "Expected legal fixed length vector!");
27148
27149 std::optional<unsigned> PgPattern =
27150 getSVEPredPatternFromNumElements(VT.getVectorNumElements());
27151 assert(PgPattern && "Unexpected element count for SVE predicate");
27152
27153 // For vectors that are exactly getMaxSVEVectorSizeInBits big, we can use
27154 // AArch64SVEPredPattern::all, which can enable the use of unpredicated
27155 // variants of instructions when available.
27156 const auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
27157 unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits();
27158 unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits();
27159 if (MaxSVESize && MinSVESize == MaxSVESize &&
27160 MaxSVESize == VT.getSizeInBits())
27161 PgPattern = AArch64SVEPredPattern::all;
27162
27163 MVT MaskVT;
27164 switch (VT.getVectorElementType().getSimpleVT().SimpleTy) {
27165 default:
27166 llvm_unreachable("unexpected element type for SVE predicate");
27167 case MVT::i8:
27168 MaskVT = MVT::nxv16i1;
27169 break;
27170 case MVT::i16:
27171 case MVT::f16:
27172 case MVT::bf16:
27173 MaskVT = MVT::nxv8i1;
27174 break;
27175 case MVT::i32:
27176 case MVT::f32:
27177 MaskVT = MVT::nxv4i1;
27178 break;
27179 case MVT::i64:
27180 case MVT::f64:
27181 MaskVT = MVT::nxv2i1;
27182 break;
27183 }
27184
27185 return getPTrue(DAG, DL, MaskVT, *PgPattern);
27186 }
27187
getPredicateForScalableVector(SelectionDAG & DAG,SDLoc & DL,EVT VT)27188 static SDValue getPredicateForScalableVector(SelectionDAG &DAG, SDLoc &DL,
27189 EVT VT) {
27190 assert(VT.isScalableVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
27191 "Expected legal scalable vector!");
27192 auto PredTy = VT.changeVectorElementType(MVT::i1);
27193 return getPTrue(DAG, DL, PredTy, AArch64SVEPredPattern::all);
27194 }
27195
getPredicateForVector(SelectionDAG & DAG,SDLoc & DL,EVT VT)27196 static SDValue getPredicateForVector(SelectionDAG &DAG, SDLoc &DL, EVT VT) {
27197 if (VT.isFixedLengthVector())
27198 return getPredicateForFixedLengthVector(DAG, DL, VT);
27199
27200 return getPredicateForScalableVector(DAG, DL, VT);
27201 }
27202
27203 // Grow V to consume an entire SVE register.
convertToScalableVector(SelectionDAG & DAG,EVT VT,SDValue V)27204 static SDValue convertToScalableVector(SelectionDAG &DAG, EVT VT, SDValue V) {
27205 assert(VT.isScalableVector() &&
27206 "Expected to convert into a scalable vector!");
27207 assert(V.getValueType().isFixedLengthVector() &&
27208 "Expected a fixed length vector operand!");
27209 SDLoc DL(V);
27210 SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
27211 return DAG.getNode(ISD::INSERT_SUBVECTOR, DL, VT, DAG.getUNDEF(VT), V, Zero);
27212 }
27213
27214 // Shrink V so it's just big enough to maintain a VT's worth of data.
convertFromScalableVector(SelectionDAG & DAG,EVT VT,SDValue V)27215 static SDValue convertFromScalableVector(SelectionDAG &DAG, EVT VT, SDValue V) {
27216 assert(VT.isFixedLengthVector() &&
27217 "Expected to convert into a fixed length vector!");
27218 assert(V.getValueType().isScalableVector() &&
27219 "Expected a scalable vector operand!");
27220 SDLoc DL(V);
27221 SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
27222 return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, V, Zero);
27223 }
27224
27225 // Convert all fixed length vector loads larger than NEON to masked_loads.
LowerFixedLengthVectorLoadToSVE(SDValue Op,SelectionDAG & DAG) const27226 SDValue AArch64TargetLowering::LowerFixedLengthVectorLoadToSVE(
27227 SDValue Op, SelectionDAG &DAG) const {
27228 auto Load = cast<LoadSDNode>(Op);
27229
27230 SDLoc DL(Op);
27231 EVT VT = Op.getValueType();
27232 EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
27233 EVT LoadVT = ContainerVT;
27234 EVT MemVT = Load->getMemoryVT();
27235
27236 auto Pg = getPredicateForFixedLengthVector(DAG, DL, VT);
27237
27238 if (VT.isFloatingPoint()) {
27239 LoadVT = ContainerVT.changeTypeToInteger();
27240 MemVT = MemVT.changeTypeToInteger();
27241 }
27242
27243 SDValue NewLoad = DAG.getMaskedLoad(
27244 LoadVT, DL, Load->getChain(), Load->getBasePtr(), Load->getOffset(), Pg,
27245 DAG.getUNDEF(LoadVT), MemVT, Load->getMemOperand(),
27246 Load->getAddressingMode(), Load->getExtensionType());
27247
27248 SDValue Result = NewLoad;
27249 if (VT.isFloatingPoint() && Load->getExtensionType() == ISD::EXTLOAD) {
27250 EVT ExtendVT = ContainerVT.changeVectorElementType(
27251 Load->getMemoryVT().getVectorElementType());
27252
27253 Result = getSVESafeBitCast(ExtendVT, Result, DAG);
27254 Result = DAG.getNode(AArch64ISD::FP_EXTEND_MERGE_PASSTHRU, DL, ContainerVT,
27255 Pg, Result, DAG.getUNDEF(ContainerVT));
27256 } else if (VT.isFloatingPoint()) {
27257 Result = DAG.getNode(ISD::BITCAST, DL, ContainerVT, Result);
27258 }
27259
27260 Result = convertFromScalableVector(DAG, VT, Result);
27261 SDValue MergedValues[2] = {Result, NewLoad.getValue(1)};
27262 return DAG.getMergeValues(MergedValues, DL);
27263 }
27264
convertFixedMaskToScalableVector(SDValue Mask,SelectionDAG & DAG)27265 static SDValue convertFixedMaskToScalableVector(SDValue Mask,
27266 SelectionDAG &DAG) {
27267 SDLoc DL(Mask);
27268 EVT InVT = Mask.getValueType();
27269 EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
27270
27271 auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
27272
27273 if (ISD::isBuildVectorAllOnes(Mask.getNode()))
27274 return Pg;
27275
27276 auto Op1 = convertToScalableVector(DAG, ContainerVT, Mask);
27277 auto Op2 = DAG.getConstant(0, DL, ContainerVT);
27278
27279 return DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, Pg.getValueType(),
27280 {Pg, Op1, Op2, DAG.getCondCode(ISD::SETNE)});
27281 }
27282
27283 // Convert all fixed length vector loads larger than NEON to masked_loads.
LowerFixedLengthVectorMLoadToSVE(SDValue Op,SelectionDAG & DAG) const27284 SDValue AArch64TargetLowering::LowerFixedLengthVectorMLoadToSVE(
27285 SDValue Op, SelectionDAG &DAG) const {
27286 auto Load = cast<MaskedLoadSDNode>(Op);
27287
27288 SDLoc DL(Op);
27289 EVT VT = Op.getValueType();
27290 EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
27291
27292 SDValue Mask = Load->getMask();
27293 // If this is an extending load and the mask type is not the same as
27294 // load's type then we have to extend the mask type.
27295 if (VT.getScalarSizeInBits() > Mask.getValueType().getScalarSizeInBits()) {
27296 assert(Load->getExtensionType() != ISD::NON_EXTLOAD &&
27297 "Incorrect mask type");
27298 Mask = DAG.getNode(ISD::ANY_EXTEND, DL, VT, Mask);
27299 }
27300 Mask = convertFixedMaskToScalableVector(Mask, DAG);
27301
27302 SDValue PassThru;
27303 bool IsPassThruZeroOrUndef = false;
27304
27305 if (Load->getPassThru()->isUndef()) {
27306 PassThru = DAG.getUNDEF(ContainerVT);
27307 IsPassThruZeroOrUndef = true;
27308 } else {
27309 if (ContainerVT.isInteger())
27310 PassThru = DAG.getConstant(0, DL, ContainerVT);
27311 else
27312 PassThru = DAG.getConstantFP(0, DL, ContainerVT);
27313 if (isZerosVector(Load->getPassThru().getNode()))
27314 IsPassThruZeroOrUndef = true;
27315 }
27316
27317 SDValue NewLoad = DAG.getMaskedLoad(
27318 ContainerVT, DL, Load->getChain(), Load->getBasePtr(), Load->getOffset(),
27319 Mask, PassThru, Load->getMemoryVT(), Load->getMemOperand(),
27320 Load->getAddressingMode(), Load->getExtensionType());
27321
27322 SDValue Result = NewLoad;
27323 if (!IsPassThruZeroOrUndef) {
27324 SDValue OldPassThru =
27325 convertToScalableVector(DAG, ContainerVT, Load->getPassThru());
27326 Result = DAG.getSelect(DL, ContainerVT, Mask, Result, OldPassThru);
27327 }
27328
27329 Result = convertFromScalableVector(DAG, VT, Result);
27330 SDValue MergedValues[2] = {Result, NewLoad.getValue(1)};
27331 return DAG.getMergeValues(MergedValues, DL);
27332 }
27333
27334 // Convert all fixed length vector stores larger than NEON to masked_stores.
LowerFixedLengthVectorStoreToSVE(SDValue Op,SelectionDAG & DAG) const27335 SDValue AArch64TargetLowering::LowerFixedLengthVectorStoreToSVE(
27336 SDValue Op, SelectionDAG &DAG) const {
27337 auto Store = cast<StoreSDNode>(Op);
27338
27339 SDLoc DL(Op);
27340 EVT VT = Store->getValue().getValueType();
27341 EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
27342 EVT MemVT = Store->getMemoryVT();
27343
27344 auto Pg = getPredicateForFixedLengthVector(DAG, DL, VT);
27345 auto NewValue = convertToScalableVector(DAG, ContainerVT, Store->getValue());
27346
27347 if (VT.isFloatingPoint() && Store->isTruncatingStore()) {
27348 EVT TruncVT = ContainerVT.changeVectorElementType(
27349 Store->getMemoryVT().getVectorElementType());
27350 MemVT = MemVT.changeTypeToInteger();
27351 NewValue = DAG.getNode(AArch64ISD::FP_ROUND_MERGE_PASSTHRU, DL, TruncVT, Pg,
27352 NewValue, DAG.getTargetConstant(0, DL, MVT::i64),
27353 DAG.getUNDEF(TruncVT));
27354 NewValue =
27355 getSVESafeBitCast(ContainerVT.changeTypeToInteger(), NewValue, DAG);
27356 } else if (VT.isFloatingPoint()) {
27357 MemVT = MemVT.changeTypeToInteger();
27358 NewValue =
27359 getSVESafeBitCast(ContainerVT.changeTypeToInteger(), NewValue, DAG);
27360 }
27361
27362 return DAG.getMaskedStore(Store->getChain(), DL, NewValue,
27363 Store->getBasePtr(), Store->getOffset(), Pg, MemVT,
27364 Store->getMemOperand(), Store->getAddressingMode(),
27365 Store->isTruncatingStore());
27366 }
27367
LowerFixedLengthVectorMStoreToSVE(SDValue Op,SelectionDAG & DAG) const27368 SDValue AArch64TargetLowering::LowerFixedLengthVectorMStoreToSVE(
27369 SDValue Op, SelectionDAG &DAG) const {
27370 auto *Store = cast<MaskedStoreSDNode>(Op);
27371
27372 SDLoc DL(Op);
27373 EVT VT = Store->getValue().getValueType();
27374 EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
27375
27376 auto NewValue = convertToScalableVector(DAG, ContainerVT, Store->getValue());
27377 SDValue Mask = convertFixedMaskToScalableVector(Store->getMask(), DAG);
27378
27379 return DAG.getMaskedStore(
27380 Store->getChain(), DL, NewValue, Store->getBasePtr(), Store->getOffset(),
27381 Mask, Store->getMemoryVT(), Store->getMemOperand(),
27382 Store->getAddressingMode(), Store->isTruncatingStore());
27383 }
27384
LowerFixedLengthVectorIntDivideToSVE(SDValue Op,SelectionDAG & DAG) const27385 SDValue AArch64TargetLowering::LowerFixedLengthVectorIntDivideToSVE(
27386 SDValue Op, SelectionDAG &DAG) const {
27387 SDLoc dl(Op);
27388 EVT VT = Op.getValueType();
27389 EVT EltVT = VT.getVectorElementType();
27390
27391 bool Signed = Op.getOpcode() == ISD::SDIV;
27392 unsigned PredOpcode = Signed ? AArch64ISD::SDIV_PRED : AArch64ISD::UDIV_PRED;
27393
27394 bool Negated;
27395 uint64_t SplatVal;
27396 if (Signed && isPow2Splat(Op.getOperand(1), SplatVal, Negated)) {
27397 EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
27398 SDValue Op1 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(0));
27399 SDValue Op2 = DAG.getTargetConstant(Log2_64(SplatVal), dl, MVT::i32);
27400
27401 SDValue Pg = getPredicateForFixedLengthVector(DAG, dl, VT);
27402 SDValue Res =
27403 DAG.getNode(AArch64ISD::SRAD_MERGE_OP1, dl, ContainerVT, Pg, Op1, Op2);
27404 if (Negated)
27405 Res = DAG.getNode(ISD::SUB, dl, ContainerVT,
27406 DAG.getConstant(0, dl, ContainerVT), Res);
27407
27408 return convertFromScalableVector(DAG, VT, Res);
27409 }
27410
27411 // Scalable vector i32/i64 DIV is supported.
27412 if (EltVT == MVT::i32 || EltVT == MVT::i64)
27413 return LowerToPredicatedOp(Op, DAG, PredOpcode);
27414
27415 // Scalable vector i8/i16 DIV is not supported. Promote it to i32.
27416 EVT HalfVT = VT.getHalfNumVectorElementsVT(*DAG.getContext());
27417 EVT PromVT = HalfVT.widenIntegerVectorElementType(*DAG.getContext());
27418 unsigned ExtendOpcode = Signed ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND;
27419
27420 // If the wider type is legal: extend, op, and truncate.
27421 EVT WideVT = VT.widenIntegerVectorElementType(*DAG.getContext());
27422 if (DAG.getTargetLoweringInfo().isTypeLegal(WideVT)) {
27423 SDValue Op0 = DAG.getNode(ExtendOpcode, dl, WideVT, Op.getOperand(0));
27424 SDValue Op1 = DAG.getNode(ExtendOpcode, dl, WideVT, Op.getOperand(1));
27425 SDValue Div = DAG.getNode(Op.getOpcode(), dl, WideVT, Op0, Op1);
27426 return DAG.getNode(ISD::TRUNCATE, dl, VT, Div);
27427 }
27428
27429 auto HalveAndExtendVector = [&DAG, &dl, &HalfVT, &PromVT,
27430 &ExtendOpcode](SDValue Op) {
27431 SDValue IdxZero = DAG.getConstant(0, dl, MVT::i64);
27432 SDValue IdxHalf =
27433 DAG.getConstant(HalfVT.getVectorNumElements(), dl, MVT::i64);
27434 SDValue Lo = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, HalfVT, Op, IdxZero);
27435 SDValue Hi = DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, HalfVT, Op, IdxHalf);
27436 return std::pair<SDValue, SDValue>(
27437 {DAG.getNode(ExtendOpcode, dl, PromVT, Lo),
27438 DAG.getNode(ExtendOpcode, dl, PromVT, Hi)});
27439 };
27440
27441 // If wider type is not legal: split, extend, op, trunc and concat.
27442 auto [Op0LoExt, Op0HiExt] = HalveAndExtendVector(Op.getOperand(0));
27443 auto [Op1LoExt, Op1HiExt] = HalveAndExtendVector(Op.getOperand(1));
27444 SDValue Lo = DAG.getNode(Op.getOpcode(), dl, PromVT, Op0LoExt, Op1LoExt);
27445 SDValue Hi = DAG.getNode(Op.getOpcode(), dl, PromVT, Op0HiExt, Op1HiExt);
27446 SDValue LoTrunc = DAG.getNode(ISD::TRUNCATE, dl, HalfVT, Lo);
27447 SDValue HiTrunc = DAG.getNode(ISD::TRUNCATE, dl, HalfVT, Hi);
27448 return DAG.getNode(ISD::CONCAT_VECTORS, dl, VT, {LoTrunc, HiTrunc});
27449 }
27450
LowerFixedLengthVectorIntExtendToSVE(SDValue Op,SelectionDAG & DAG) const27451 SDValue AArch64TargetLowering::LowerFixedLengthVectorIntExtendToSVE(
27452 SDValue Op, SelectionDAG &DAG) const {
27453 EVT VT = Op.getValueType();
27454 assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
27455
27456 SDLoc DL(Op);
27457 SDValue Val = Op.getOperand(0);
27458 EVT ContainerVT = getContainerForFixedLengthVector(DAG, Val.getValueType());
27459 Val = convertToScalableVector(DAG, ContainerVT, Val);
27460
27461 bool Signed = Op.getOpcode() == ISD::SIGN_EXTEND;
27462 unsigned ExtendOpc = Signed ? AArch64ISD::SUNPKLO : AArch64ISD::UUNPKLO;
27463
27464 // Repeatedly unpack Val until the result is of the desired element type.
27465 switch (ContainerVT.getSimpleVT().SimpleTy) {
27466 default:
27467 llvm_unreachable("unimplemented container type");
27468 case MVT::nxv16i8:
27469 Val = DAG.getNode(ExtendOpc, DL, MVT::nxv8i16, Val);
27470 if (VT.getVectorElementType() == MVT::i16)
27471 break;
27472 [[fallthrough]];
27473 case MVT::nxv8i16:
27474 Val = DAG.getNode(ExtendOpc, DL, MVT::nxv4i32, Val);
27475 if (VT.getVectorElementType() == MVT::i32)
27476 break;
27477 [[fallthrough]];
27478 case MVT::nxv4i32:
27479 Val = DAG.getNode(ExtendOpc, DL, MVT::nxv2i64, Val);
27480 assert(VT.getVectorElementType() == MVT::i64 && "Unexpected element type!");
27481 break;
27482 }
27483
27484 return convertFromScalableVector(DAG, VT, Val);
27485 }
27486
LowerFixedLengthVectorTruncateToSVE(SDValue Op,SelectionDAG & DAG) const27487 SDValue AArch64TargetLowering::LowerFixedLengthVectorTruncateToSVE(
27488 SDValue Op, SelectionDAG &DAG) const {
27489 EVT VT = Op.getValueType();
27490 assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
27491
27492 SDLoc DL(Op);
27493 SDValue Val = Op.getOperand(0);
27494 EVT ContainerVT = getContainerForFixedLengthVector(DAG, Val.getValueType());
27495 Val = convertToScalableVector(DAG, ContainerVT, Val);
27496
27497 // Repeatedly truncate Val until the result is of the desired element type.
27498 switch (ContainerVT.getSimpleVT().SimpleTy) {
27499 default:
27500 llvm_unreachable("unimplemented container type");
27501 case MVT::nxv2i64:
27502 Val = DAG.getNode(ISD::BITCAST, DL, MVT::nxv4i32, Val);
27503 Val = DAG.getNode(AArch64ISD::UZP1, DL, MVT::nxv4i32, Val, Val);
27504 if (VT.getVectorElementType() == MVT::i32)
27505 break;
27506 [[fallthrough]];
27507 case MVT::nxv4i32:
27508 Val = DAG.getNode(ISD::BITCAST, DL, MVT::nxv8i16, Val);
27509 Val = DAG.getNode(AArch64ISD::UZP1, DL, MVT::nxv8i16, Val, Val);
27510 if (VT.getVectorElementType() == MVT::i16)
27511 break;
27512 [[fallthrough]];
27513 case MVT::nxv8i16:
27514 Val = DAG.getNode(ISD::BITCAST, DL, MVT::nxv16i8, Val);
27515 Val = DAG.getNode(AArch64ISD::UZP1, DL, MVT::nxv16i8, Val, Val);
27516 assert(VT.getVectorElementType() == MVT::i8 && "Unexpected element type!");
27517 break;
27518 }
27519
27520 return convertFromScalableVector(DAG, VT, Val);
27521 }
27522
LowerFixedLengthExtractVectorElt(SDValue Op,SelectionDAG & DAG) const27523 SDValue AArch64TargetLowering::LowerFixedLengthExtractVectorElt(
27524 SDValue Op, SelectionDAG &DAG) const {
27525 EVT VT = Op.getValueType();
27526 EVT InVT = Op.getOperand(0).getValueType();
27527 assert(InVT.isFixedLengthVector() && "Expected fixed length vector type!");
27528
27529 SDLoc DL(Op);
27530 EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
27531 SDValue Op0 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(0));
27532
27533 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Op.getOperand(1));
27534 }
27535
LowerFixedLengthInsertVectorElt(SDValue Op,SelectionDAG & DAG) const27536 SDValue AArch64TargetLowering::LowerFixedLengthInsertVectorElt(
27537 SDValue Op, SelectionDAG &DAG) const {
27538 EVT VT = Op.getValueType();
27539 assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
27540
27541 SDLoc DL(Op);
27542 EVT InVT = Op.getOperand(0).getValueType();
27543 EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
27544 SDValue Op0 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(0));
27545
27546 auto ScalableRes = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ContainerVT, Op0,
27547 Op.getOperand(1), Op.getOperand(2));
27548
27549 return convertFromScalableVector(DAG, VT, ScalableRes);
27550 }
27551
27552 // Convert vector operation 'Op' to an equivalent predicated operation whereby
27553 // the original operation's type is used to construct a suitable predicate.
27554 // NOTE: The results for inactive lanes are undefined.
LowerToPredicatedOp(SDValue Op,SelectionDAG & DAG,unsigned NewOp) const27555 SDValue AArch64TargetLowering::LowerToPredicatedOp(SDValue Op,
27556 SelectionDAG &DAG,
27557 unsigned NewOp) const {
27558 EVT VT = Op.getValueType();
27559 SDLoc DL(Op);
27560 auto Pg = getPredicateForVector(DAG, DL, VT);
27561
27562 if (VT.isFixedLengthVector()) {
27563 assert(isTypeLegal(VT) && "Expected only legal fixed-width types");
27564 EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
27565
27566 // Create list of operands by converting existing ones to scalable types.
27567 SmallVector<SDValue, 4> Operands = {Pg};
27568 for (const SDValue &V : Op->op_values()) {
27569 if (isa<CondCodeSDNode>(V)) {
27570 Operands.push_back(V);
27571 continue;
27572 }
27573
27574 if (const VTSDNode *VTNode = dyn_cast<VTSDNode>(V)) {
27575 EVT VTArg = VTNode->getVT().getVectorElementType();
27576 EVT NewVTArg = ContainerVT.changeVectorElementType(VTArg);
27577 Operands.push_back(DAG.getValueType(NewVTArg));
27578 continue;
27579 }
27580
27581 assert(isTypeLegal(V.getValueType()) &&
27582 "Expected only legal fixed-width types");
27583 Operands.push_back(convertToScalableVector(DAG, ContainerVT, V));
27584 }
27585
27586 if (isMergePassthruOpcode(NewOp))
27587 Operands.push_back(DAG.getUNDEF(ContainerVT));
27588
27589 auto ScalableRes = DAG.getNode(NewOp, DL, ContainerVT, Operands);
27590 return convertFromScalableVector(DAG, VT, ScalableRes);
27591 }
27592
27593 assert(VT.isScalableVector() && "Only expect to lower scalable vector op!");
27594
27595 SmallVector<SDValue, 4> Operands = {Pg};
27596 for (const SDValue &V : Op->op_values()) {
27597 assert((!V.getValueType().isVector() ||
27598 V.getValueType().isScalableVector()) &&
27599 "Only scalable vectors are supported!");
27600 Operands.push_back(V);
27601 }
27602
27603 if (isMergePassthruOpcode(NewOp))
27604 Operands.push_back(DAG.getUNDEF(VT));
27605
27606 return DAG.getNode(NewOp, DL, VT, Operands, Op->getFlags());
27607 }
27608
27609 // If a fixed length vector operation has no side effects when applied to
27610 // undefined elements, we can safely use scalable vectors to perform the same
27611 // operation without needing to worry about predication.
LowerToScalableOp(SDValue Op,SelectionDAG & DAG) const27612 SDValue AArch64TargetLowering::LowerToScalableOp(SDValue Op,
27613 SelectionDAG &DAG) const {
27614 EVT VT = Op.getValueType();
27615 assert(VT.isFixedLengthVector() && isTypeLegal(VT) &&
27616 "Only expected to lower fixed length vector operation!");
27617 EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
27618
27619 // Create list of operands by converting existing ones to scalable types.
27620 SmallVector<SDValue, 4> Ops;
27621 for (const SDValue &V : Op->op_values()) {
27622 assert(!isa<VTSDNode>(V) && "Unexpected VTSDNode node!");
27623
27624 // Pass through non-vector operands.
27625 if (!V.getValueType().isVector()) {
27626 Ops.push_back(V);
27627 continue;
27628 }
27629
27630 // "cast" fixed length vector to a scalable vector.
27631 assert(V.getValueType().isFixedLengthVector() &&
27632 isTypeLegal(V.getValueType()) &&
27633 "Only fixed length vectors are supported!");
27634 Ops.push_back(convertToScalableVector(DAG, ContainerVT, V));
27635 }
27636
27637 auto ScalableRes = DAG.getNode(Op.getOpcode(), SDLoc(Op), ContainerVT, Ops);
27638 return convertFromScalableVector(DAG, VT, ScalableRes);
27639 }
27640
LowerVECREDUCE_SEQ_FADD(SDValue ScalarOp,SelectionDAG & DAG) const27641 SDValue AArch64TargetLowering::LowerVECREDUCE_SEQ_FADD(SDValue ScalarOp,
27642 SelectionDAG &DAG) const {
27643 SDLoc DL(ScalarOp);
27644 SDValue AccOp = ScalarOp.getOperand(0);
27645 SDValue VecOp = ScalarOp.getOperand(1);
27646 EVT SrcVT = VecOp.getValueType();
27647 EVT ResVT = SrcVT.getVectorElementType();
27648
27649 EVT ContainerVT = SrcVT;
27650 if (SrcVT.isFixedLengthVector()) {
27651 ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT);
27652 VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
27653 }
27654
27655 SDValue Pg = getPredicateForVector(DAG, DL, SrcVT);
27656 SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
27657
27658 // Convert operands to Scalable.
27659 AccOp = DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, ContainerVT,
27660 DAG.getUNDEF(ContainerVT), AccOp, Zero);
27661
27662 // Perform reduction.
27663 SDValue Rdx = DAG.getNode(AArch64ISD::FADDA_PRED, DL, ContainerVT,
27664 Pg, AccOp, VecOp);
27665
27666 return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Rdx, Zero);
27667 }
27668
LowerPredReductionToSVE(SDValue ReduceOp,SelectionDAG & DAG) const27669 SDValue AArch64TargetLowering::LowerPredReductionToSVE(SDValue ReduceOp,
27670 SelectionDAG &DAG) const {
27671 SDLoc DL(ReduceOp);
27672 SDValue Op = ReduceOp.getOperand(0);
27673 EVT OpVT = Op.getValueType();
27674 EVT VT = ReduceOp.getValueType();
27675
27676 if (!OpVT.isScalableVector() || OpVT.getVectorElementType() != MVT::i1)
27677 return SDValue();
27678
27679 SDValue Pg = getPredicateForVector(DAG, DL, OpVT);
27680
27681 switch (ReduceOp.getOpcode()) {
27682 default:
27683 return SDValue();
27684 case ISD::VECREDUCE_OR:
27685 if (isAllActivePredicate(DAG, Pg) && OpVT == MVT::nxv16i1)
27686 // The predicate can be 'Op' because
27687 // vecreduce_or(Op & <all true>) <=> vecreduce_or(Op).
27688 return getPTest(DAG, VT, Op, Op, AArch64CC::ANY_ACTIVE);
27689 else
27690 return getPTest(DAG, VT, Pg, Op, AArch64CC::ANY_ACTIVE);
27691 case ISD::VECREDUCE_AND: {
27692 Op = DAG.getNode(ISD::XOR, DL, OpVT, Op, Pg);
27693 return getPTest(DAG, VT, Pg, Op, AArch64CC::NONE_ACTIVE);
27694 }
27695 case ISD::VECREDUCE_XOR: {
27696 SDValue ID =
27697 DAG.getTargetConstant(Intrinsic::aarch64_sve_cntp, DL, MVT::i64);
27698 if (OpVT == MVT::nxv1i1) {
27699 // Emulate a CNTP on .Q using .D and a different governing predicate.
27700 Pg = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv2i1, Pg);
27701 Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, MVT::nxv2i1, Op);
27702 }
27703 SDValue Cntp =
27704 DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64, ID, Pg, Op);
27705 return DAG.getAnyExtOrTrunc(Cntp, DL, VT);
27706 }
27707 }
27708
27709 return SDValue();
27710 }
27711
LowerReductionToSVE(unsigned Opcode,SDValue ScalarOp,SelectionDAG & DAG) const27712 SDValue AArch64TargetLowering::LowerReductionToSVE(unsigned Opcode,
27713 SDValue ScalarOp,
27714 SelectionDAG &DAG) const {
27715 SDLoc DL(ScalarOp);
27716 SDValue VecOp = ScalarOp.getOperand(0);
27717 EVT SrcVT = VecOp.getValueType();
27718
27719 if (useSVEForFixedLengthVectorVT(
27720 SrcVT,
27721 /*OverrideNEON=*/Subtarget->useSVEForFixedLengthVectors())) {
27722 EVT ContainerVT = getContainerForFixedLengthVector(DAG, SrcVT);
27723 VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
27724 }
27725
27726 // UADDV always returns an i64 result.
27727 EVT ResVT = (Opcode == AArch64ISD::UADDV_PRED) ? MVT::i64 :
27728 SrcVT.getVectorElementType();
27729 EVT RdxVT = SrcVT;
27730 if (SrcVT.isFixedLengthVector() || Opcode == AArch64ISD::UADDV_PRED)
27731 RdxVT = getPackedSVEVectorVT(ResVT);
27732
27733 SDValue Pg = getPredicateForVector(DAG, DL, SrcVT);
27734 SDValue Rdx = DAG.getNode(Opcode, DL, RdxVT, Pg, VecOp);
27735 SDValue Res = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT,
27736 Rdx, DAG.getConstant(0, DL, MVT::i64));
27737
27738 // The VEC_REDUCE nodes expect an element size result.
27739 if (ResVT != ScalarOp.getValueType())
27740 Res = DAG.getAnyExtOrTrunc(Res, DL, ScalarOp.getValueType());
27741
27742 return Res;
27743 }
27744
27745 SDValue
LowerFixedLengthVectorSelectToSVE(SDValue Op,SelectionDAG & DAG) const27746 AArch64TargetLowering::LowerFixedLengthVectorSelectToSVE(SDValue Op,
27747 SelectionDAG &DAG) const {
27748 EVT VT = Op.getValueType();
27749 SDLoc DL(Op);
27750
27751 EVT InVT = Op.getOperand(1).getValueType();
27752 EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
27753 SDValue Op1 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(1));
27754 SDValue Op2 = convertToScalableVector(DAG, ContainerVT, Op->getOperand(2));
27755
27756 // Convert the mask to a predicated (NOTE: We don't need to worry about
27757 // inactive lanes since VSELECT is safe when given undefined elements).
27758 EVT MaskVT = Op.getOperand(0).getValueType();
27759 EVT MaskContainerVT = getContainerForFixedLengthVector(DAG, MaskVT);
27760 auto Mask = convertToScalableVector(DAG, MaskContainerVT, Op.getOperand(0));
27761 Mask = DAG.getNode(ISD::TRUNCATE, DL,
27762 MaskContainerVT.changeVectorElementType(MVT::i1), Mask);
27763
27764 auto ScalableRes = DAG.getNode(ISD::VSELECT, DL, ContainerVT,
27765 Mask, Op1, Op2);
27766
27767 return convertFromScalableVector(DAG, VT, ScalableRes);
27768 }
27769
LowerFixedLengthVectorSetccToSVE(SDValue Op,SelectionDAG & DAG) const27770 SDValue AArch64TargetLowering::LowerFixedLengthVectorSetccToSVE(
27771 SDValue Op, SelectionDAG &DAG) const {
27772 SDLoc DL(Op);
27773 EVT InVT = Op.getOperand(0).getValueType();
27774 EVT ContainerVT = getContainerForFixedLengthVector(DAG, InVT);
27775
27776 assert(InVT.isFixedLengthVector() && isTypeLegal(InVT) &&
27777 "Only expected to lower fixed length vector operation!");
27778 assert(Op.getValueType() == InVT.changeTypeToInteger() &&
27779 "Expected integer result of the same bit length as the inputs!");
27780
27781 auto Op1 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(0));
27782 auto Op2 = convertToScalableVector(DAG, ContainerVT, Op.getOperand(1));
27783 auto Pg = getPredicateForFixedLengthVector(DAG, DL, InVT);
27784
27785 EVT CmpVT = Pg.getValueType();
27786 auto Cmp = DAG.getNode(AArch64ISD::SETCC_MERGE_ZERO, DL, CmpVT,
27787 {Pg, Op1, Op2, Op.getOperand(2)});
27788
27789 EVT PromoteVT = ContainerVT.changeTypeToInteger();
27790 auto Promote = DAG.getBoolExtOrTrunc(Cmp, DL, PromoteVT, InVT);
27791 return convertFromScalableVector(DAG, Op.getValueType(), Promote);
27792 }
27793
27794 SDValue
LowerFixedLengthBitcastToSVE(SDValue Op,SelectionDAG & DAG) const27795 AArch64TargetLowering::LowerFixedLengthBitcastToSVE(SDValue Op,
27796 SelectionDAG &DAG) const {
27797 SDLoc DL(Op);
27798 auto SrcOp = Op.getOperand(0);
27799 EVT VT = Op.getValueType();
27800 EVT ContainerDstVT = getContainerForFixedLengthVector(DAG, VT);
27801 EVT ContainerSrcVT =
27802 getContainerForFixedLengthVector(DAG, SrcOp.getValueType());
27803
27804 SrcOp = convertToScalableVector(DAG, ContainerSrcVT, SrcOp);
27805 Op = DAG.getNode(ISD::BITCAST, DL, ContainerDstVT, SrcOp);
27806 return convertFromScalableVector(DAG, VT, Op);
27807 }
27808
LowerFixedLengthConcatVectorsToSVE(SDValue Op,SelectionDAG & DAG) const27809 SDValue AArch64TargetLowering::LowerFixedLengthConcatVectorsToSVE(
27810 SDValue Op, SelectionDAG &DAG) const {
27811 SDLoc DL(Op);
27812 unsigned NumOperands = Op->getNumOperands();
27813
27814 assert(NumOperands > 1 && isPowerOf2_32(NumOperands) &&
27815 "Unexpected number of operands in CONCAT_VECTORS");
27816
27817 auto SrcOp1 = Op.getOperand(0);
27818 auto SrcOp2 = Op.getOperand(1);
27819 EVT VT = Op.getValueType();
27820 EVT SrcVT = SrcOp1.getValueType();
27821
27822 if (NumOperands > 2) {
27823 SmallVector<SDValue, 4> Ops;
27824 EVT PairVT = SrcVT.getDoubleNumVectorElementsVT(*DAG.getContext());
27825 for (unsigned I = 0; I < NumOperands; I += 2)
27826 Ops.push_back(DAG.getNode(ISD::CONCAT_VECTORS, DL, PairVT,
27827 Op->getOperand(I), Op->getOperand(I + 1)));
27828
27829 return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, Ops);
27830 }
27831
27832 EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
27833
27834 SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, SrcVT);
27835 SrcOp1 = convertToScalableVector(DAG, ContainerVT, SrcOp1);
27836 SrcOp2 = convertToScalableVector(DAG, ContainerVT, SrcOp2);
27837
27838 Op = DAG.getNode(AArch64ISD::SPLICE, DL, ContainerVT, Pg, SrcOp1, SrcOp2);
27839
27840 return convertFromScalableVector(DAG, VT, Op);
27841 }
27842
27843 SDValue
LowerFixedLengthFPExtendToSVE(SDValue Op,SelectionDAG & DAG) const27844 AArch64TargetLowering::LowerFixedLengthFPExtendToSVE(SDValue Op,
27845 SelectionDAG &DAG) const {
27846 EVT VT = Op.getValueType();
27847 assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
27848
27849 SDLoc DL(Op);
27850 SDValue Val = Op.getOperand(0);
27851 SDValue Pg = getPredicateForVector(DAG, DL, VT);
27852 EVT SrcVT = Val.getValueType();
27853 EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
27854 EVT ExtendVT = ContainerVT.changeVectorElementType(
27855 SrcVT.getVectorElementType());
27856
27857 Val = DAG.getNode(ISD::BITCAST, DL, SrcVT.changeTypeToInteger(), Val);
27858 Val = DAG.getNode(ISD::ANY_EXTEND, DL, VT.changeTypeToInteger(), Val);
27859
27860 Val = convertToScalableVector(DAG, ContainerVT.changeTypeToInteger(), Val);
27861 Val = getSVESafeBitCast(ExtendVT, Val, DAG);
27862 Val = DAG.getNode(AArch64ISD::FP_EXTEND_MERGE_PASSTHRU, DL, ContainerVT,
27863 Pg, Val, DAG.getUNDEF(ContainerVT));
27864
27865 return convertFromScalableVector(DAG, VT, Val);
27866 }
27867
27868 SDValue
LowerFixedLengthFPRoundToSVE(SDValue Op,SelectionDAG & DAG) const27869 AArch64TargetLowering::LowerFixedLengthFPRoundToSVE(SDValue Op,
27870 SelectionDAG &DAG) const {
27871 EVT VT = Op.getValueType();
27872 assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
27873
27874 SDLoc DL(Op);
27875 SDValue Val = Op.getOperand(0);
27876 EVT SrcVT = Val.getValueType();
27877 EVT ContainerSrcVT = getContainerForFixedLengthVector(DAG, SrcVT);
27878 EVT RoundVT = ContainerSrcVT.changeVectorElementType(
27879 VT.getVectorElementType());
27880 SDValue Pg = getPredicateForVector(DAG, DL, RoundVT);
27881
27882 Val = convertToScalableVector(DAG, ContainerSrcVT, Val);
27883 Val = DAG.getNode(AArch64ISD::FP_ROUND_MERGE_PASSTHRU, DL, RoundVT, Pg, Val,
27884 Op.getOperand(1), DAG.getUNDEF(RoundVT));
27885 Val = getSVESafeBitCast(ContainerSrcVT.changeTypeToInteger(), Val, DAG);
27886 Val = convertFromScalableVector(DAG, SrcVT.changeTypeToInteger(), Val);
27887
27888 Val = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Val);
27889 return DAG.getNode(ISD::BITCAST, DL, VT, Val);
27890 }
27891
27892 SDValue
LowerFixedLengthIntToFPToSVE(SDValue Op,SelectionDAG & DAG) const27893 AArch64TargetLowering::LowerFixedLengthIntToFPToSVE(SDValue Op,
27894 SelectionDAG &DAG) const {
27895 EVT VT = Op.getValueType();
27896 assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
27897
27898 bool IsSigned = Op.getOpcode() == ISD::SINT_TO_FP;
27899 unsigned Opcode = IsSigned ? AArch64ISD::SINT_TO_FP_MERGE_PASSTHRU
27900 : AArch64ISD::UINT_TO_FP_MERGE_PASSTHRU;
27901
27902 SDLoc DL(Op);
27903 SDValue Val = Op.getOperand(0);
27904 EVT SrcVT = Val.getValueType();
27905 EVT ContainerDstVT = getContainerForFixedLengthVector(DAG, VT);
27906 EVT ContainerSrcVT = getContainerForFixedLengthVector(DAG, SrcVT);
27907
27908 if (VT.bitsGE(SrcVT)) {
27909 SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, VT);
27910
27911 Val = DAG.getNode(IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL,
27912 VT.changeTypeToInteger(), Val);
27913
27914 // Safe to use a larger than specified operand because by promoting the
27915 // value nothing has changed from an arithmetic point of view.
27916 Val =
27917 convertToScalableVector(DAG, ContainerDstVT.changeTypeToInteger(), Val);
27918 Val = DAG.getNode(Opcode, DL, ContainerDstVT, Pg, Val,
27919 DAG.getUNDEF(ContainerDstVT));
27920 return convertFromScalableVector(DAG, VT, Val);
27921 } else {
27922 EVT CvtVT = ContainerSrcVT.changeVectorElementType(
27923 ContainerDstVT.getVectorElementType());
27924 SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, SrcVT);
27925
27926 Val = convertToScalableVector(DAG, ContainerSrcVT, Val);
27927 Val = DAG.getNode(Opcode, DL, CvtVT, Pg, Val, DAG.getUNDEF(CvtVT));
27928 Val = getSVESafeBitCast(ContainerSrcVT, Val, DAG);
27929 Val = convertFromScalableVector(DAG, SrcVT, Val);
27930
27931 Val = DAG.getNode(ISD::TRUNCATE, DL, VT.changeTypeToInteger(), Val);
27932 return DAG.getNode(ISD::BITCAST, DL, VT, Val);
27933 }
27934 }
27935
27936 SDValue
LowerVECTOR_DEINTERLEAVE(SDValue Op,SelectionDAG & DAG) const27937 AArch64TargetLowering::LowerVECTOR_DEINTERLEAVE(SDValue Op,
27938 SelectionDAG &DAG) const {
27939 SDLoc DL(Op);
27940 EVT OpVT = Op.getValueType();
27941 assert(OpVT.isScalableVector() &&
27942 "Expected scalable vector in LowerVECTOR_DEINTERLEAVE.");
27943 SDValue Even = DAG.getNode(AArch64ISD::UZP1, DL, OpVT, Op.getOperand(0),
27944 Op.getOperand(1));
27945 SDValue Odd = DAG.getNode(AArch64ISD::UZP2, DL, OpVT, Op.getOperand(0),
27946 Op.getOperand(1));
27947 return DAG.getMergeValues({Even, Odd}, DL);
27948 }
27949
LowerVECTOR_INTERLEAVE(SDValue Op,SelectionDAG & DAG) const27950 SDValue AArch64TargetLowering::LowerVECTOR_INTERLEAVE(SDValue Op,
27951 SelectionDAG &DAG) const {
27952 SDLoc DL(Op);
27953 EVT OpVT = Op.getValueType();
27954 assert(OpVT.isScalableVector() &&
27955 "Expected scalable vector in LowerVECTOR_INTERLEAVE.");
27956
27957 SDValue Lo = DAG.getNode(AArch64ISD::ZIP1, DL, OpVT, Op.getOperand(0),
27958 Op.getOperand(1));
27959 SDValue Hi = DAG.getNode(AArch64ISD::ZIP2, DL, OpVT, Op.getOperand(0),
27960 Op.getOperand(1));
27961 return DAG.getMergeValues({Lo, Hi}, DL);
27962 }
27963
LowerVECTOR_HISTOGRAM(SDValue Op,SelectionDAG & DAG) const27964 SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
27965 SelectionDAG &DAG) const {
27966 // FIXME: Maybe share some code with LowerMGather/Scatter?
27967 MaskedHistogramSDNode *HG = cast<MaskedHistogramSDNode>(Op);
27968 SDLoc DL(HG);
27969 SDValue Chain = HG->getChain();
27970 SDValue Inc = HG->getInc();
27971 SDValue Mask = HG->getMask();
27972 SDValue Ptr = HG->getBasePtr();
27973 SDValue Index = HG->getIndex();
27974 SDValue Scale = HG->getScale();
27975 SDValue IntID = HG->getIntID();
27976
27977 // The Intrinsic ID determines the type of update operation.
27978 [[maybe_unused]] ConstantSDNode *CID = cast<ConstantSDNode>(IntID.getNode());
27979 // Right now, we only support 'add' as an update.
27980 assert(CID->getZExtValue() == Intrinsic::experimental_vector_histogram_add &&
27981 "Unexpected histogram update operation");
27982
27983 EVT IncVT = Inc.getValueType();
27984 EVT IndexVT = Index.getValueType();
27985 EVT MemVT = EVT::getVectorVT(*DAG.getContext(), IncVT,
27986 IndexVT.getVectorElementCount());
27987 SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
27988 SDValue PassThru = DAG.getSplatVector(MemVT, DL, Zero);
27989 SDValue IncSplat = DAG.getSplatVector(MemVT, DL, Inc);
27990 SDValue Ops[] = {Chain, PassThru, Mask, Ptr, Index, Scale};
27991
27992 MachineMemOperand *MMO = HG->getMemOperand();
27993 // Create an MMO for the gather, without load|store flags.
27994 MachineMemOperand *GMMO = DAG.getMachineFunction().getMachineMemOperand(
27995 MMO->getPointerInfo(), MachineMemOperand::MOLoad, MMO->getSize(),
27996 MMO->getAlign(), MMO->getAAInfo());
27997 ISD::MemIndexType IndexType = HG->getIndexType();
27998 SDValue Gather =
27999 DAG.getMaskedGather(DAG.getVTList(MemVT, MVT::Other), MemVT, DL, Ops,
28000 GMMO, IndexType, ISD::NON_EXTLOAD);
28001
28002 SDValue GChain = Gather.getValue(1);
28003
28004 // Perform the histcnt, multiply by inc, add to bucket data.
28005 SDValue ID = DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncVT);
28006 SDValue HistCnt =
28007 DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT, ID, Mask, Index, Index);
28008 SDValue Mul = DAG.getNode(ISD::MUL, DL, MemVT, HistCnt, IncSplat);
28009 SDValue Add = DAG.getNode(ISD::ADD, DL, MemVT, Gather, Mul);
28010
28011 // Create an MMO for the scatter, without load|store flags.
28012 MachineMemOperand *SMMO = DAG.getMachineFunction().getMachineMemOperand(
28013 MMO->getPointerInfo(), MachineMemOperand::MOStore, MMO->getSize(),
28014 MMO->getAlign(), MMO->getAAInfo());
28015
28016 SDValue ScatterOps[] = {GChain, Add, Mask, Ptr, Index, Scale};
28017 SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MemVT, DL,
28018 ScatterOps, SMMO, IndexType, false);
28019 return Scatter;
28020 }
28021
28022 SDValue
LowerFixedLengthFPToIntToSVE(SDValue Op,SelectionDAG & DAG) const28023 AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op,
28024 SelectionDAG &DAG) const {
28025 EVT VT = Op.getValueType();
28026 assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
28027
28028 bool IsSigned = Op.getOpcode() == ISD::FP_TO_SINT;
28029 unsigned Opcode = IsSigned ? AArch64ISD::FCVTZS_MERGE_PASSTHRU
28030 : AArch64ISD::FCVTZU_MERGE_PASSTHRU;
28031
28032 SDLoc DL(Op);
28033 SDValue Val = Op.getOperand(0);
28034 EVT SrcVT = Val.getValueType();
28035 EVT ContainerDstVT = getContainerForFixedLengthVector(DAG, VT);
28036 EVT ContainerSrcVT = getContainerForFixedLengthVector(DAG, SrcVT);
28037
28038 if (VT.bitsGT(SrcVT)) {
28039 EVT CvtVT = ContainerDstVT.changeVectorElementType(
28040 ContainerSrcVT.getVectorElementType());
28041 SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, VT);
28042
28043 Val = DAG.getNode(ISD::BITCAST, DL, SrcVT.changeTypeToInteger(), Val);
28044 Val = DAG.getNode(ISD::ANY_EXTEND, DL, VT, Val);
28045
28046 Val = convertToScalableVector(DAG, ContainerDstVT, Val);
28047 Val = getSVESafeBitCast(CvtVT, Val, DAG);
28048 Val = DAG.getNode(Opcode, DL, ContainerDstVT, Pg, Val,
28049 DAG.getUNDEF(ContainerDstVT));
28050 return convertFromScalableVector(DAG, VT, Val);
28051 } else {
28052 EVT CvtVT = ContainerSrcVT.changeTypeToInteger();
28053 SDValue Pg = getPredicateForFixedLengthVector(DAG, DL, SrcVT);
28054
28055 // Safe to use a larger than specified result since an fp_to_int where the
28056 // result doesn't fit into the destination is undefined.
28057 Val = convertToScalableVector(DAG, ContainerSrcVT, Val);
28058 Val = DAG.getNode(Opcode, DL, CvtVT, Pg, Val, DAG.getUNDEF(CvtVT));
28059 Val = convertFromScalableVector(DAG, SrcVT.changeTypeToInteger(), Val);
28060
28061 return DAG.getNode(ISD::TRUNCATE, DL, VT, Val);
28062 }
28063 }
28064
GenerateFixedLengthSVETBL(SDValue Op,SDValue Op1,SDValue Op2,ArrayRef<int> ShuffleMask,EVT VT,EVT ContainerVT,SelectionDAG & DAG)28065 static SDValue GenerateFixedLengthSVETBL(SDValue Op, SDValue Op1, SDValue Op2,
28066 ArrayRef<int> ShuffleMask, EVT VT,
28067 EVT ContainerVT, SelectionDAG &DAG) {
28068 auto &Subtarget = DAG.getSubtarget<AArch64Subtarget>();
28069 SDLoc DL(Op);
28070 unsigned MinSVESize = Subtarget.getMinSVEVectorSizeInBits();
28071 unsigned MaxSVESize = Subtarget.getMaxSVEVectorSizeInBits();
28072 bool IsSingleOp =
28073 ShuffleVectorInst::isSingleSourceMask(ShuffleMask, ShuffleMask.size());
28074
28075 if (!Subtarget.isNeonAvailable() && !MinSVESize)
28076 MinSVESize = 128;
28077
28078 // Ignore two operands if no SVE2 or all index numbers couldn't
28079 // be represented.
28080 if (!IsSingleOp && !Subtarget.hasSVE2())
28081 return SDValue();
28082
28083 EVT VTOp1 = Op.getOperand(0).getValueType();
28084 unsigned BitsPerElt = VTOp1.getVectorElementType().getSizeInBits();
28085 unsigned IndexLen = MinSVESize / BitsPerElt;
28086 unsigned ElementsPerVectorReg = VTOp1.getVectorNumElements();
28087 uint64_t MaxOffset = APInt(BitsPerElt, -1, false).getZExtValue();
28088 EVT MaskEltType = VTOp1.getVectorElementType().changeTypeToInteger();
28089 EVT MaskType = EVT::getVectorVT(*DAG.getContext(), MaskEltType, IndexLen);
28090 bool MinMaxEqual = (MinSVESize == MaxSVESize);
28091 assert(ElementsPerVectorReg <= IndexLen && ShuffleMask.size() <= IndexLen &&
28092 "Incorrectly legalised shuffle operation");
28093
28094 SmallVector<SDValue, 8> TBLMask;
28095 // If MinSVESize is not equal to MaxSVESize then we need to know which
28096 // TBL mask element needs adjustment.
28097 SmallVector<SDValue, 8> AddRuntimeVLMask;
28098
28099 // Bail out for 8-bits element types, because with 2048-bit SVE register
28100 // size 8 bits is only sufficient to index into the first source vector.
28101 if (!IsSingleOp && !MinMaxEqual && BitsPerElt == 8)
28102 return SDValue();
28103
28104 for (int Index : ShuffleMask) {
28105 // Handling poison index value.
28106 if (Index < 0)
28107 Index = 0;
28108 // If the mask refers to elements in the second operand, then we have to
28109 // offset the index by the number of elements in a vector. If this is number
28110 // is not known at compile-time, we need to maintain a mask with 'VL' values
28111 // to add at runtime.
28112 if ((unsigned)Index >= ElementsPerVectorReg) {
28113 if (MinMaxEqual) {
28114 Index += IndexLen - ElementsPerVectorReg;
28115 } else {
28116 Index = Index - ElementsPerVectorReg;
28117 AddRuntimeVLMask.push_back(DAG.getConstant(1, DL, MVT::i64));
28118 }
28119 } else if (!MinMaxEqual)
28120 AddRuntimeVLMask.push_back(DAG.getConstant(0, DL, MVT::i64));
28121 // For 8-bit elements and 1024-bit SVE registers and MaxOffset equals
28122 // to 255, this might point to the last element of in the second operand
28123 // of the shufflevector, thus we are rejecting this transform.
28124 if ((unsigned)Index >= MaxOffset)
28125 return SDValue();
28126 TBLMask.push_back(DAG.getConstant(Index, DL, MVT::i64));
28127 }
28128
28129 // Choosing an out-of-range index leads to the lane being zeroed vs zero
28130 // value where it would perform first lane duplication for out of
28131 // index elements. For i8 elements an out-of-range index could be a valid
28132 // for 2048-bit vector register size.
28133 for (unsigned i = 0; i < IndexLen - ElementsPerVectorReg; ++i) {
28134 TBLMask.push_back(DAG.getConstant((int)MaxOffset, DL, MVT::i64));
28135 if (!MinMaxEqual)
28136 AddRuntimeVLMask.push_back(DAG.getConstant(0, DL, MVT::i64));
28137 }
28138
28139 EVT MaskContainerVT = getContainerForFixedLengthVector(DAG, MaskType);
28140 SDValue VecMask =
28141 DAG.getBuildVector(MaskType, DL, ArrayRef(TBLMask.data(), IndexLen));
28142 SDValue SVEMask = convertToScalableVector(DAG, MaskContainerVT, VecMask);
28143
28144 SDValue Shuffle;
28145 if (IsSingleOp)
28146 Shuffle =
28147 DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT,
28148 DAG.getConstant(Intrinsic::aarch64_sve_tbl, DL, MVT::i32),
28149 Op1, SVEMask);
28150 else if (Subtarget.hasSVE2()) {
28151 if (!MinMaxEqual) {
28152 unsigned MinNumElts = AArch64::SVEBitsPerBlock / BitsPerElt;
28153 SDValue VScale = (BitsPerElt == 64)
28154 ? DAG.getVScale(DL, MVT::i64, APInt(64, MinNumElts))
28155 : DAG.getVScale(DL, MVT::i32, APInt(32, MinNumElts));
28156 SDValue VecMask =
28157 DAG.getBuildVector(MaskType, DL, ArrayRef(TBLMask.data(), IndexLen));
28158 SDValue MulByMask = DAG.getNode(
28159 ISD::MUL, DL, MaskType,
28160 DAG.getNode(ISD::SPLAT_VECTOR, DL, MaskType, VScale),
28161 DAG.getBuildVector(MaskType, DL,
28162 ArrayRef(AddRuntimeVLMask.data(), IndexLen)));
28163 SDValue UpdatedVecMask =
28164 DAG.getNode(ISD::ADD, DL, MaskType, VecMask, MulByMask);
28165 SVEMask = convertToScalableVector(
28166 DAG, getContainerForFixedLengthVector(DAG, MaskType), UpdatedVecMask);
28167 }
28168 Shuffle =
28169 DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ContainerVT,
28170 DAG.getConstant(Intrinsic::aarch64_sve_tbl2, DL, MVT::i32),
28171 Op1, Op2, SVEMask);
28172 }
28173 Shuffle = convertFromScalableVector(DAG, VT, Shuffle);
28174 return DAG.getNode(ISD::BITCAST, DL, Op.getValueType(), Shuffle);
28175 }
28176
LowerFixedLengthVECTOR_SHUFFLEToSVE(SDValue Op,SelectionDAG & DAG) const28177 SDValue AArch64TargetLowering::LowerFixedLengthVECTOR_SHUFFLEToSVE(
28178 SDValue Op, SelectionDAG &DAG) const {
28179 EVT VT = Op.getValueType();
28180 assert(VT.isFixedLengthVector() && "Expected fixed length vector type!");
28181
28182 auto *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
28183 auto ShuffleMask = SVN->getMask();
28184
28185 SDLoc DL(Op);
28186 SDValue Op1 = Op.getOperand(0);
28187 SDValue Op2 = Op.getOperand(1);
28188
28189 EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT);
28190 Op1 = convertToScalableVector(DAG, ContainerVT, Op1);
28191 Op2 = convertToScalableVector(DAG, ContainerVT, Op2);
28192
28193 auto MinLegalExtractEltScalarTy = [](EVT ScalarTy) -> EVT {
28194 if (ScalarTy == MVT::i8 || ScalarTy == MVT::i16)
28195 return MVT::i32;
28196 return ScalarTy;
28197 };
28198
28199 if (SVN->isSplat()) {
28200 unsigned Lane = std::max(0, SVN->getSplatIndex());
28201 EVT ScalarTy = MinLegalExtractEltScalarTy(VT.getVectorElementType());
28202 SDValue SplatEl = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ScalarTy, Op1,
28203 DAG.getConstant(Lane, DL, MVT::i64));
28204 Op = DAG.getNode(ISD::SPLAT_VECTOR, DL, ContainerVT, SplatEl);
28205 return convertFromScalableVector(DAG, VT, Op);
28206 }
28207
28208 bool ReverseEXT = false;
28209 unsigned Imm;
28210 if (isEXTMask(ShuffleMask, VT, ReverseEXT, Imm) &&
28211 Imm == VT.getVectorNumElements() - 1) {
28212 if (ReverseEXT)
28213 std::swap(Op1, Op2);
28214 EVT ScalarTy = MinLegalExtractEltScalarTy(VT.getVectorElementType());
28215 SDValue Scalar = DAG.getNode(
28216 ISD::EXTRACT_VECTOR_ELT, DL, ScalarTy, Op1,
28217 DAG.getConstant(VT.getVectorNumElements() - 1, DL, MVT::i64));
28218 Op = DAG.getNode(AArch64ISD::INSR, DL, ContainerVT, Op2, Scalar);
28219 return convertFromScalableVector(DAG, VT, Op);
28220 }
28221
28222 unsigned EltSize = VT.getScalarSizeInBits();
28223 for (unsigned LaneSize : {64U, 32U, 16U}) {
28224 if (isREVMask(ShuffleMask, EltSize, VT.getVectorNumElements(), LaneSize)) {
28225 EVT NewVT =
28226 getPackedSVEVectorVT(EVT::getIntegerVT(*DAG.getContext(), LaneSize));
28227 unsigned RevOp;
28228 if (EltSize == 8)
28229 RevOp = AArch64ISD::BSWAP_MERGE_PASSTHRU;
28230 else if (EltSize == 16)
28231 RevOp = AArch64ISD::REVH_MERGE_PASSTHRU;
28232 else
28233 RevOp = AArch64ISD::REVW_MERGE_PASSTHRU;
28234
28235 Op = DAG.getNode(ISD::BITCAST, DL, NewVT, Op1);
28236 Op = LowerToPredicatedOp(Op, DAG, RevOp);
28237 Op = DAG.getNode(ISD::BITCAST, DL, ContainerVT, Op);
28238 return convertFromScalableVector(DAG, VT, Op);
28239 }
28240 }
28241
28242 if (Subtarget->hasSVE2p1() && EltSize == 64 &&
28243 isREVMask(ShuffleMask, EltSize, VT.getVectorNumElements(), 128)) {
28244 if (!VT.isFloatingPoint())
28245 return LowerToPredicatedOp(Op, DAG, AArch64ISD::REVD_MERGE_PASSTHRU);
28246
28247 EVT NewVT = getPackedSVEVectorVT(EVT::getIntegerVT(*DAG.getContext(), 64));
28248 Op = DAG.getNode(ISD::BITCAST, DL, NewVT, Op1);
28249 Op = LowerToPredicatedOp(Op, DAG, AArch64ISD::REVD_MERGE_PASSTHRU);
28250 Op = DAG.getNode(ISD::BITCAST, DL, ContainerVT, Op);
28251 return convertFromScalableVector(DAG, VT, Op);
28252 }
28253
28254 unsigned WhichResult;
28255 if (isZIPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult) &&
28256 WhichResult == 0)
28257 return convertFromScalableVector(
28258 DAG, VT, DAG.getNode(AArch64ISD::ZIP1, DL, ContainerVT, Op1, Op2));
28259
28260 if (isTRNMask(ShuffleMask, VT.getVectorNumElements(), WhichResult)) {
28261 unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2;
28262 return convertFromScalableVector(
28263 DAG, VT, DAG.getNode(Opc, DL, ContainerVT, Op1, Op2));
28264 }
28265
28266 if (isZIP_v_undef_Mask(ShuffleMask, VT, WhichResult) && WhichResult == 0)
28267 return convertFromScalableVector(
28268 DAG, VT, DAG.getNode(AArch64ISD::ZIP1, DL, ContainerVT, Op1, Op1));
28269
28270 if (isTRN_v_undef_Mask(ShuffleMask, VT, WhichResult)) {
28271 unsigned Opc = (WhichResult == 0) ? AArch64ISD::TRN1 : AArch64ISD::TRN2;
28272 return convertFromScalableVector(
28273 DAG, VT, DAG.getNode(Opc, DL, ContainerVT, Op1, Op1));
28274 }
28275
28276 // Functions like isZIPMask return true when a ISD::VECTOR_SHUFFLE's mask
28277 // represents the same logical operation as performed by a ZIP instruction. In
28278 // isolation these functions do not mean the ISD::VECTOR_SHUFFLE is exactly
28279 // equivalent to an AArch64 instruction. There's the extra component of
28280 // ISD::VECTOR_SHUFFLE's value type to consider. Prior to SVE these functions
28281 // only operated on 64/128bit vector types that have a direct mapping to a
28282 // target register and so an exact mapping is implied.
28283 // However, when using SVE for fixed length vectors, most legal vector types
28284 // are actually sub-vectors of a larger SVE register. When mapping
28285 // ISD::VECTOR_SHUFFLE to an SVE instruction care must be taken to consider
28286 // how the mask's indices translate. Specifically, when the mapping requires
28287 // an exact meaning for a specific vector index (e.g. Index X is the last
28288 // vector element in the register) then such mappings are often only safe when
28289 // the exact SVE register size is know. The main exception to this is when
28290 // indices are logically relative to the first element of either
28291 // ISD::VECTOR_SHUFFLE operand because these relative indices don't change
28292 // when converting from fixed-length to scalable vector types (i.e. the start
28293 // of a fixed length vector is always the start of a scalable vector).
28294 unsigned MinSVESize = Subtarget->getMinSVEVectorSizeInBits();
28295 unsigned MaxSVESize = Subtarget->getMaxSVEVectorSizeInBits();
28296 if (MinSVESize == MaxSVESize && MaxSVESize == VT.getSizeInBits()) {
28297 if (ShuffleVectorInst::isReverseMask(ShuffleMask, ShuffleMask.size()) &&
28298 Op2.isUndef()) {
28299 Op = DAG.getNode(ISD::VECTOR_REVERSE, DL, ContainerVT, Op1);
28300 return convertFromScalableVector(DAG, VT, Op);
28301 }
28302
28303 if (isZIPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult) &&
28304 WhichResult != 0)
28305 return convertFromScalableVector(
28306 DAG, VT, DAG.getNode(AArch64ISD::ZIP2, DL, ContainerVT, Op1, Op2));
28307
28308 if (isUZPMask(ShuffleMask, VT.getVectorNumElements(), WhichResult)) {
28309 unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2;
28310 return convertFromScalableVector(
28311 DAG, VT, DAG.getNode(Opc, DL, ContainerVT, Op1, Op2));
28312 }
28313
28314 if (isZIP_v_undef_Mask(ShuffleMask, VT, WhichResult) && WhichResult != 0)
28315 return convertFromScalableVector(
28316 DAG, VT, DAG.getNode(AArch64ISD::ZIP2, DL, ContainerVT, Op1, Op1));
28317
28318 if (isUZP_v_undef_Mask(ShuffleMask, VT, WhichResult)) {
28319 unsigned Opc = (WhichResult == 0) ? AArch64ISD::UZP1 : AArch64ISD::UZP2;
28320 return convertFromScalableVector(
28321 DAG, VT, DAG.getNode(Opc, DL, ContainerVT, Op1, Op1));
28322 }
28323 }
28324
28325 // Avoid producing TBL instruction if we don't know SVE register minimal size,
28326 // unless NEON is not available and we can assume minimal SVE register size is
28327 // 128-bits.
28328 if (MinSVESize || !Subtarget->isNeonAvailable())
28329 return GenerateFixedLengthSVETBL(Op, Op1, Op2, ShuffleMask, VT, ContainerVT,
28330 DAG);
28331
28332 return SDValue();
28333 }
28334
getSVESafeBitCast(EVT VT,SDValue Op,SelectionDAG & DAG) const28335 SDValue AArch64TargetLowering::getSVESafeBitCast(EVT VT, SDValue Op,
28336 SelectionDAG &DAG) const {
28337 SDLoc DL(Op);
28338 EVT InVT = Op.getValueType();
28339
28340 assert(VT.isScalableVector() && isTypeLegal(VT) &&
28341 InVT.isScalableVector() && isTypeLegal(InVT) &&
28342 "Only expect to cast between legal scalable vector types!");
28343 assert(VT.getVectorElementType() != MVT::i1 &&
28344 InVT.getVectorElementType() != MVT::i1 &&
28345 "For predicate bitcasts, use getSVEPredicateBitCast");
28346
28347 if (InVT == VT)
28348 return Op;
28349
28350 EVT PackedVT = getPackedSVEVectorVT(VT.getVectorElementType());
28351 EVT PackedInVT = getPackedSVEVectorVT(InVT.getVectorElementType());
28352
28353 // Safe bitcasting between unpacked vector types of different element counts
28354 // is currently unsupported because the following is missing the necessary
28355 // work to ensure the result's elements live where they're supposed to within
28356 // an SVE register.
28357 // 01234567
28358 // e.g. nxv2i32 = XX??XX??
28359 // nxv4f16 = X?X?X?X?
28360 assert((VT.getVectorElementCount() == InVT.getVectorElementCount() ||
28361 VT == PackedVT || InVT == PackedInVT) &&
28362 "Unexpected bitcast!");
28363
28364 // Pack input if required.
28365 if (InVT != PackedInVT)
28366 Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, PackedInVT, Op);
28367
28368 Op = DAG.getNode(ISD::BITCAST, DL, PackedVT, Op);
28369
28370 // Unpack result if required.
28371 if (VT != PackedVT)
28372 Op = DAG.getNode(AArch64ISD::REINTERPRET_CAST, DL, VT, Op);
28373
28374 return Op;
28375 }
28376
isAllActivePredicate(SelectionDAG & DAG,SDValue N) const28377 bool AArch64TargetLowering::isAllActivePredicate(SelectionDAG &DAG,
28378 SDValue N) const {
28379 return ::isAllActivePredicate(DAG, N);
28380 }
28381
getPromotedVTForPredicate(EVT VT) const28382 EVT AArch64TargetLowering::getPromotedVTForPredicate(EVT VT) const {
28383 return ::getPromotedVTForPredicate(VT);
28384 }
28385
SimplifyDemandedBitsForTargetNode(SDValue Op,const APInt & OriginalDemandedBits,const APInt & OriginalDemandedElts,KnownBits & Known,TargetLoweringOpt & TLO,unsigned Depth) const28386 bool AArch64TargetLowering::SimplifyDemandedBitsForTargetNode(
28387 SDValue Op, const APInt &OriginalDemandedBits,
28388 const APInt &OriginalDemandedElts, KnownBits &Known, TargetLoweringOpt &TLO,
28389 unsigned Depth) const {
28390
28391 unsigned Opc = Op.getOpcode();
28392 switch (Opc) {
28393 case AArch64ISD::VSHL: {
28394 // Match (VSHL (VLSHR Val X) X)
28395 SDValue ShiftL = Op;
28396 SDValue ShiftR = Op->getOperand(0);
28397 if (ShiftR->getOpcode() != AArch64ISD::VLSHR)
28398 return false;
28399
28400 if (!ShiftL.hasOneUse() || !ShiftR.hasOneUse())
28401 return false;
28402
28403 unsigned ShiftLBits = ShiftL->getConstantOperandVal(1);
28404 unsigned ShiftRBits = ShiftR->getConstantOperandVal(1);
28405
28406 // Other cases can be handled as well, but this is not
28407 // implemented.
28408 if (ShiftRBits != ShiftLBits)
28409 return false;
28410
28411 unsigned ScalarSize = Op.getScalarValueSizeInBits();
28412 assert(ScalarSize > ShiftLBits && "Invalid shift imm");
28413
28414 APInt ZeroBits = APInt::getLowBitsSet(ScalarSize, ShiftLBits);
28415 APInt UnusedBits = ~OriginalDemandedBits;
28416
28417 if ((ZeroBits & UnusedBits) != ZeroBits)
28418 return false;
28419
28420 // All bits that are zeroed by (VSHL (VLSHR Val X) X) are not
28421 // used - simplify to just Val.
28422 return TLO.CombineTo(Op, ShiftR->getOperand(0));
28423 }
28424 case AArch64ISD::BICi: {
28425 // Fold BICi if all destination bits already known to be zeroed
28426 SDValue Op0 = Op.getOperand(0);
28427 KnownBits KnownOp0 =
28428 TLO.DAG.computeKnownBits(Op0, OriginalDemandedElts, Depth + 1);
28429 // Op0 &= ~(ConstantOperandVal(1) << ConstantOperandVal(2))
28430 uint64_t BitsToClear = Op->getConstantOperandVal(1)
28431 << Op->getConstantOperandVal(2);
28432 APInt AlreadyZeroedBitsToClear = BitsToClear & KnownOp0.Zero;
28433 if (APInt(Known.getBitWidth(), BitsToClear)
28434 .isSubsetOf(AlreadyZeroedBitsToClear))
28435 return TLO.CombineTo(Op, Op0);
28436
28437 Known = KnownOp0 &
28438 KnownBits::makeConstant(APInt(Known.getBitWidth(), ~BitsToClear));
28439
28440 return false;
28441 }
28442 case ISD::INTRINSIC_WO_CHAIN: {
28443 if (auto ElementSize = IsSVECntIntrinsic(Op)) {
28444 unsigned MaxSVEVectorSizeInBits = Subtarget->getMaxSVEVectorSizeInBits();
28445 if (!MaxSVEVectorSizeInBits)
28446 MaxSVEVectorSizeInBits = AArch64::SVEMaxBitsPerVector;
28447 unsigned MaxElements = MaxSVEVectorSizeInBits / *ElementSize;
28448 // The SVE count intrinsics don't support the multiplier immediate so we
28449 // don't have to account for that here. The value returned may be slightly
28450 // over the true required bits, as this is based on the "ALL" pattern. The
28451 // other patterns are also exposed by these intrinsics, but they all
28452 // return a value that's strictly less than "ALL".
28453 unsigned RequiredBits = llvm::bit_width(MaxElements);
28454 unsigned BitWidth = Known.Zero.getBitWidth();
28455 if (RequiredBits < BitWidth)
28456 Known.Zero.setHighBits(BitWidth - RequiredBits);
28457 return false;
28458 }
28459 }
28460 }
28461
28462 return TargetLowering::SimplifyDemandedBitsForTargetNode(
28463 Op, OriginalDemandedBits, OriginalDemandedElts, Known, TLO, Depth);
28464 }
28465
isTargetCanonicalConstantNode(SDValue Op) const28466 bool AArch64TargetLowering::isTargetCanonicalConstantNode(SDValue Op) const {
28467 return Op.getOpcode() == AArch64ISD::DUP ||
28468 Op.getOpcode() == AArch64ISD::MOVI ||
28469 (Op.getOpcode() == ISD::EXTRACT_SUBVECTOR &&
28470 Op.getOperand(0).getOpcode() == AArch64ISD::DUP) ||
28471 TargetLowering::isTargetCanonicalConstantNode(Op);
28472 }
28473
isComplexDeinterleavingSupported() const28474 bool AArch64TargetLowering::isComplexDeinterleavingSupported() const {
28475 return Subtarget->hasSVE() || Subtarget->hasSVE2() ||
28476 Subtarget->hasComplxNum();
28477 }
28478
isComplexDeinterleavingOperationSupported(ComplexDeinterleavingOperation Operation,Type * Ty) const28479 bool AArch64TargetLowering::isComplexDeinterleavingOperationSupported(
28480 ComplexDeinterleavingOperation Operation, Type *Ty) const {
28481 auto *VTy = dyn_cast<VectorType>(Ty);
28482 if (!VTy)
28483 return false;
28484
28485 // If the vector is scalable, SVE is enabled, implying support for complex
28486 // numbers. Otherwise, we need to ensure complex number support is available
28487 if (!VTy->isScalableTy() && !Subtarget->hasComplxNum())
28488 return false;
28489
28490 auto *ScalarTy = VTy->getScalarType();
28491 unsigned NumElements = VTy->getElementCount().getKnownMinValue();
28492
28493 // We can only process vectors that have a bit size of 128 or higher (with an
28494 // additional 64 bits for Neon). Additionally, these vectors must have a
28495 // power-of-2 size, as we later split them into the smallest supported size
28496 // and merging them back together after applying complex operation.
28497 unsigned VTyWidth = VTy->getScalarSizeInBits() * NumElements;
28498 if ((VTyWidth < 128 && (VTy->isScalableTy() || VTyWidth != 64)) ||
28499 !llvm::isPowerOf2_32(VTyWidth))
28500 return false;
28501
28502 if (ScalarTy->isIntegerTy() && Subtarget->hasSVE2() && VTy->isScalableTy()) {
28503 unsigned ScalarWidth = ScalarTy->getScalarSizeInBits();
28504 return 8 <= ScalarWidth && ScalarWidth <= 64;
28505 }
28506
28507 return (ScalarTy->isHalfTy() && Subtarget->hasFullFP16()) ||
28508 ScalarTy->isFloatTy() || ScalarTy->isDoubleTy();
28509 }
28510
createComplexDeinterleavingIR(IRBuilderBase & B,ComplexDeinterleavingOperation OperationType,ComplexDeinterleavingRotation Rotation,Value * InputA,Value * InputB,Value * Accumulator) const28511 Value *AArch64TargetLowering::createComplexDeinterleavingIR(
28512 IRBuilderBase &B, ComplexDeinterleavingOperation OperationType,
28513 ComplexDeinterleavingRotation Rotation, Value *InputA, Value *InputB,
28514 Value *Accumulator) const {
28515 VectorType *Ty = cast<VectorType>(InputA->getType());
28516 bool IsScalable = Ty->isScalableTy();
28517 bool IsInt = Ty->getElementType()->isIntegerTy();
28518
28519 unsigned TyWidth =
28520 Ty->getScalarSizeInBits() * Ty->getElementCount().getKnownMinValue();
28521
28522 assert(((TyWidth >= 128 && llvm::isPowerOf2_32(TyWidth)) || TyWidth == 64) &&
28523 "Vector type must be either 64 or a power of 2 that is at least 128");
28524
28525 if (TyWidth > 128) {
28526 int Stride = Ty->getElementCount().getKnownMinValue() / 2;
28527 auto *HalfTy = VectorType::getHalfElementsVectorType(Ty);
28528 auto *LowerSplitA = B.CreateExtractVector(HalfTy, InputA, B.getInt64(0));
28529 auto *LowerSplitB = B.CreateExtractVector(HalfTy, InputB, B.getInt64(0));
28530 auto *UpperSplitA =
28531 B.CreateExtractVector(HalfTy, InputA, B.getInt64(Stride));
28532 auto *UpperSplitB =
28533 B.CreateExtractVector(HalfTy, InputB, B.getInt64(Stride));
28534 Value *LowerSplitAcc = nullptr;
28535 Value *UpperSplitAcc = nullptr;
28536 if (Accumulator) {
28537 LowerSplitAcc = B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(0));
28538 UpperSplitAcc =
28539 B.CreateExtractVector(HalfTy, Accumulator, B.getInt64(Stride));
28540 }
28541 auto *LowerSplitInt = createComplexDeinterleavingIR(
28542 B, OperationType, Rotation, LowerSplitA, LowerSplitB, LowerSplitAcc);
28543 auto *UpperSplitInt = createComplexDeinterleavingIR(
28544 B, OperationType, Rotation, UpperSplitA, UpperSplitB, UpperSplitAcc);
28545
28546 auto *Result = B.CreateInsertVector(Ty, PoisonValue::get(Ty), LowerSplitInt,
28547 B.getInt64(0));
28548 return B.CreateInsertVector(Ty, Result, UpperSplitInt, B.getInt64(Stride));
28549 }
28550
28551 if (OperationType == ComplexDeinterleavingOperation::CMulPartial) {
28552 if (Accumulator == nullptr)
28553 Accumulator = Constant::getNullValue(Ty);
28554
28555 if (IsScalable) {
28556 if (IsInt)
28557 return B.CreateIntrinsic(
28558 Intrinsic::aarch64_sve_cmla_x, Ty,
28559 {Accumulator, InputA, InputB, B.getInt32((int)Rotation * 90)});
28560
28561 auto *Mask = B.getAllOnesMask(Ty->getElementCount());
28562 return B.CreateIntrinsic(
28563 Intrinsic::aarch64_sve_fcmla, Ty,
28564 {Mask, Accumulator, InputA, InputB, B.getInt32((int)Rotation * 90)});
28565 }
28566
28567 Intrinsic::ID IdMap[4] = {Intrinsic::aarch64_neon_vcmla_rot0,
28568 Intrinsic::aarch64_neon_vcmla_rot90,
28569 Intrinsic::aarch64_neon_vcmla_rot180,
28570 Intrinsic::aarch64_neon_vcmla_rot270};
28571
28572
28573 return B.CreateIntrinsic(IdMap[(int)Rotation], Ty,
28574 {Accumulator, InputA, InputB});
28575 }
28576
28577 if (OperationType == ComplexDeinterleavingOperation::CAdd) {
28578 if (IsScalable) {
28579 if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
28580 Rotation == ComplexDeinterleavingRotation::Rotation_270) {
28581 if (IsInt)
28582 return B.CreateIntrinsic(
28583 Intrinsic::aarch64_sve_cadd_x, Ty,
28584 {InputA, InputB, B.getInt32((int)Rotation * 90)});
28585
28586 auto *Mask = B.getAllOnesMask(Ty->getElementCount());
28587 return B.CreateIntrinsic(
28588 Intrinsic::aarch64_sve_fcadd, Ty,
28589 {Mask, InputA, InputB, B.getInt32((int)Rotation * 90)});
28590 }
28591 return nullptr;
28592 }
28593
28594 Intrinsic::ID IntId = Intrinsic::not_intrinsic;
28595 if (Rotation == ComplexDeinterleavingRotation::Rotation_90)
28596 IntId = Intrinsic::aarch64_neon_vcadd_rot90;
28597 else if (Rotation == ComplexDeinterleavingRotation::Rotation_270)
28598 IntId = Intrinsic::aarch64_neon_vcadd_rot270;
28599
28600 if (IntId == Intrinsic::not_intrinsic)
28601 return nullptr;
28602
28603 return B.CreateIntrinsic(IntId, Ty, {InputA, InputB});
28604 }
28605
28606 return nullptr;
28607 }
28608
preferScalarizeSplat(SDNode * N) const28609 bool AArch64TargetLowering::preferScalarizeSplat(SDNode *N) const {
28610 unsigned Opc = N->getOpcode();
28611 if (ISD::isExtOpcode(Opc)) {
28612 if (any_of(N->uses(),
28613 [&](SDNode *Use) { return Use->getOpcode() == ISD::MUL; }))
28614 return false;
28615 }
28616 return true;
28617 }
28618
getMinimumJumpTableEntries() const28619 unsigned AArch64TargetLowering::getMinimumJumpTableEntries() const {
28620 return Subtarget->getMinimumJumpTableEntries();
28621 }
28622
getRegisterTypeForCallingConv(LLVMContext & Context,CallingConv::ID CC,EVT VT) const28623 MVT AArch64TargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
28624 CallingConv::ID CC,
28625 EVT VT) const {
28626 bool NonUnitFixedLengthVector =
28627 VT.isFixedLengthVector() && !VT.getVectorElementCount().isScalar();
28628 if (!NonUnitFixedLengthVector || !Subtarget->useSVEForFixedLengthVectors())
28629 return TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
28630
28631 EVT VT1;
28632 MVT RegisterVT;
28633 unsigned NumIntermediates;
28634 getVectorTypeBreakdownForCallingConv(Context, CC, VT, VT1, NumIntermediates,
28635 RegisterVT);
28636 return RegisterVT;
28637 }
28638
getNumRegistersForCallingConv(LLVMContext & Context,CallingConv::ID CC,EVT VT) const28639 unsigned AArch64TargetLowering::getNumRegistersForCallingConv(
28640 LLVMContext &Context, CallingConv::ID CC, EVT VT) const {
28641 bool NonUnitFixedLengthVector =
28642 VT.isFixedLengthVector() && !VT.getVectorElementCount().isScalar();
28643 if (!NonUnitFixedLengthVector || !Subtarget->useSVEForFixedLengthVectors())
28644 return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
28645
28646 EVT VT1;
28647 MVT VT2;
28648 unsigned NumIntermediates;
28649 return getVectorTypeBreakdownForCallingConv(Context, CC, VT, VT1,
28650 NumIntermediates, VT2);
28651 }
28652
getVectorTypeBreakdownForCallingConv(LLVMContext & Context,CallingConv::ID CC,EVT VT,EVT & IntermediateVT,unsigned & NumIntermediates,MVT & RegisterVT) const28653 unsigned AArch64TargetLowering::getVectorTypeBreakdownForCallingConv(
28654 LLVMContext &Context, CallingConv::ID CC, EVT VT, EVT &IntermediateVT,
28655 unsigned &NumIntermediates, MVT &RegisterVT) const {
28656 int NumRegs = TargetLowering::getVectorTypeBreakdownForCallingConv(
28657 Context, CC, VT, IntermediateVT, NumIntermediates, RegisterVT);
28658 if (!RegisterVT.isFixedLengthVector() ||
28659 RegisterVT.getFixedSizeInBits() <= 128)
28660 return NumRegs;
28661
28662 assert(Subtarget->useSVEForFixedLengthVectors() && "Unexpected mode!");
28663 assert(IntermediateVT == RegisterVT && "Unexpected VT mismatch!");
28664 assert(RegisterVT.getFixedSizeInBits() % 128 == 0 && "Unexpected size!");
28665
28666 // A size mismatch here implies either type promotion or widening and would
28667 // have resulted in scalarisation if larger vectors had not be available.
28668 if (RegisterVT.getSizeInBits() * NumRegs != VT.getSizeInBits()) {
28669 EVT EltTy = VT.getVectorElementType();
28670 EVT NewVT = EVT::getVectorVT(Context, EltTy, ElementCount::getFixed(1));
28671 if (!isTypeLegal(NewVT))
28672 NewVT = EltTy;
28673
28674 IntermediateVT = NewVT;
28675 NumIntermediates = VT.getVectorNumElements();
28676 RegisterVT = getRegisterType(Context, NewVT);
28677 return NumIntermediates;
28678 }
28679
28680 // SVE VLS support does not introduce a new ABI so we should use NEON sized
28681 // types for vector arguments and returns.
28682
28683 unsigned NumSubRegs = RegisterVT.getFixedSizeInBits() / 128;
28684 NumIntermediates *= NumSubRegs;
28685 NumRegs *= NumSubRegs;
28686
28687 switch (RegisterVT.getVectorElementType().SimpleTy) {
28688 default:
28689 llvm_unreachable("unexpected element type for vector");
28690 case MVT::i8:
28691 IntermediateVT = RegisterVT = MVT::v16i8;
28692 break;
28693 case MVT::i16:
28694 IntermediateVT = RegisterVT = MVT::v8i16;
28695 break;
28696 case MVT::i32:
28697 IntermediateVT = RegisterVT = MVT::v4i32;
28698 break;
28699 case MVT::i64:
28700 IntermediateVT = RegisterVT = MVT::v2i64;
28701 break;
28702 case MVT::f16:
28703 IntermediateVT = RegisterVT = MVT::v8f16;
28704 break;
28705 case MVT::f32:
28706 IntermediateVT = RegisterVT = MVT::v4f32;
28707 break;
28708 case MVT::f64:
28709 IntermediateVT = RegisterVT = MVT::v2f64;
28710 break;
28711 case MVT::bf16:
28712 IntermediateVT = RegisterVT = MVT::v8bf16;
28713 break;
28714 }
28715
28716 return NumRegs;
28717 }
28718
hasInlineStackProbe(const MachineFunction & MF) const28719 bool AArch64TargetLowering::hasInlineStackProbe(
28720 const MachineFunction &MF) const {
28721 return !Subtarget->isTargetWindows() &&
28722 MF.getInfo<AArch64FunctionInfo>()->hasStackProbing();
28723 }
28724
28725 #ifndef NDEBUG
verifyTargetSDNode(const SDNode * N) const28726 void AArch64TargetLowering::verifyTargetSDNode(const SDNode *N) const {
28727 switch (N->getOpcode()) {
28728 default:
28729 break;
28730 case AArch64ISD::SUNPKLO:
28731 case AArch64ISD::SUNPKHI:
28732 case AArch64ISD::UUNPKLO:
28733 case AArch64ISD::UUNPKHI: {
28734 assert(N->getNumValues() == 1 && "Expected one result!");
28735 assert(N->getNumOperands() == 1 && "Expected one operand!");
28736 EVT VT = N->getValueType(0);
28737 EVT OpVT = N->getOperand(0).getValueType();
28738 assert(OpVT.isVector() && VT.isVector() && OpVT.isInteger() &&
28739 VT.isInteger() && "Expected integer vectors!");
28740 assert(OpVT.getSizeInBits() == VT.getSizeInBits() &&
28741 "Expected vectors of equal size!");
28742 // TODO: Enable assert once bogus creations have been fixed.
28743 // assert(OpVT.getVectorElementCount() == VT.getVectorElementCount()*2 &&
28744 // "Expected result vector with half the lanes of its input!");
28745 break;
28746 }
28747 case AArch64ISD::TRN1:
28748 case AArch64ISD::TRN2:
28749 case AArch64ISD::UZP1:
28750 case AArch64ISD::UZP2:
28751 case AArch64ISD::ZIP1:
28752 case AArch64ISD::ZIP2: {
28753 assert(N->getNumValues() == 1 && "Expected one result!");
28754 assert(N->getNumOperands() == 2 && "Expected two operands!");
28755 EVT VT = N->getValueType(0);
28756 EVT Op0VT = N->getOperand(0).getValueType();
28757 EVT Op1VT = N->getOperand(1).getValueType();
28758 assert(VT.isVector() && Op0VT.isVector() && Op1VT.isVector() &&
28759 "Expected vectors!");
28760 // TODO: Enable assert once bogus creations have been fixed.
28761 // assert(VT == Op0VT && VT == Op1VT && "Expected matching vectors!");
28762 break;
28763 }
28764 }
28765 }
28766 #endif
28767