xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (revision 5ca8e32633c4ffbbcd6762e5888b6a4ba0708c6c)
1 //===-- AArch64TargetTransformInfo.cpp - AArch64 specific TTI -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "AArch64TargetTransformInfo.h"
10 #include "AArch64ExpandImm.h"
11 #include "AArch64PerfectShuffle.h"
12 #include "MCTargetDesc/AArch64AddressingModes.h"
13 #include "llvm/Analysis/IVDescriptors.h"
14 #include "llvm/Analysis/LoopInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/CodeGen/BasicTTIImpl.h"
17 #include "llvm/CodeGen/CostTable.h"
18 #include "llvm/CodeGen/TargetLowering.h"
19 #include "llvm/IR/IntrinsicInst.h"
20 #include "llvm/IR/Intrinsics.h"
21 #include "llvm/IR/IntrinsicsAArch64.h"
22 #include "llvm/IR/PatternMatch.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Transforms/InstCombine/InstCombiner.h"
25 #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
26 #include <algorithm>
27 #include <optional>
28 using namespace llvm;
29 using namespace llvm::PatternMatch;
30 
31 #define DEBUG_TYPE "aarch64tti"
32 
33 static cl::opt<bool> EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix",
34                                                cl::init(true), cl::Hidden);
35 
36 static cl::opt<unsigned> SVEGatherOverhead("sve-gather-overhead", cl::init(10),
37                                            cl::Hidden);
38 
39 static cl::opt<unsigned> SVEScatterOverhead("sve-scatter-overhead",
40                                             cl::init(10), cl::Hidden);
41 
42 static cl::opt<unsigned> SVETailFoldInsnThreshold("sve-tail-folding-insn-threshold",
43                                                   cl::init(15), cl::Hidden);
44 
45 static cl::opt<unsigned>
46     NeonNonConstStrideOverhead("neon-nonconst-stride-overhead", cl::init(10),
47                                cl::Hidden);
48 
49 namespace {
50 class TailFoldingOption {
51   // These bitfields will only ever be set to something non-zero in operator=,
52   // when setting the -sve-tail-folding option. This option should always be of
53   // the form (default|simple|all|disable)[+(Flag1|Flag2|etc)], where here
54   // InitialBits is one of (disabled|all|simple). EnableBits represents
55   // additional flags we're enabling, and DisableBits for those flags we're
56   // disabling. The default flag is tracked in the variable NeedsDefault, since
57   // at the time of setting the option we may not know what the default value
58   // for the CPU is.
59   TailFoldingOpts InitialBits = TailFoldingOpts::Disabled;
60   TailFoldingOpts EnableBits = TailFoldingOpts::Disabled;
61   TailFoldingOpts DisableBits = TailFoldingOpts::Disabled;
62 
63   // This value needs to be initialised to true in case the user does not
64   // explicitly set the -sve-tail-folding option.
65   bool NeedsDefault = true;
66 
67   void setInitialBits(TailFoldingOpts Bits) { InitialBits = Bits; }
68 
69   void setNeedsDefault(bool V) { NeedsDefault = V; }
70 
71   void setEnableBit(TailFoldingOpts Bit) {
72     EnableBits |= Bit;
73     DisableBits &= ~Bit;
74   }
75 
76   void setDisableBit(TailFoldingOpts Bit) {
77     EnableBits &= ~Bit;
78     DisableBits |= Bit;
79   }
80 
81   TailFoldingOpts getBits(TailFoldingOpts DefaultBits) const {
82     TailFoldingOpts Bits = TailFoldingOpts::Disabled;
83 
84     assert((InitialBits == TailFoldingOpts::Disabled || !NeedsDefault) &&
85            "Initial bits should only include one of "
86            "(disabled|all|simple|default)");
87     Bits = NeedsDefault ? DefaultBits : InitialBits;
88     Bits |= EnableBits;
89     Bits &= ~DisableBits;
90 
91     return Bits;
92   }
93 
94   void reportError(std::string Opt) {
95     errs() << "invalid argument '" << Opt
96            << "' to -sve-tail-folding=; the option should be of the form\n"
97               "  (disabled|all|default|simple)[+(reductions|recurrences"
98               "|reverse|noreductions|norecurrences|noreverse)]\n";
99     report_fatal_error("Unrecognised tail-folding option");
100   }
101 
102 public:
103 
104   void operator=(const std::string &Val) {
105     // If the user explicitly sets -sve-tail-folding= then treat as an error.
106     if (Val.empty()) {
107       reportError("");
108       return;
109     }
110 
111     // Since the user is explicitly setting the option we don't automatically
112     // need the default unless they require it.
113     setNeedsDefault(false);
114 
115     SmallVector<StringRef, 4> TailFoldTypes;
116     StringRef(Val).split(TailFoldTypes, '+', -1, false);
117 
118     unsigned StartIdx = 1;
119     if (TailFoldTypes[0] == "disabled")
120       setInitialBits(TailFoldingOpts::Disabled);
121     else if (TailFoldTypes[0] == "all")
122       setInitialBits(TailFoldingOpts::All);
123     else if (TailFoldTypes[0] == "default")
124       setNeedsDefault(true);
125     else if (TailFoldTypes[0] == "simple")
126       setInitialBits(TailFoldingOpts::Simple);
127     else {
128       StartIdx = 0;
129       setInitialBits(TailFoldingOpts::Disabled);
130     }
131 
132     for (unsigned I = StartIdx; I < TailFoldTypes.size(); I++) {
133       if (TailFoldTypes[I] == "reductions")
134         setEnableBit(TailFoldingOpts::Reductions);
135       else if (TailFoldTypes[I] == "recurrences")
136         setEnableBit(TailFoldingOpts::Recurrences);
137       else if (TailFoldTypes[I] == "reverse")
138         setEnableBit(TailFoldingOpts::Reverse);
139       else if (TailFoldTypes[I] == "noreductions")
140         setDisableBit(TailFoldingOpts::Reductions);
141       else if (TailFoldTypes[I] == "norecurrences")
142         setDisableBit(TailFoldingOpts::Recurrences);
143       else if (TailFoldTypes[I] == "noreverse")
144         setDisableBit(TailFoldingOpts::Reverse);
145       else
146         reportError(Val);
147     }
148   }
149 
150   bool satisfies(TailFoldingOpts DefaultBits, TailFoldingOpts Required) const {
151     return (getBits(DefaultBits) & Required) == Required;
152   }
153 };
154 } // namespace
155 
156 TailFoldingOption TailFoldingOptionLoc;
157 
158 cl::opt<TailFoldingOption, true, cl::parser<std::string>> SVETailFolding(
159     "sve-tail-folding",
160     cl::desc(
161         "Control the use of vectorisation using tail-folding for SVE where the"
162         " option is specified in the form (Initial)[+(Flag1|Flag2|...)]:"
163         "\ndisabled      (Initial) No loop types will vectorize using "
164         "tail-folding"
165         "\ndefault       (Initial) Uses the default tail-folding settings for "
166         "the target CPU"
167         "\nall           (Initial) All legal loop types will vectorize using "
168         "tail-folding"
169         "\nsimple        (Initial) Use tail-folding for simple loops (not "
170         "reductions or recurrences)"
171         "\nreductions    Use tail-folding for loops containing reductions"
172         "\nnoreductions  Inverse of above"
173         "\nrecurrences   Use tail-folding for loops containing fixed order "
174         "recurrences"
175         "\nnorecurrences Inverse of above"
176         "\nreverse       Use tail-folding for loops requiring reversed "
177         "predicates"
178         "\nnoreverse     Inverse of above"),
179     cl::location(TailFoldingOptionLoc));
180 
181 // Experimental option that will only be fully functional when the
182 // code-generator is changed to use SVE instead of NEON for all fixed-width
183 // operations.
184 static cl::opt<bool> EnableFixedwidthAutovecInStreamingMode(
185     "enable-fixedwidth-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
186 
187 // Experimental option that will only be fully functional when the cost-model
188 // and code-generator have been changed to avoid using scalable vector
189 // instructions that are not legal in streaming SVE mode.
190 static cl::opt<bool> EnableScalableAutovecInStreamingMode(
191     "enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
192 
193 bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
194                                          const Function *Callee) const {
195   SMEAttrs CallerAttrs(*Caller);
196   SMEAttrs CalleeAttrs(*Callee);
197   if (CallerAttrs.requiresSMChange(CalleeAttrs,
198                                    /*BodyOverridesInterface=*/true) ||
199       CallerAttrs.requiresLazySave(CalleeAttrs) ||
200       CalleeAttrs.hasNewZAInterface())
201     return false;
202 
203   const TargetMachine &TM = getTLI()->getTargetMachine();
204 
205   const FeatureBitset &CallerBits =
206       TM.getSubtargetImpl(*Caller)->getFeatureBits();
207   const FeatureBitset &CalleeBits =
208       TM.getSubtargetImpl(*Callee)->getFeatureBits();
209 
210   // Inline a callee if its target-features are a subset of the callers
211   // target-features.
212   return (CallerBits & CalleeBits) == CalleeBits;
213 }
214 
215 bool AArch64TTIImpl::areTypesABICompatible(
216     const Function *Caller, const Function *Callee,
217     const ArrayRef<Type *> &Types) const {
218   if (!BaseT::areTypesABICompatible(Caller, Callee, Types))
219     return false;
220 
221   // We need to ensure that argument promotion does not attempt to promote
222   // pointers to fixed-length vector types larger than 128 bits like
223   // <8 x float> (and pointers to aggregate types which have such fixed-length
224   // vector type members) into the values of the pointees. Such vector types
225   // are used for SVE VLS but there is no ABI for SVE VLS arguments and the
226   // backend cannot lower such value arguments. The 128-bit fixed-length SVE
227   // types can be safely treated as 128-bit NEON types and they cannot be
228   // distinguished in IR.
229   if (ST->useSVEForFixedLengthVectors() && llvm::any_of(Types, [](Type *Ty) {
230         auto FVTy = dyn_cast<FixedVectorType>(Ty);
231         return FVTy &&
232                FVTy->getScalarSizeInBits() * FVTy->getNumElements() > 128;
233       }))
234     return false;
235 
236   return true;
237 }
238 
239 bool AArch64TTIImpl::shouldMaximizeVectorBandwidth(
240     TargetTransformInfo::RegisterKind K) const {
241   assert(K != TargetTransformInfo::RGK_Scalar);
242   return (K == TargetTransformInfo::RGK_FixedWidthVector &&
243           ST->isNeonAvailable());
244 }
245 
246 /// Calculate the cost of materializing a 64-bit value. This helper
247 /// method might only calculate a fraction of a larger immediate. Therefore it
248 /// is valid to return a cost of ZERO.
249 InstructionCost AArch64TTIImpl::getIntImmCost(int64_t Val) {
250   // Check if the immediate can be encoded within an instruction.
251   if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64))
252     return 0;
253 
254   if (Val < 0)
255     Val = ~Val;
256 
257   // Calculate how many moves we will need to materialize this constant.
258   SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn;
259   AArch64_IMM::expandMOVImm(Val, 64, Insn);
260   return Insn.size();
261 }
262 
263 /// Calculate the cost of materializing the given constant.
264 InstructionCost AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
265                                               TTI::TargetCostKind CostKind) {
266   assert(Ty->isIntegerTy());
267 
268   unsigned BitSize = Ty->getPrimitiveSizeInBits();
269   if (BitSize == 0)
270     return ~0U;
271 
272   // Sign-extend all constants to a multiple of 64-bit.
273   APInt ImmVal = Imm;
274   if (BitSize & 0x3f)
275     ImmVal = Imm.sext((BitSize + 63) & ~0x3fU);
276 
277   // Split the constant into 64-bit chunks and calculate the cost for each
278   // chunk.
279   InstructionCost Cost = 0;
280   for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) {
281     APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64);
282     int64_t Val = Tmp.getSExtValue();
283     Cost += getIntImmCost(Val);
284   }
285   // We need at least one instruction to materialze the constant.
286   return std::max<InstructionCost>(1, Cost);
287 }
288 
289 InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
290                                                   const APInt &Imm, Type *Ty,
291                                                   TTI::TargetCostKind CostKind,
292                                                   Instruction *Inst) {
293   assert(Ty->isIntegerTy());
294 
295   unsigned BitSize = Ty->getPrimitiveSizeInBits();
296   // There is no cost model for constants with a bit size of 0. Return TCC_Free
297   // here, so that constant hoisting will ignore this constant.
298   if (BitSize == 0)
299     return TTI::TCC_Free;
300 
301   unsigned ImmIdx = ~0U;
302   switch (Opcode) {
303   default:
304     return TTI::TCC_Free;
305   case Instruction::GetElementPtr:
306     // Always hoist the base address of a GetElementPtr.
307     if (Idx == 0)
308       return 2 * TTI::TCC_Basic;
309     return TTI::TCC_Free;
310   case Instruction::Store:
311     ImmIdx = 0;
312     break;
313   case Instruction::Add:
314   case Instruction::Sub:
315   case Instruction::Mul:
316   case Instruction::UDiv:
317   case Instruction::SDiv:
318   case Instruction::URem:
319   case Instruction::SRem:
320   case Instruction::And:
321   case Instruction::Or:
322   case Instruction::Xor:
323   case Instruction::ICmp:
324     ImmIdx = 1;
325     break;
326   // Always return TCC_Free for the shift value of a shift instruction.
327   case Instruction::Shl:
328   case Instruction::LShr:
329   case Instruction::AShr:
330     if (Idx == 1)
331       return TTI::TCC_Free;
332     break;
333   case Instruction::Trunc:
334   case Instruction::ZExt:
335   case Instruction::SExt:
336   case Instruction::IntToPtr:
337   case Instruction::PtrToInt:
338   case Instruction::BitCast:
339   case Instruction::PHI:
340   case Instruction::Call:
341   case Instruction::Select:
342   case Instruction::Ret:
343   case Instruction::Load:
344     break;
345   }
346 
347   if (Idx == ImmIdx) {
348     int NumConstants = (BitSize + 63) / 64;
349     InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
350     return (Cost <= NumConstants * TTI::TCC_Basic)
351                ? static_cast<int>(TTI::TCC_Free)
352                : Cost;
353   }
354   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
355 }
356 
357 InstructionCost
358 AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
359                                     const APInt &Imm, Type *Ty,
360                                     TTI::TargetCostKind CostKind) {
361   assert(Ty->isIntegerTy());
362 
363   unsigned BitSize = Ty->getPrimitiveSizeInBits();
364   // There is no cost model for constants with a bit size of 0. Return TCC_Free
365   // here, so that constant hoisting will ignore this constant.
366   if (BitSize == 0)
367     return TTI::TCC_Free;
368 
369   // Most (all?) AArch64 intrinsics do not support folding immediates into the
370   // selected instruction, so we compute the materialization cost for the
371   // immediate directly.
372   if (IID >= Intrinsic::aarch64_addg && IID <= Intrinsic::aarch64_udiv)
373     return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
374 
375   switch (IID) {
376   default:
377     return TTI::TCC_Free;
378   case Intrinsic::sadd_with_overflow:
379   case Intrinsic::uadd_with_overflow:
380   case Intrinsic::ssub_with_overflow:
381   case Intrinsic::usub_with_overflow:
382   case Intrinsic::smul_with_overflow:
383   case Intrinsic::umul_with_overflow:
384     if (Idx == 1) {
385       int NumConstants = (BitSize + 63) / 64;
386       InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
387       return (Cost <= NumConstants * TTI::TCC_Basic)
388                  ? static_cast<int>(TTI::TCC_Free)
389                  : Cost;
390     }
391     break;
392   case Intrinsic::experimental_stackmap:
393     if ((Idx < 2) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
394       return TTI::TCC_Free;
395     break;
396   case Intrinsic::experimental_patchpoint_void:
397   case Intrinsic::experimental_patchpoint_i64:
398     if ((Idx < 4) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
399       return TTI::TCC_Free;
400     break;
401   case Intrinsic::experimental_gc_statepoint:
402     if ((Idx < 5) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
403       return TTI::TCC_Free;
404     break;
405   }
406   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
407 }
408 
409 TargetTransformInfo::PopcntSupportKind
410 AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) {
411   assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2");
412   if (TyWidth == 32 || TyWidth == 64)
413     return TTI::PSK_FastHardware;
414   // TODO: AArch64TargetLowering::LowerCTPOP() supports 128bit popcount.
415   return TTI::PSK_Software;
416 }
417 
418 InstructionCost
419 AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
420                                       TTI::TargetCostKind CostKind) {
421   auto *RetTy = ICA.getReturnType();
422   switch (ICA.getID()) {
423   case Intrinsic::umin:
424   case Intrinsic::umax:
425   case Intrinsic::smin:
426   case Intrinsic::smax: {
427     static const auto ValidMinMaxTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
428                                         MVT::v8i16, MVT::v2i32, MVT::v4i32,
429                                         MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32,
430                                         MVT::nxv2i64};
431     auto LT = getTypeLegalizationCost(RetTy);
432     // v2i64 types get converted to cmp+bif hence the cost of 2
433     if (LT.second == MVT::v2i64)
434       return LT.first * 2;
435     if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }))
436       return LT.first;
437     break;
438   }
439   case Intrinsic::sadd_sat:
440   case Intrinsic::ssub_sat:
441   case Intrinsic::uadd_sat:
442   case Intrinsic::usub_sat: {
443     static const auto ValidSatTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
444                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
445                                      MVT::v2i64};
446     auto LT = getTypeLegalizationCost(RetTy);
447     // This is a base cost of 1 for the vadd, plus 3 extract shifts if we
448     // need to extend the type, as it uses shr(qadd(shl, shl)).
449     unsigned Instrs =
450         LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4;
451     if (any_of(ValidSatTys, [&LT](MVT M) { return M == LT.second; }))
452       return LT.first * Instrs;
453     break;
454   }
455   case Intrinsic::abs: {
456     static const auto ValidAbsTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
457                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
458                                      MVT::v2i64};
459     auto LT = getTypeLegalizationCost(RetTy);
460     if (any_of(ValidAbsTys, [&LT](MVT M) { return M == LT.second; }))
461       return LT.first;
462     break;
463   }
464   case Intrinsic::bswap: {
465     static const auto ValidAbsTys = {MVT::v4i16, MVT::v8i16, MVT::v2i32,
466                                      MVT::v4i32, MVT::v2i64};
467     auto LT = getTypeLegalizationCost(RetTy);
468     if (any_of(ValidAbsTys, [&LT](MVT M) { return M == LT.second; }) &&
469         LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits())
470       return LT.first;
471     break;
472   }
473   case Intrinsic::experimental_stepvector: {
474     InstructionCost Cost = 1; // Cost of the `index' instruction
475     auto LT = getTypeLegalizationCost(RetTy);
476     // Legalisation of illegal vectors involves an `index' instruction plus
477     // (LT.first - 1) vector adds.
478     if (LT.first > 1) {
479       Type *LegalVTy = EVT(LT.second).getTypeForEVT(RetTy->getContext());
480       InstructionCost AddCost =
481           getArithmeticInstrCost(Instruction::Add, LegalVTy, CostKind);
482       Cost += AddCost * (LT.first - 1);
483     }
484     return Cost;
485   }
486   case Intrinsic::bitreverse: {
487     static const CostTblEntry BitreverseTbl[] = {
488         {Intrinsic::bitreverse, MVT::i32, 1},
489         {Intrinsic::bitreverse, MVT::i64, 1},
490         {Intrinsic::bitreverse, MVT::v8i8, 1},
491         {Intrinsic::bitreverse, MVT::v16i8, 1},
492         {Intrinsic::bitreverse, MVT::v4i16, 2},
493         {Intrinsic::bitreverse, MVT::v8i16, 2},
494         {Intrinsic::bitreverse, MVT::v2i32, 2},
495         {Intrinsic::bitreverse, MVT::v4i32, 2},
496         {Intrinsic::bitreverse, MVT::v1i64, 2},
497         {Intrinsic::bitreverse, MVT::v2i64, 2},
498     };
499     const auto LegalisationCost = getTypeLegalizationCost(RetTy);
500     const auto *Entry =
501         CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second);
502     if (Entry) {
503       // Cost Model is using the legal type(i32) that i8 and i16 will be
504       // converted to +1 so that we match the actual lowering cost
505       if (TLI->getValueType(DL, RetTy, true) == MVT::i8 ||
506           TLI->getValueType(DL, RetTy, true) == MVT::i16)
507         return LegalisationCost.first * Entry->Cost + 1;
508 
509       return LegalisationCost.first * Entry->Cost;
510     }
511     break;
512   }
513   case Intrinsic::ctpop: {
514     if (!ST->hasNEON()) {
515       // 32-bit or 64-bit ctpop without NEON is 12 instructions.
516       return getTypeLegalizationCost(RetTy).first * 12;
517     }
518     static const CostTblEntry CtpopCostTbl[] = {
519         {ISD::CTPOP, MVT::v2i64, 4},
520         {ISD::CTPOP, MVT::v4i32, 3},
521         {ISD::CTPOP, MVT::v8i16, 2},
522         {ISD::CTPOP, MVT::v16i8, 1},
523         {ISD::CTPOP, MVT::i64,   4},
524         {ISD::CTPOP, MVT::v2i32, 3},
525         {ISD::CTPOP, MVT::v4i16, 2},
526         {ISD::CTPOP, MVT::v8i8,  1},
527         {ISD::CTPOP, MVT::i32,   5},
528     };
529     auto LT = getTypeLegalizationCost(RetTy);
530     MVT MTy = LT.second;
531     if (const auto *Entry = CostTableLookup(CtpopCostTbl, ISD::CTPOP, MTy)) {
532       // Extra cost of +1 when illegal vector types are legalized by promoting
533       // the integer type.
534       int ExtraCost = MTy.isVector() && MTy.getScalarSizeInBits() !=
535                                             RetTy->getScalarSizeInBits()
536                           ? 1
537                           : 0;
538       return LT.first * Entry->Cost + ExtraCost;
539     }
540     break;
541   }
542   case Intrinsic::sadd_with_overflow:
543   case Intrinsic::uadd_with_overflow:
544   case Intrinsic::ssub_with_overflow:
545   case Intrinsic::usub_with_overflow:
546   case Intrinsic::smul_with_overflow:
547   case Intrinsic::umul_with_overflow: {
548     static const CostTblEntry WithOverflowCostTbl[] = {
549         {Intrinsic::sadd_with_overflow, MVT::i8, 3},
550         {Intrinsic::uadd_with_overflow, MVT::i8, 3},
551         {Intrinsic::sadd_with_overflow, MVT::i16, 3},
552         {Intrinsic::uadd_with_overflow, MVT::i16, 3},
553         {Intrinsic::sadd_with_overflow, MVT::i32, 1},
554         {Intrinsic::uadd_with_overflow, MVT::i32, 1},
555         {Intrinsic::sadd_with_overflow, MVT::i64, 1},
556         {Intrinsic::uadd_with_overflow, MVT::i64, 1},
557         {Intrinsic::ssub_with_overflow, MVT::i8, 3},
558         {Intrinsic::usub_with_overflow, MVT::i8, 3},
559         {Intrinsic::ssub_with_overflow, MVT::i16, 3},
560         {Intrinsic::usub_with_overflow, MVT::i16, 3},
561         {Intrinsic::ssub_with_overflow, MVT::i32, 1},
562         {Intrinsic::usub_with_overflow, MVT::i32, 1},
563         {Intrinsic::ssub_with_overflow, MVT::i64, 1},
564         {Intrinsic::usub_with_overflow, MVT::i64, 1},
565         {Intrinsic::smul_with_overflow, MVT::i8, 5},
566         {Intrinsic::umul_with_overflow, MVT::i8, 4},
567         {Intrinsic::smul_with_overflow, MVT::i16, 5},
568         {Intrinsic::umul_with_overflow, MVT::i16, 4},
569         {Intrinsic::smul_with_overflow, MVT::i32, 2}, // eg umull;tst
570         {Intrinsic::umul_with_overflow, MVT::i32, 2}, // eg umull;cmp sxtw
571         {Intrinsic::smul_with_overflow, MVT::i64, 3}, // eg mul;smulh;cmp
572         {Intrinsic::umul_with_overflow, MVT::i64, 3}, // eg mul;umulh;cmp asr
573     };
574     EVT MTy = TLI->getValueType(DL, RetTy->getContainedType(0), true);
575     if (MTy.isSimple())
576       if (const auto *Entry = CostTableLookup(WithOverflowCostTbl, ICA.getID(),
577                                               MTy.getSimpleVT()))
578         return Entry->Cost;
579     break;
580   }
581   case Intrinsic::fptosi_sat:
582   case Intrinsic::fptoui_sat: {
583     if (ICA.getArgTypes().empty())
584       break;
585     bool IsSigned = ICA.getID() == Intrinsic::fptosi_sat;
586     auto LT = getTypeLegalizationCost(ICA.getArgTypes()[0]);
587     EVT MTy = TLI->getValueType(DL, RetTy);
588     // Check for the legal types, which are where the size of the input and the
589     // output are the same, or we are using cvt f64->i32 or f32->i64.
590     if ((LT.second == MVT::f32 || LT.second == MVT::f64 ||
591          LT.second == MVT::v2f32 || LT.second == MVT::v4f32 ||
592          LT.second == MVT::v2f64) &&
593         (LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits() ||
594          (LT.second == MVT::f64 && MTy == MVT::i32) ||
595          (LT.second == MVT::f32 && MTy == MVT::i64)))
596       return LT.first;
597     // Similarly for fp16 sizes
598     if (ST->hasFullFP16() &&
599         ((LT.second == MVT::f16 && MTy == MVT::i32) ||
600          ((LT.second == MVT::v4f16 || LT.second == MVT::v8f16) &&
601           (LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits()))))
602       return LT.first;
603 
604     // Otherwise we use a legal convert followed by a min+max
605     if ((LT.second.getScalarType() == MVT::f32 ||
606          LT.second.getScalarType() == MVT::f64 ||
607          (ST->hasFullFP16() && LT.second.getScalarType() == MVT::f16)) &&
608         LT.second.getScalarSizeInBits() >= MTy.getScalarSizeInBits()) {
609       Type *LegalTy =
610           Type::getIntNTy(RetTy->getContext(), LT.second.getScalarSizeInBits());
611       if (LT.second.isVector())
612         LegalTy = VectorType::get(LegalTy, LT.second.getVectorElementCount());
613       InstructionCost Cost = 1;
614       IntrinsicCostAttributes Attrs1(IsSigned ? Intrinsic::smin : Intrinsic::umin,
615                                     LegalTy, {LegalTy, LegalTy});
616       Cost += getIntrinsicInstrCost(Attrs1, CostKind);
617       IntrinsicCostAttributes Attrs2(IsSigned ? Intrinsic::smax : Intrinsic::umax,
618                                     LegalTy, {LegalTy, LegalTy});
619       Cost += getIntrinsicInstrCost(Attrs2, CostKind);
620       return LT.first * Cost;
621     }
622     break;
623   }
624   case Intrinsic::fshl:
625   case Intrinsic::fshr: {
626     if (ICA.getArgs().empty())
627       break;
628 
629     // TODO: Add handling for fshl where third argument is not a constant.
630     const TTI::OperandValueInfo OpInfoZ = TTI::getOperandInfo(ICA.getArgs()[2]);
631     if (!OpInfoZ.isConstant())
632       break;
633 
634     const auto LegalisationCost = getTypeLegalizationCost(RetTy);
635     if (OpInfoZ.isUniform()) {
636       // FIXME: The costs could be lower if the codegen is better.
637       static const CostTblEntry FshlTbl[] = {
638           {Intrinsic::fshl, MVT::v4i32, 3}, // ushr + shl + orr
639           {Intrinsic::fshl, MVT::v2i64, 3}, {Intrinsic::fshl, MVT::v16i8, 4},
640           {Intrinsic::fshl, MVT::v8i16, 4}, {Intrinsic::fshl, MVT::v2i32, 3},
641           {Intrinsic::fshl, MVT::v8i8, 4},  {Intrinsic::fshl, MVT::v4i16, 4}};
642       // Costs for both fshl & fshr are the same, so just pass Intrinsic::fshl
643       // to avoid having to duplicate the costs.
644       const auto *Entry =
645           CostTableLookup(FshlTbl, Intrinsic::fshl, LegalisationCost.second);
646       if (Entry)
647         return LegalisationCost.first * Entry->Cost;
648     }
649 
650     auto TyL = getTypeLegalizationCost(RetTy);
651     if (!RetTy->isIntegerTy())
652       break;
653 
654     // Estimate cost manually, as types like i8 and i16 will get promoted to
655     // i32 and CostTableLookup will ignore the extra conversion cost.
656     bool HigherCost = (RetTy->getScalarSizeInBits() != 32 &&
657                        RetTy->getScalarSizeInBits() < 64) ||
658                       (RetTy->getScalarSizeInBits() % 64 != 0);
659     unsigned ExtraCost = HigherCost ? 1 : 0;
660     if (RetTy->getScalarSizeInBits() == 32 ||
661         RetTy->getScalarSizeInBits() == 64)
662       ExtraCost = 0; // fhsl/fshr for i32 and i64 can be lowered to a single
663                      // extr instruction.
664     else if (HigherCost)
665       ExtraCost = 1;
666     else
667       break;
668     return TyL.first + ExtraCost;
669   }
670   default:
671     break;
672   }
673   return BaseT::getIntrinsicInstrCost(ICA, CostKind);
674 }
675 
676 /// The function will remove redundant reinterprets casting in the presence
677 /// of the control flow
678 static std::optional<Instruction *> processPhiNode(InstCombiner &IC,
679                                                    IntrinsicInst &II) {
680   SmallVector<Instruction *, 32> Worklist;
681   auto RequiredType = II.getType();
682 
683   auto *PN = dyn_cast<PHINode>(II.getArgOperand(0));
684   assert(PN && "Expected Phi Node!");
685 
686   // Don't create a new Phi unless we can remove the old one.
687   if (!PN->hasOneUse())
688     return std::nullopt;
689 
690   for (Value *IncValPhi : PN->incoming_values()) {
691     auto *Reinterpret = dyn_cast<IntrinsicInst>(IncValPhi);
692     if (!Reinterpret ||
693         Reinterpret->getIntrinsicID() !=
694             Intrinsic::aarch64_sve_convert_to_svbool ||
695         RequiredType != Reinterpret->getArgOperand(0)->getType())
696       return std::nullopt;
697   }
698 
699   // Create the new Phi
700   IC.Builder.SetInsertPoint(PN);
701   PHINode *NPN = IC.Builder.CreatePHI(RequiredType, PN->getNumIncomingValues());
702   Worklist.push_back(PN);
703 
704   for (unsigned I = 0; I < PN->getNumIncomingValues(); I++) {
705     auto *Reinterpret = cast<Instruction>(PN->getIncomingValue(I));
706     NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(I));
707     Worklist.push_back(Reinterpret);
708   }
709 
710   // Cleanup Phi Node and reinterprets
711   return IC.replaceInstUsesWith(II, NPN);
712 }
713 
714 // (from_svbool (binop (to_svbool pred) (svbool_t _) (svbool_t _))))
715 // => (binop (pred) (from_svbool _) (from_svbool _))
716 //
717 // The above transformation eliminates a `to_svbool` in the predicate
718 // operand of bitwise operation `binop` by narrowing the vector width of
719 // the operation. For example, it would convert a `<vscale x 16 x i1>
720 // and` into a `<vscale x 4 x i1> and`. This is profitable because
721 // to_svbool must zero the new lanes during widening, whereas
722 // from_svbool is free.
723 static std::optional<Instruction *>
724 tryCombineFromSVBoolBinOp(InstCombiner &IC, IntrinsicInst &II) {
725   auto BinOp = dyn_cast<IntrinsicInst>(II.getOperand(0));
726   if (!BinOp)
727     return std::nullopt;
728 
729   auto IntrinsicID = BinOp->getIntrinsicID();
730   switch (IntrinsicID) {
731   case Intrinsic::aarch64_sve_and_z:
732   case Intrinsic::aarch64_sve_bic_z:
733   case Intrinsic::aarch64_sve_eor_z:
734   case Intrinsic::aarch64_sve_nand_z:
735   case Intrinsic::aarch64_sve_nor_z:
736   case Intrinsic::aarch64_sve_orn_z:
737   case Intrinsic::aarch64_sve_orr_z:
738     break;
739   default:
740     return std::nullopt;
741   }
742 
743   auto BinOpPred = BinOp->getOperand(0);
744   auto BinOpOp1 = BinOp->getOperand(1);
745   auto BinOpOp2 = BinOp->getOperand(2);
746 
747   auto PredIntr = dyn_cast<IntrinsicInst>(BinOpPred);
748   if (!PredIntr ||
749       PredIntr->getIntrinsicID() != Intrinsic::aarch64_sve_convert_to_svbool)
750     return std::nullopt;
751 
752   auto PredOp = PredIntr->getOperand(0);
753   auto PredOpTy = cast<VectorType>(PredOp->getType());
754   if (PredOpTy != II.getType())
755     return std::nullopt;
756 
757   SmallVector<Value *> NarrowedBinOpArgs = {PredOp};
758   auto NarrowBinOpOp1 = IC.Builder.CreateIntrinsic(
759       Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp1});
760   NarrowedBinOpArgs.push_back(NarrowBinOpOp1);
761   if (BinOpOp1 == BinOpOp2)
762     NarrowedBinOpArgs.push_back(NarrowBinOpOp1);
763   else
764     NarrowedBinOpArgs.push_back(IC.Builder.CreateIntrinsic(
765         Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp2}));
766 
767   auto NarrowedBinOp =
768       IC.Builder.CreateIntrinsic(IntrinsicID, {PredOpTy}, NarrowedBinOpArgs);
769   return IC.replaceInstUsesWith(II, NarrowedBinOp);
770 }
771 
772 static std::optional<Instruction *>
773 instCombineConvertFromSVBool(InstCombiner &IC, IntrinsicInst &II) {
774   // If the reinterpret instruction operand is a PHI Node
775   if (isa<PHINode>(II.getArgOperand(0)))
776     return processPhiNode(IC, II);
777 
778   if (auto BinOpCombine = tryCombineFromSVBoolBinOp(IC, II))
779     return BinOpCombine;
780 
781   // Ignore converts to/from svcount_t.
782   if (isa<TargetExtType>(II.getArgOperand(0)->getType()) ||
783       isa<TargetExtType>(II.getType()))
784     return std::nullopt;
785 
786   SmallVector<Instruction *, 32> CandidatesForRemoval;
787   Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr;
788 
789   const auto *IVTy = cast<VectorType>(II.getType());
790 
791   // Walk the chain of conversions.
792   while (Cursor) {
793     // If the type of the cursor has fewer lanes than the final result, zeroing
794     // must take place, which breaks the equivalence chain.
795     const auto *CursorVTy = cast<VectorType>(Cursor->getType());
796     if (CursorVTy->getElementCount().getKnownMinValue() <
797         IVTy->getElementCount().getKnownMinValue())
798       break;
799 
800     // If the cursor has the same type as I, it is a viable replacement.
801     if (Cursor->getType() == IVTy)
802       EarliestReplacement = Cursor;
803 
804     auto *IntrinsicCursor = dyn_cast<IntrinsicInst>(Cursor);
805 
806     // If this is not an SVE conversion intrinsic, this is the end of the chain.
807     if (!IntrinsicCursor || !(IntrinsicCursor->getIntrinsicID() ==
808                                   Intrinsic::aarch64_sve_convert_to_svbool ||
809                               IntrinsicCursor->getIntrinsicID() ==
810                                   Intrinsic::aarch64_sve_convert_from_svbool))
811       break;
812 
813     CandidatesForRemoval.insert(CandidatesForRemoval.begin(), IntrinsicCursor);
814     Cursor = IntrinsicCursor->getOperand(0);
815   }
816 
817   // If no viable replacement in the conversion chain was found, there is
818   // nothing to do.
819   if (!EarliestReplacement)
820     return std::nullopt;
821 
822   return IC.replaceInstUsesWith(II, EarliestReplacement);
823 }
824 
825 static std::optional<Instruction *> instCombineSVESel(InstCombiner &IC,
826                                                       IntrinsicInst &II) {
827   auto Select = IC.Builder.CreateSelect(II.getOperand(0), II.getOperand(1),
828                                         II.getOperand(2));
829   return IC.replaceInstUsesWith(II, Select);
830 }
831 
832 static std::optional<Instruction *> instCombineSVEDup(InstCombiner &IC,
833                                                       IntrinsicInst &II) {
834   IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
835   if (!Pg)
836     return std::nullopt;
837 
838   if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
839     return std::nullopt;
840 
841   const auto PTruePattern =
842       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
843   if (PTruePattern != AArch64SVEPredPattern::vl1)
844     return std::nullopt;
845 
846   // The intrinsic is inserting into lane zero so use an insert instead.
847   auto *IdxTy = Type::getInt64Ty(II.getContext());
848   auto *Insert = InsertElementInst::Create(
849       II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0));
850   Insert->insertBefore(&II);
851   Insert->takeName(&II);
852 
853   return IC.replaceInstUsesWith(II, Insert);
854 }
855 
856 static std::optional<Instruction *> instCombineSVEDupX(InstCombiner &IC,
857                                                        IntrinsicInst &II) {
858   // Replace DupX with a regular IR splat.
859   auto *RetTy = cast<ScalableVectorType>(II.getType());
860   Value *Splat = IC.Builder.CreateVectorSplat(RetTy->getElementCount(),
861                                               II.getArgOperand(0));
862   Splat->takeName(&II);
863   return IC.replaceInstUsesWith(II, Splat);
864 }
865 
866 static std::optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
867                                                         IntrinsicInst &II) {
868   LLVMContext &Ctx = II.getContext();
869 
870   // Check that the predicate is all active
871   auto *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
872   if (!Pg || Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
873     return std::nullopt;
874 
875   const auto PTruePattern =
876       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
877   if (PTruePattern != AArch64SVEPredPattern::all)
878     return std::nullopt;
879 
880   // Check that we have a compare of zero..
881   auto *SplatValue =
882       dyn_cast_or_null<ConstantInt>(getSplatValue(II.getArgOperand(2)));
883   if (!SplatValue || !SplatValue->isZero())
884     return std::nullopt;
885 
886   // ..against a dupq
887   auto *DupQLane = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
888   if (!DupQLane ||
889       DupQLane->getIntrinsicID() != Intrinsic::aarch64_sve_dupq_lane)
890     return std::nullopt;
891 
892   // Where the dupq is a lane 0 replicate of a vector insert
893   if (!cast<ConstantInt>(DupQLane->getArgOperand(1))->isZero())
894     return std::nullopt;
895 
896   auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0));
897   if (!VecIns || VecIns->getIntrinsicID() != Intrinsic::vector_insert)
898     return std::nullopt;
899 
900   // Where the vector insert is a fixed constant vector insert into undef at
901   // index zero
902   if (!isa<UndefValue>(VecIns->getArgOperand(0)))
903     return std::nullopt;
904 
905   if (!cast<ConstantInt>(VecIns->getArgOperand(2))->isZero())
906     return std::nullopt;
907 
908   auto *ConstVec = dyn_cast<Constant>(VecIns->getArgOperand(1));
909   if (!ConstVec)
910     return std::nullopt;
911 
912   auto *VecTy = dyn_cast<FixedVectorType>(ConstVec->getType());
913   auto *OutTy = dyn_cast<ScalableVectorType>(II.getType());
914   if (!VecTy || !OutTy || VecTy->getNumElements() != OutTy->getMinNumElements())
915     return std::nullopt;
916 
917   unsigned NumElts = VecTy->getNumElements();
918   unsigned PredicateBits = 0;
919 
920   // Expand intrinsic operands to a 16-bit byte level predicate
921   for (unsigned I = 0; I < NumElts; ++I) {
922     auto *Arg = dyn_cast<ConstantInt>(ConstVec->getAggregateElement(I));
923     if (!Arg)
924       return std::nullopt;
925     if (!Arg->isZero())
926       PredicateBits |= 1 << (I * (16 / NumElts));
927   }
928 
929   // If all bits are zero bail early with an empty predicate
930   if (PredicateBits == 0) {
931     auto *PFalse = Constant::getNullValue(II.getType());
932     PFalse->takeName(&II);
933     return IC.replaceInstUsesWith(II, PFalse);
934   }
935 
936   // Calculate largest predicate type used (where byte predicate is largest)
937   unsigned Mask = 8;
938   for (unsigned I = 0; I < 16; ++I)
939     if ((PredicateBits & (1 << I)) != 0)
940       Mask |= (I % 8);
941 
942   unsigned PredSize = Mask & -Mask;
943   auto *PredType = ScalableVectorType::get(
944       Type::getInt1Ty(Ctx), AArch64::SVEBitsPerBlock / (PredSize * 8));
945 
946   // Ensure all relevant bits are set
947   for (unsigned I = 0; I < 16; I += PredSize)
948     if ((PredicateBits & (1 << I)) == 0)
949       return std::nullopt;
950 
951   auto *PTruePat =
952       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
953   auto *PTrue = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
954                                            {PredType}, {PTruePat});
955   auto *ConvertToSVBool = IC.Builder.CreateIntrinsic(
956       Intrinsic::aarch64_sve_convert_to_svbool, {PredType}, {PTrue});
957   auto *ConvertFromSVBool =
958       IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
959                                  {II.getType()}, {ConvertToSVBool});
960 
961   ConvertFromSVBool->takeName(&II);
962   return IC.replaceInstUsesWith(II, ConvertFromSVBool);
963 }
964 
965 static std::optional<Instruction *> instCombineSVELast(InstCombiner &IC,
966                                                        IntrinsicInst &II) {
967   Value *Pg = II.getArgOperand(0);
968   Value *Vec = II.getArgOperand(1);
969   auto IntrinsicID = II.getIntrinsicID();
970   bool IsAfter = IntrinsicID == Intrinsic::aarch64_sve_lasta;
971 
972   // lastX(splat(X)) --> X
973   if (auto *SplatVal = getSplatValue(Vec))
974     return IC.replaceInstUsesWith(II, SplatVal);
975 
976   // If x and/or y is a splat value then:
977   // lastX (binop (x, y)) --> binop(lastX(x), lastX(y))
978   Value *LHS, *RHS;
979   if (match(Vec, m_OneUse(m_BinOp(m_Value(LHS), m_Value(RHS))))) {
980     if (isSplatValue(LHS) || isSplatValue(RHS)) {
981       auto *OldBinOp = cast<BinaryOperator>(Vec);
982       auto OpC = OldBinOp->getOpcode();
983       auto *NewLHS =
984           IC.Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, LHS});
985       auto *NewRHS =
986           IC.Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, RHS});
987       auto *NewBinOp = BinaryOperator::CreateWithCopiedFlags(
988           OpC, NewLHS, NewRHS, OldBinOp, OldBinOp->getName(), &II);
989       return IC.replaceInstUsesWith(II, NewBinOp);
990     }
991   }
992 
993   auto *C = dyn_cast<Constant>(Pg);
994   if (IsAfter && C && C->isNullValue()) {
995     // The intrinsic is extracting lane 0 so use an extract instead.
996     auto *IdxTy = Type::getInt64Ty(II.getContext());
997     auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, 0));
998     Extract->insertBefore(&II);
999     Extract->takeName(&II);
1000     return IC.replaceInstUsesWith(II, Extract);
1001   }
1002 
1003   auto *IntrPG = dyn_cast<IntrinsicInst>(Pg);
1004   if (!IntrPG)
1005     return std::nullopt;
1006 
1007   if (IntrPG->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
1008     return std::nullopt;
1009 
1010   const auto PTruePattern =
1011       cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue();
1012 
1013   // Can the intrinsic's predicate be converted to a known constant index?
1014   unsigned MinNumElts = getNumElementsFromSVEPredPattern(PTruePattern);
1015   if (!MinNumElts)
1016     return std::nullopt;
1017 
1018   unsigned Idx = MinNumElts - 1;
1019   // Increment the index if extracting the element after the last active
1020   // predicate element.
1021   if (IsAfter)
1022     ++Idx;
1023 
1024   // Ignore extracts whose index is larger than the known minimum vector
1025   // length. NOTE: This is an artificial constraint where we prefer to
1026   // maintain what the user asked for until an alternative is proven faster.
1027   auto *PgVTy = cast<ScalableVectorType>(Pg->getType());
1028   if (Idx >= PgVTy->getMinNumElements())
1029     return std::nullopt;
1030 
1031   // The intrinsic is extracting a fixed lane so use an extract instead.
1032   auto *IdxTy = Type::getInt64Ty(II.getContext());
1033   auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, Idx));
1034   Extract->insertBefore(&II);
1035   Extract->takeName(&II);
1036   return IC.replaceInstUsesWith(II, Extract);
1037 }
1038 
1039 static std::optional<Instruction *> instCombineSVECondLast(InstCombiner &IC,
1040                                                            IntrinsicInst &II) {
1041   // The SIMD&FP variant of CLAST[AB] is significantly faster than the scalar
1042   // integer variant across a variety of micro-architectures. Replace scalar
1043   // integer CLAST[AB] intrinsic with optimal SIMD&FP variant. A simple
1044   // bitcast-to-fp + clast[ab] + bitcast-to-int will cost a cycle or two more
1045   // depending on the micro-architecture, but has been observed as generally
1046   // being faster, particularly when the CLAST[AB] op is a loop-carried
1047   // dependency.
1048   Value *Pg = II.getArgOperand(0);
1049   Value *Fallback = II.getArgOperand(1);
1050   Value *Vec = II.getArgOperand(2);
1051   Type *Ty = II.getType();
1052 
1053   if (!Ty->isIntegerTy())
1054     return std::nullopt;
1055 
1056   Type *FPTy;
1057   switch (cast<IntegerType>(Ty)->getBitWidth()) {
1058   default:
1059     return std::nullopt;
1060   case 16:
1061     FPTy = IC.Builder.getHalfTy();
1062     break;
1063   case 32:
1064     FPTy = IC.Builder.getFloatTy();
1065     break;
1066   case 64:
1067     FPTy = IC.Builder.getDoubleTy();
1068     break;
1069   }
1070 
1071   Value *FPFallBack = IC.Builder.CreateBitCast(Fallback, FPTy);
1072   auto *FPVTy = VectorType::get(
1073       FPTy, cast<VectorType>(Vec->getType())->getElementCount());
1074   Value *FPVec = IC.Builder.CreateBitCast(Vec, FPVTy);
1075   auto *FPII = IC.Builder.CreateIntrinsic(
1076       II.getIntrinsicID(), {FPVec->getType()}, {Pg, FPFallBack, FPVec});
1077   Value *FPIItoInt = IC.Builder.CreateBitCast(FPII, II.getType());
1078   return IC.replaceInstUsesWith(II, FPIItoInt);
1079 }
1080 
1081 static std::optional<Instruction *> instCombineRDFFR(InstCombiner &IC,
1082                                                      IntrinsicInst &II) {
1083   LLVMContext &Ctx = II.getContext();
1084   // Replace rdffr with predicated rdffr.z intrinsic, so that optimizePTestInstr
1085   // can work with RDFFR_PP for ptest elimination.
1086   auto *AllPat =
1087       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
1088   auto *PTrue = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
1089                                            {II.getType()}, {AllPat});
1090   auto *RDFFR =
1091       IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_rdffr_z, {}, {PTrue});
1092   RDFFR->takeName(&II);
1093   return IC.replaceInstUsesWith(II, RDFFR);
1094 }
1095 
1096 static std::optional<Instruction *>
1097 instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
1098   const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
1099 
1100   if (Pattern == AArch64SVEPredPattern::all) {
1101     Constant *StepVal = ConstantInt::get(II.getType(), NumElts);
1102     auto *VScale = IC.Builder.CreateVScale(StepVal);
1103     VScale->takeName(&II);
1104     return IC.replaceInstUsesWith(II, VScale);
1105   }
1106 
1107   unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern);
1108 
1109   return MinNumElts && NumElts >= MinNumElts
1110              ? std::optional<Instruction *>(IC.replaceInstUsesWith(
1111                    II, ConstantInt::get(II.getType(), MinNumElts)))
1112              : std::nullopt;
1113 }
1114 
1115 static std::optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
1116                                                         IntrinsicInst &II) {
1117   Value *PgVal = II.getArgOperand(0);
1118   Value *OpVal = II.getArgOperand(1);
1119 
1120   // PTEST_<FIRST|LAST>(X, X) is equivalent to PTEST_ANY(X, X).
1121   // Later optimizations prefer this form.
1122   if (PgVal == OpVal &&
1123       (II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_first ||
1124        II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_last)) {
1125     Value *Ops[] = {PgVal, OpVal};
1126     Type *Tys[] = {PgVal->getType()};
1127 
1128     auto *PTest =
1129         IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptest_any, Tys, Ops);
1130     PTest->takeName(&II);
1131 
1132     return IC.replaceInstUsesWith(II, PTest);
1133   }
1134 
1135   IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(PgVal);
1136   IntrinsicInst *Op = dyn_cast<IntrinsicInst>(OpVal);
1137 
1138   if (!Pg || !Op)
1139     return std::nullopt;
1140 
1141   Intrinsic::ID OpIID = Op->getIntrinsicID();
1142 
1143   if (Pg->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
1144       OpIID == Intrinsic::aarch64_sve_convert_to_svbool &&
1145       Pg->getArgOperand(0)->getType() == Op->getArgOperand(0)->getType()) {
1146     Value *Ops[] = {Pg->getArgOperand(0), Op->getArgOperand(0)};
1147     Type *Tys[] = {Pg->getArgOperand(0)->getType()};
1148 
1149     auto *PTest = IC.Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
1150 
1151     PTest->takeName(&II);
1152     return IC.replaceInstUsesWith(II, PTest);
1153   }
1154 
1155   // Transform PTEST_ANY(X=OP(PG,...), X) -> PTEST_ANY(PG, X)).
1156   // Later optimizations may rewrite sequence to use the flag-setting variant
1157   // of instruction X to remove PTEST.
1158   if ((Pg == Op) && (II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_any) &&
1159       ((OpIID == Intrinsic::aarch64_sve_brka_z) ||
1160        (OpIID == Intrinsic::aarch64_sve_brkb_z) ||
1161        (OpIID == Intrinsic::aarch64_sve_brkpa_z) ||
1162        (OpIID == Intrinsic::aarch64_sve_brkpb_z) ||
1163        (OpIID == Intrinsic::aarch64_sve_rdffr_z) ||
1164        (OpIID == Intrinsic::aarch64_sve_and_z) ||
1165        (OpIID == Intrinsic::aarch64_sve_bic_z) ||
1166        (OpIID == Intrinsic::aarch64_sve_eor_z) ||
1167        (OpIID == Intrinsic::aarch64_sve_nand_z) ||
1168        (OpIID == Intrinsic::aarch64_sve_nor_z) ||
1169        (OpIID == Intrinsic::aarch64_sve_orn_z) ||
1170        (OpIID == Intrinsic::aarch64_sve_orr_z))) {
1171     Value *Ops[] = {Pg->getArgOperand(0), Pg};
1172     Type *Tys[] = {Pg->getType()};
1173 
1174     auto *PTest = IC.Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
1175     PTest->takeName(&II);
1176 
1177     return IC.replaceInstUsesWith(II, PTest);
1178   }
1179 
1180   return std::nullopt;
1181 }
1182 
1183 template <Intrinsic::ID MulOpc, typename Intrinsic::ID FuseOpc>
1184 static std::optional<Instruction *>
1185 instCombineSVEVectorFuseMulAddSub(InstCombiner &IC, IntrinsicInst &II,
1186                                   bool MergeIntoAddendOp) {
1187   Value *P = II.getOperand(0);
1188   Value *MulOp0, *MulOp1, *AddendOp, *Mul;
1189   if (MergeIntoAddendOp) {
1190     AddendOp = II.getOperand(1);
1191     Mul = II.getOperand(2);
1192   } else {
1193     AddendOp = II.getOperand(2);
1194     Mul = II.getOperand(1);
1195   }
1196 
1197   if (!match(Mul, m_Intrinsic<MulOpc>(m_Specific(P), m_Value(MulOp0),
1198                                       m_Value(MulOp1))))
1199     return std::nullopt;
1200 
1201   if (!Mul->hasOneUse())
1202     return std::nullopt;
1203 
1204   Instruction *FMFSource = nullptr;
1205   if (II.getType()->isFPOrFPVectorTy()) {
1206     llvm::FastMathFlags FAddFlags = II.getFastMathFlags();
1207     // Stop the combine when the flags on the inputs differ in case dropping
1208     // flags would lead to us missing out on more beneficial optimizations.
1209     if (FAddFlags != cast<CallInst>(Mul)->getFastMathFlags())
1210       return std::nullopt;
1211     if (!FAddFlags.allowContract())
1212       return std::nullopt;
1213     FMFSource = &II;
1214   }
1215 
1216   CallInst *Res;
1217   if (MergeIntoAddendOp)
1218     Res = IC.Builder.CreateIntrinsic(FuseOpc, {II.getType()},
1219                                      {P, AddendOp, MulOp0, MulOp1}, FMFSource);
1220   else
1221     Res = IC.Builder.CreateIntrinsic(FuseOpc, {II.getType()},
1222                                      {P, MulOp0, MulOp1, AddendOp}, FMFSource);
1223 
1224   return IC.replaceInstUsesWith(II, Res);
1225 }
1226 
1227 static bool isAllActivePredicate(Value *Pred) {
1228   // Look through convert.from.svbool(convert.to.svbool(...) chain.
1229   Value *UncastedPred;
1230   if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_from_svbool>(
1231                       m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>(
1232                           m_Value(UncastedPred)))))
1233     // If the predicate has the same or less lanes than the uncasted
1234     // predicate then we know the casting has no effect.
1235     if (cast<ScalableVectorType>(Pred->getType())->getMinNumElements() <=
1236         cast<ScalableVectorType>(UncastedPred->getType())->getMinNumElements())
1237       Pred = UncastedPred;
1238 
1239   return match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
1240                          m_ConstantInt<AArch64SVEPredPattern::all>()));
1241 }
1242 
1243 static std::optional<Instruction *>
1244 instCombineSVELD1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
1245   Value *Pred = II.getOperand(0);
1246   Value *PtrOp = II.getOperand(1);
1247   Type *VecTy = II.getType();
1248 
1249   if (isAllActivePredicate(Pred)) {
1250     LoadInst *Load = IC.Builder.CreateLoad(VecTy, PtrOp);
1251     Load->copyMetadata(II);
1252     return IC.replaceInstUsesWith(II, Load);
1253   }
1254 
1255   CallInst *MaskedLoad =
1256       IC.Builder.CreateMaskedLoad(VecTy, PtrOp, PtrOp->getPointerAlignment(DL),
1257                                   Pred, ConstantAggregateZero::get(VecTy));
1258   MaskedLoad->copyMetadata(II);
1259   return IC.replaceInstUsesWith(II, MaskedLoad);
1260 }
1261 
1262 static std::optional<Instruction *>
1263 instCombineSVEST1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
1264   Value *VecOp = II.getOperand(0);
1265   Value *Pred = II.getOperand(1);
1266   Value *PtrOp = II.getOperand(2);
1267 
1268   if (isAllActivePredicate(Pred)) {
1269     StoreInst *Store = IC.Builder.CreateStore(VecOp, PtrOp);
1270     Store->copyMetadata(II);
1271     return IC.eraseInstFromFunction(II);
1272   }
1273 
1274   CallInst *MaskedStore = IC.Builder.CreateMaskedStore(
1275       VecOp, PtrOp, PtrOp->getPointerAlignment(DL), Pred);
1276   MaskedStore->copyMetadata(II);
1277   return IC.eraseInstFromFunction(II);
1278 }
1279 
1280 static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) {
1281   switch (Intrinsic) {
1282   case Intrinsic::aarch64_sve_fmul_u:
1283     return Instruction::BinaryOps::FMul;
1284   case Intrinsic::aarch64_sve_fadd_u:
1285     return Instruction::BinaryOps::FAdd;
1286   case Intrinsic::aarch64_sve_fsub_u:
1287     return Instruction::BinaryOps::FSub;
1288   default:
1289     return Instruction::BinaryOpsEnd;
1290   }
1291 }
1292 
1293 static std::optional<Instruction *>
1294 instCombineSVEVectorBinOp(InstCombiner &IC, IntrinsicInst &II) {
1295   // Bail due to missing support for ISD::STRICT_ scalable vector operations.
1296   if (II.isStrictFP())
1297     return std::nullopt;
1298 
1299   auto *OpPredicate = II.getOperand(0);
1300   auto BinOpCode = intrinsicIDToBinOpCode(II.getIntrinsicID());
1301   if (BinOpCode == Instruction::BinaryOpsEnd ||
1302       !match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
1303                               m_ConstantInt<AArch64SVEPredPattern::all>())))
1304     return std::nullopt;
1305   IRBuilderBase::FastMathFlagGuard FMFGuard(IC.Builder);
1306   IC.Builder.setFastMathFlags(II.getFastMathFlags());
1307   auto BinOp =
1308       IC.Builder.CreateBinOp(BinOpCode, II.getOperand(1), II.getOperand(2));
1309   return IC.replaceInstUsesWith(II, BinOp);
1310 }
1311 
1312 // Canonicalise operations that take an all active predicate (e.g. sve.add ->
1313 // sve.add_u).
1314 static std::optional<Instruction *> instCombineSVEAllActive(IntrinsicInst &II,
1315                                                             Intrinsic::ID IID) {
1316   auto *OpPredicate = II.getOperand(0);
1317   if (!match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
1318                               m_ConstantInt<AArch64SVEPredPattern::all>())))
1319     return std::nullopt;
1320 
1321   auto *Mod = II.getModule();
1322   auto *NewDecl = Intrinsic::getDeclaration(Mod, IID, {II.getType()});
1323   II.setCalledFunction(NewDecl);
1324 
1325   return &II;
1326 }
1327 
1328 static std::optional<Instruction *> instCombineSVEVectorAdd(InstCombiner &IC,
1329                                                             IntrinsicInst &II) {
1330   if (auto II_U = instCombineSVEAllActive(II, Intrinsic::aarch64_sve_add_u))
1331     return II_U;
1332   if (auto MLA = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
1333                                                    Intrinsic::aarch64_sve_mla>(
1334           IC, II, true))
1335     return MLA;
1336   if (auto MAD = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
1337                                                    Intrinsic::aarch64_sve_mad>(
1338           IC, II, false))
1339     return MAD;
1340   return std::nullopt;
1341 }
1342 
1343 static std::optional<Instruction *>
1344 instCombineSVEVectorFAdd(InstCombiner &IC, IntrinsicInst &II) {
1345   if (auto II_U = instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fadd_u))
1346     return II_U;
1347   if (auto FMLA =
1348           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1349                                             Intrinsic::aarch64_sve_fmla>(IC, II,
1350                                                                          true))
1351     return FMLA;
1352   if (auto FMAD =
1353           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1354                                             Intrinsic::aarch64_sve_fmad>(IC, II,
1355                                                                          false))
1356     return FMAD;
1357   if (auto FMLA =
1358           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
1359                                             Intrinsic::aarch64_sve_fmla>(IC, II,
1360                                                                          true))
1361     return FMLA;
1362   return std::nullopt;
1363 }
1364 
1365 static std::optional<Instruction *>
1366 instCombineSVEVectorFAddU(InstCombiner &IC, IntrinsicInst &II) {
1367   if (auto FMLA =
1368           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1369                                             Intrinsic::aarch64_sve_fmla>(IC, II,
1370                                                                          true))
1371     return FMLA;
1372   if (auto FMAD =
1373           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1374                                             Intrinsic::aarch64_sve_fmad>(IC, II,
1375                                                                          false))
1376     return FMAD;
1377   if (auto FMLA_U =
1378           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
1379                                             Intrinsic::aarch64_sve_fmla_u>(
1380               IC, II, true))
1381     return FMLA_U;
1382   return instCombineSVEVectorBinOp(IC, II);
1383 }
1384 
1385 static std::optional<Instruction *>
1386 instCombineSVEVectorFSub(InstCombiner &IC, IntrinsicInst &II) {
1387   if (auto II_U = instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fsub_u))
1388     return II_U;
1389   if (auto FMLS =
1390           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1391                                             Intrinsic::aarch64_sve_fmls>(IC, II,
1392                                                                          true))
1393     return FMLS;
1394   if (auto FMSB =
1395           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1396                                             Intrinsic::aarch64_sve_fnmsb>(
1397               IC, II, false))
1398     return FMSB;
1399   if (auto FMLS =
1400           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
1401                                             Intrinsic::aarch64_sve_fmls>(IC, II,
1402                                                                          true))
1403     return FMLS;
1404   return std::nullopt;
1405 }
1406 
1407 static std::optional<Instruction *>
1408 instCombineSVEVectorFSubU(InstCombiner &IC, IntrinsicInst &II) {
1409   if (auto FMLS =
1410           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1411                                             Intrinsic::aarch64_sve_fmls>(IC, II,
1412                                                                          true))
1413     return FMLS;
1414   if (auto FMSB =
1415           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1416                                             Intrinsic::aarch64_sve_fnmsb>(
1417               IC, II, false))
1418     return FMSB;
1419   if (auto FMLS_U =
1420           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
1421                                             Intrinsic::aarch64_sve_fmls_u>(
1422               IC, II, true))
1423     return FMLS_U;
1424   return instCombineSVEVectorBinOp(IC, II);
1425 }
1426 
1427 static std::optional<Instruction *> instCombineSVEVectorSub(InstCombiner &IC,
1428                                                             IntrinsicInst &II) {
1429   if (auto II_U = instCombineSVEAllActive(II, Intrinsic::aarch64_sve_sub_u))
1430     return II_U;
1431   if (auto MLS = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
1432                                                    Intrinsic::aarch64_sve_mls>(
1433           IC, II, true))
1434     return MLS;
1435   return std::nullopt;
1436 }
1437 
1438 static std::optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
1439                                                             IntrinsicInst &II,
1440                                                             Intrinsic::ID IID) {
1441   auto *OpPredicate = II.getOperand(0);
1442   auto *OpMultiplicand = II.getOperand(1);
1443   auto *OpMultiplier = II.getOperand(2);
1444 
1445   // Canonicalise a non _u intrinsic only.
1446   if (II.getIntrinsicID() != IID)
1447     if (auto II_U = instCombineSVEAllActive(II, IID))
1448       return II_U;
1449 
1450   // Return true if a given instruction is a unit splat value, false otherwise.
1451   auto IsUnitSplat = [](auto *I) {
1452     auto *SplatValue = getSplatValue(I);
1453     if (!SplatValue)
1454       return false;
1455     return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
1456   };
1457 
1458   // Return true if a given instruction is an aarch64_sve_dup intrinsic call
1459   // with a unit splat value, false otherwise.
1460   auto IsUnitDup = [](auto *I) {
1461     auto *IntrI = dyn_cast<IntrinsicInst>(I);
1462     if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup)
1463       return false;
1464 
1465     auto *SplatValue = IntrI->getOperand(2);
1466     return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
1467   };
1468 
1469   if (IsUnitSplat(OpMultiplier)) {
1470     // [f]mul pg %n, (dupx 1) => %n
1471     OpMultiplicand->takeName(&II);
1472     return IC.replaceInstUsesWith(II, OpMultiplicand);
1473   } else if (IsUnitDup(OpMultiplier)) {
1474     // [f]mul pg %n, (dup pg 1) => %n
1475     auto *DupInst = cast<IntrinsicInst>(OpMultiplier);
1476     auto *DupPg = DupInst->getOperand(1);
1477     // TODO: this is naive. The optimization is still valid if DupPg
1478     // 'encompasses' OpPredicate, not only if they're the same predicate.
1479     if (OpPredicate == DupPg) {
1480       OpMultiplicand->takeName(&II);
1481       return IC.replaceInstUsesWith(II, OpMultiplicand);
1482     }
1483   }
1484 
1485   return instCombineSVEVectorBinOp(IC, II);
1486 }
1487 
1488 static std::optional<Instruction *> instCombineSVEUnpack(InstCombiner &IC,
1489                                                          IntrinsicInst &II) {
1490   Value *UnpackArg = II.getArgOperand(0);
1491   auto *RetTy = cast<ScalableVectorType>(II.getType());
1492   bool IsSigned = II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpkhi ||
1493                   II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpklo;
1494 
1495   // Hi = uunpkhi(splat(X)) --> Hi = splat(extend(X))
1496   // Lo = uunpklo(splat(X)) --> Lo = splat(extend(X))
1497   if (auto *ScalarArg = getSplatValue(UnpackArg)) {
1498     ScalarArg =
1499         IC.Builder.CreateIntCast(ScalarArg, RetTy->getScalarType(), IsSigned);
1500     Value *NewVal =
1501         IC.Builder.CreateVectorSplat(RetTy->getElementCount(), ScalarArg);
1502     NewVal->takeName(&II);
1503     return IC.replaceInstUsesWith(II, NewVal);
1504   }
1505 
1506   return std::nullopt;
1507 }
1508 static std::optional<Instruction *> instCombineSVETBL(InstCombiner &IC,
1509                                                       IntrinsicInst &II) {
1510   auto *OpVal = II.getOperand(0);
1511   auto *OpIndices = II.getOperand(1);
1512   VectorType *VTy = cast<VectorType>(II.getType());
1513 
1514   // Check whether OpIndices is a constant splat value < minimal element count
1515   // of result.
1516   auto *SplatValue = dyn_cast_or_null<ConstantInt>(getSplatValue(OpIndices));
1517   if (!SplatValue ||
1518       SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue()))
1519     return std::nullopt;
1520 
1521   // Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to
1522   // splat_vector(extractelement(OpVal, SplatValue)) for further optimization.
1523   auto *Extract = IC.Builder.CreateExtractElement(OpVal, SplatValue);
1524   auto *VectorSplat =
1525       IC.Builder.CreateVectorSplat(VTy->getElementCount(), Extract);
1526 
1527   VectorSplat->takeName(&II);
1528   return IC.replaceInstUsesWith(II, VectorSplat);
1529 }
1530 
1531 static std::optional<Instruction *> instCombineSVEZip(InstCombiner &IC,
1532                                                       IntrinsicInst &II) {
1533   // zip1(uzp1(A, B), uzp2(A, B)) --> A
1534   // zip2(uzp1(A, B), uzp2(A, B)) --> B
1535   Value *A, *B;
1536   if (match(II.getArgOperand(0),
1537             m_Intrinsic<Intrinsic::aarch64_sve_uzp1>(m_Value(A), m_Value(B))) &&
1538       match(II.getArgOperand(1), m_Intrinsic<Intrinsic::aarch64_sve_uzp2>(
1539                                      m_Specific(A), m_Specific(B))))
1540     return IC.replaceInstUsesWith(
1541         II, (II.getIntrinsicID() == Intrinsic::aarch64_sve_zip1 ? A : B));
1542 
1543   return std::nullopt;
1544 }
1545 
1546 static std::optional<Instruction *>
1547 instCombineLD1GatherIndex(InstCombiner &IC, IntrinsicInst &II) {
1548   Value *Mask = II.getOperand(0);
1549   Value *BasePtr = II.getOperand(1);
1550   Value *Index = II.getOperand(2);
1551   Type *Ty = II.getType();
1552   Value *PassThru = ConstantAggregateZero::get(Ty);
1553 
1554   // Contiguous gather => masked load.
1555   // (sve.ld1.gather.index Mask BasePtr (sve.index IndexBase 1))
1556   // => (masked.load (gep BasePtr IndexBase) Align Mask zeroinitializer)
1557   Value *IndexBase;
1558   if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
1559                        m_Value(IndexBase), m_SpecificInt(1)))) {
1560     Align Alignment =
1561         BasePtr->getPointerAlignment(II.getModule()->getDataLayout());
1562 
1563     Type *VecPtrTy = PointerType::getUnqual(Ty);
1564     Value *Ptr = IC.Builder.CreateGEP(cast<VectorType>(Ty)->getElementType(),
1565                                       BasePtr, IndexBase);
1566     Ptr = IC.Builder.CreateBitCast(Ptr, VecPtrTy);
1567     CallInst *MaskedLoad =
1568         IC.Builder.CreateMaskedLoad(Ty, Ptr, Alignment, Mask, PassThru);
1569     MaskedLoad->takeName(&II);
1570     return IC.replaceInstUsesWith(II, MaskedLoad);
1571   }
1572 
1573   return std::nullopt;
1574 }
1575 
1576 static std::optional<Instruction *>
1577 instCombineST1ScatterIndex(InstCombiner &IC, IntrinsicInst &II) {
1578   Value *Val = II.getOperand(0);
1579   Value *Mask = II.getOperand(1);
1580   Value *BasePtr = II.getOperand(2);
1581   Value *Index = II.getOperand(3);
1582   Type *Ty = Val->getType();
1583 
1584   // Contiguous scatter => masked store.
1585   // (sve.st1.scatter.index Value Mask BasePtr (sve.index IndexBase 1))
1586   // => (masked.store Value (gep BasePtr IndexBase) Align Mask)
1587   Value *IndexBase;
1588   if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
1589                        m_Value(IndexBase), m_SpecificInt(1)))) {
1590     Align Alignment =
1591         BasePtr->getPointerAlignment(II.getModule()->getDataLayout());
1592 
1593     Value *Ptr = IC.Builder.CreateGEP(cast<VectorType>(Ty)->getElementType(),
1594                                       BasePtr, IndexBase);
1595     Type *VecPtrTy = PointerType::getUnqual(Ty);
1596     Ptr = IC.Builder.CreateBitCast(Ptr, VecPtrTy);
1597 
1598     (void)IC.Builder.CreateMaskedStore(Val, Ptr, Alignment, Mask);
1599 
1600     return IC.eraseInstFromFunction(II);
1601   }
1602 
1603   return std::nullopt;
1604 }
1605 
1606 static std::optional<Instruction *> instCombineSVESDIV(InstCombiner &IC,
1607                                                        IntrinsicInst &II) {
1608   Type *Int32Ty = IC.Builder.getInt32Ty();
1609   Value *Pred = II.getOperand(0);
1610   Value *Vec = II.getOperand(1);
1611   Value *DivVec = II.getOperand(2);
1612 
1613   Value *SplatValue = getSplatValue(DivVec);
1614   ConstantInt *SplatConstantInt = dyn_cast_or_null<ConstantInt>(SplatValue);
1615   if (!SplatConstantInt)
1616     return std::nullopt;
1617   APInt Divisor = SplatConstantInt->getValue();
1618 
1619   if (Divisor.isPowerOf2()) {
1620     Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
1621     auto ASRD = IC.Builder.CreateIntrinsic(
1622         Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
1623     return IC.replaceInstUsesWith(II, ASRD);
1624   }
1625   if (Divisor.isNegatedPowerOf2()) {
1626     Divisor.negate();
1627     Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
1628     auto ASRD = IC.Builder.CreateIntrinsic(
1629         Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
1630     auto NEG = IC.Builder.CreateIntrinsic(
1631         Intrinsic::aarch64_sve_neg, {ASRD->getType()}, {ASRD, Pred, ASRD});
1632     return IC.replaceInstUsesWith(II, NEG);
1633   }
1634 
1635   return std::nullopt;
1636 }
1637 
1638 bool SimplifyValuePattern(SmallVector<Value *> &Vec, bool AllowPoison) {
1639   size_t VecSize = Vec.size();
1640   if (VecSize == 1)
1641     return true;
1642   if (!isPowerOf2_64(VecSize))
1643     return false;
1644   size_t HalfVecSize = VecSize / 2;
1645 
1646   for (auto LHS = Vec.begin(), RHS = Vec.begin() + HalfVecSize;
1647        RHS != Vec.end(); LHS++, RHS++) {
1648     if (*LHS != nullptr && *RHS != nullptr) {
1649       if (*LHS == *RHS)
1650         continue;
1651       else
1652         return false;
1653     }
1654     if (!AllowPoison)
1655       return false;
1656     if (*LHS == nullptr && *RHS != nullptr)
1657       *LHS = *RHS;
1658   }
1659 
1660   Vec.resize(HalfVecSize);
1661   SimplifyValuePattern(Vec, AllowPoison);
1662   return true;
1663 }
1664 
1665 // Try to simplify dupqlane patterns like dupqlane(f32 A, f32 B, f32 A, f32 B)
1666 // to dupqlane(f64(C)) where C is A concatenated with B
1667 static std::optional<Instruction *> instCombineSVEDupqLane(InstCombiner &IC,
1668                                                            IntrinsicInst &II) {
1669   Value *CurrentInsertElt = nullptr, *Default = nullptr;
1670   if (!match(II.getOperand(0),
1671              m_Intrinsic<Intrinsic::vector_insert>(
1672                  m_Value(Default), m_Value(CurrentInsertElt), m_Value())) ||
1673       !isa<FixedVectorType>(CurrentInsertElt->getType()))
1674     return std::nullopt;
1675   auto IIScalableTy = cast<ScalableVectorType>(II.getType());
1676 
1677   // Insert the scalars into a container ordered by InsertElement index
1678   SmallVector<Value *> Elts(IIScalableTy->getMinNumElements(), nullptr);
1679   while (auto InsertElt = dyn_cast<InsertElementInst>(CurrentInsertElt)) {
1680     auto Idx = cast<ConstantInt>(InsertElt->getOperand(2));
1681     Elts[Idx->getValue().getZExtValue()] = InsertElt->getOperand(1);
1682     CurrentInsertElt = InsertElt->getOperand(0);
1683   }
1684 
1685   bool AllowPoison =
1686       isa<PoisonValue>(CurrentInsertElt) && isa<PoisonValue>(Default);
1687   if (!SimplifyValuePattern(Elts, AllowPoison))
1688     return std::nullopt;
1689 
1690   // Rebuild the simplified chain of InsertElements. e.g. (a, b, a, b) as (a, b)
1691   Value *InsertEltChain = PoisonValue::get(CurrentInsertElt->getType());
1692   for (size_t I = 0; I < Elts.size(); I++) {
1693     if (Elts[I] == nullptr)
1694       continue;
1695     InsertEltChain = IC.Builder.CreateInsertElement(InsertEltChain, Elts[I],
1696                                                     IC.Builder.getInt64(I));
1697   }
1698   if (InsertEltChain == nullptr)
1699     return std::nullopt;
1700 
1701   // Splat the simplified sequence, e.g. (f16 a, f16 b, f16 c, f16 d) as one i64
1702   // value or (f16 a, f16 b) as one i32 value. This requires an InsertSubvector
1703   // be bitcast to a type wide enough to fit the sequence, be splatted, and then
1704   // be narrowed back to the original type.
1705   unsigned PatternWidth = IIScalableTy->getScalarSizeInBits() * Elts.size();
1706   unsigned PatternElementCount = IIScalableTy->getScalarSizeInBits() *
1707                                  IIScalableTy->getMinNumElements() /
1708                                  PatternWidth;
1709 
1710   IntegerType *WideTy = IC.Builder.getIntNTy(PatternWidth);
1711   auto *WideScalableTy = ScalableVectorType::get(WideTy, PatternElementCount);
1712   auto *WideShuffleMaskTy =
1713       ScalableVectorType::get(IC.Builder.getInt32Ty(), PatternElementCount);
1714 
1715   auto ZeroIdx = ConstantInt::get(IC.Builder.getInt64Ty(), APInt(64, 0));
1716   auto InsertSubvector = IC.Builder.CreateInsertVector(
1717       II.getType(), PoisonValue::get(II.getType()), InsertEltChain, ZeroIdx);
1718   auto WideBitcast =
1719       IC.Builder.CreateBitOrPointerCast(InsertSubvector, WideScalableTy);
1720   auto WideShuffleMask = ConstantAggregateZero::get(WideShuffleMaskTy);
1721   auto WideShuffle = IC.Builder.CreateShuffleVector(
1722       WideBitcast, PoisonValue::get(WideScalableTy), WideShuffleMask);
1723   auto NarrowBitcast =
1724       IC.Builder.CreateBitOrPointerCast(WideShuffle, II.getType());
1725 
1726   return IC.replaceInstUsesWith(II, NarrowBitcast);
1727 }
1728 
1729 static std::optional<Instruction *> instCombineMaxMinNM(InstCombiner &IC,
1730                                                         IntrinsicInst &II) {
1731   Value *A = II.getArgOperand(0);
1732   Value *B = II.getArgOperand(1);
1733   if (A == B)
1734     return IC.replaceInstUsesWith(II, A);
1735 
1736   return std::nullopt;
1737 }
1738 
1739 static std::optional<Instruction *> instCombineSVESrshl(InstCombiner &IC,
1740                                                         IntrinsicInst &II) {
1741   Value *Pred = II.getOperand(0);
1742   Value *Vec = II.getOperand(1);
1743   Value *Shift = II.getOperand(2);
1744 
1745   // Convert SRSHL into the simpler LSL intrinsic when fed by an ABS intrinsic.
1746   Value *AbsPred, *MergedValue;
1747   if (!match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_sqabs>(
1748                       m_Value(MergedValue), m_Value(AbsPred), m_Value())) &&
1749       !match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_abs>(
1750                       m_Value(MergedValue), m_Value(AbsPred), m_Value())))
1751 
1752     return std::nullopt;
1753 
1754   // Transform is valid if any of the following are true:
1755   // * The ABS merge value is an undef or non-negative
1756   // * The ABS predicate is all active
1757   // * The ABS predicate and the SRSHL predicates are the same
1758   if (!isa<UndefValue>(MergedValue) && !match(MergedValue, m_NonNegative()) &&
1759       AbsPred != Pred && !isAllActivePredicate(AbsPred))
1760     return std::nullopt;
1761 
1762   // Only valid when the shift amount is non-negative, otherwise the rounding
1763   // behaviour of SRSHL cannot be ignored.
1764   if (!match(Shift, m_NonNegative()))
1765     return std::nullopt;
1766 
1767   auto LSL = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_lsl,
1768                                         {II.getType()}, {Pred, Vec, Shift});
1769 
1770   return IC.replaceInstUsesWith(II, LSL);
1771 }
1772 
1773 std::optional<Instruction *>
1774 AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
1775                                      IntrinsicInst &II) const {
1776   Intrinsic::ID IID = II.getIntrinsicID();
1777   switch (IID) {
1778   default:
1779     break;
1780   case Intrinsic::aarch64_neon_fmaxnm:
1781   case Intrinsic::aarch64_neon_fminnm:
1782     return instCombineMaxMinNM(IC, II);
1783   case Intrinsic::aarch64_sve_convert_from_svbool:
1784     return instCombineConvertFromSVBool(IC, II);
1785   case Intrinsic::aarch64_sve_dup:
1786     return instCombineSVEDup(IC, II);
1787   case Intrinsic::aarch64_sve_dup_x:
1788     return instCombineSVEDupX(IC, II);
1789   case Intrinsic::aarch64_sve_cmpne:
1790   case Intrinsic::aarch64_sve_cmpne_wide:
1791     return instCombineSVECmpNE(IC, II);
1792   case Intrinsic::aarch64_sve_rdffr:
1793     return instCombineRDFFR(IC, II);
1794   case Intrinsic::aarch64_sve_lasta:
1795   case Intrinsic::aarch64_sve_lastb:
1796     return instCombineSVELast(IC, II);
1797   case Intrinsic::aarch64_sve_clasta_n:
1798   case Intrinsic::aarch64_sve_clastb_n:
1799     return instCombineSVECondLast(IC, II);
1800   case Intrinsic::aarch64_sve_cntd:
1801     return instCombineSVECntElts(IC, II, 2);
1802   case Intrinsic::aarch64_sve_cntw:
1803     return instCombineSVECntElts(IC, II, 4);
1804   case Intrinsic::aarch64_sve_cnth:
1805     return instCombineSVECntElts(IC, II, 8);
1806   case Intrinsic::aarch64_sve_cntb:
1807     return instCombineSVECntElts(IC, II, 16);
1808   case Intrinsic::aarch64_sve_ptest_any:
1809   case Intrinsic::aarch64_sve_ptest_first:
1810   case Intrinsic::aarch64_sve_ptest_last:
1811     return instCombineSVEPTest(IC, II);
1812   case Intrinsic::aarch64_sve_fabd:
1813     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fabd_u);
1814   case Intrinsic::aarch64_sve_fadd:
1815     return instCombineSVEVectorFAdd(IC, II);
1816   case Intrinsic::aarch64_sve_fadd_u:
1817     return instCombineSVEVectorFAddU(IC, II);
1818   case Intrinsic::aarch64_sve_fdiv:
1819     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fdiv_u);
1820   case Intrinsic::aarch64_sve_fmax:
1821     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmax_u);
1822   case Intrinsic::aarch64_sve_fmaxnm:
1823     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmaxnm_u);
1824   case Intrinsic::aarch64_sve_fmin:
1825     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmin_u);
1826   case Intrinsic::aarch64_sve_fminnm:
1827     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fminnm_u);
1828   case Intrinsic::aarch64_sve_fmla:
1829     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmla_u);
1830   case Intrinsic::aarch64_sve_fmls:
1831     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmls_u);
1832   case Intrinsic::aarch64_sve_fmul:
1833   case Intrinsic::aarch64_sve_fmul_u:
1834     return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_fmul_u);
1835   case Intrinsic::aarch64_sve_fmulx:
1836     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmulx_u);
1837   case Intrinsic::aarch64_sve_fnmla:
1838     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fnmla_u);
1839   case Intrinsic::aarch64_sve_fnmls:
1840     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fnmls_u);
1841   case Intrinsic::aarch64_sve_fsub:
1842     return instCombineSVEVectorFSub(IC, II);
1843   case Intrinsic::aarch64_sve_fsub_u:
1844     return instCombineSVEVectorFSubU(IC, II);
1845   case Intrinsic::aarch64_sve_add:
1846     return instCombineSVEVectorAdd(IC, II);
1847   case Intrinsic::aarch64_sve_add_u:
1848     return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul_u,
1849                                              Intrinsic::aarch64_sve_mla_u>(
1850         IC, II, true);
1851   case Intrinsic::aarch64_sve_mla:
1852     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_mla_u);
1853   case Intrinsic::aarch64_sve_mls:
1854     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_mls_u);
1855   case Intrinsic::aarch64_sve_mul:
1856   case Intrinsic::aarch64_sve_mul_u:
1857     return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_mul_u);
1858   case Intrinsic::aarch64_sve_sabd:
1859     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_sabd_u);
1860   case Intrinsic::aarch64_sve_smax:
1861     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_smax_u);
1862   case Intrinsic::aarch64_sve_smin:
1863     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_smin_u);
1864   case Intrinsic::aarch64_sve_smulh:
1865     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_smulh_u);
1866   case Intrinsic::aarch64_sve_sub:
1867     return instCombineSVEVectorSub(IC, II);
1868   case Intrinsic::aarch64_sve_sub_u:
1869     return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul_u,
1870                                              Intrinsic::aarch64_sve_mls_u>(
1871         IC, II, true);
1872   case Intrinsic::aarch64_sve_uabd:
1873     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_uabd_u);
1874   case Intrinsic::aarch64_sve_umax:
1875     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_umax_u);
1876   case Intrinsic::aarch64_sve_umin:
1877     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_umin_u);
1878   case Intrinsic::aarch64_sve_umulh:
1879     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_umulh_u);
1880   case Intrinsic::aarch64_sve_asr:
1881     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_asr_u);
1882   case Intrinsic::aarch64_sve_lsl:
1883     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_lsl_u);
1884   case Intrinsic::aarch64_sve_lsr:
1885     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_lsr_u);
1886   case Intrinsic::aarch64_sve_and:
1887     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_and_u);
1888   case Intrinsic::aarch64_sve_bic:
1889     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_bic_u);
1890   case Intrinsic::aarch64_sve_eor:
1891     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_eor_u);
1892   case Intrinsic::aarch64_sve_orr:
1893     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_orr_u);
1894   case Intrinsic::aarch64_sve_sqsub:
1895     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_sqsub_u);
1896   case Intrinsic::aarch64_sve_uqsub:
1897     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_uqsub_u);
1898   case Intrinsic::aarch64_sve_tbl:
1899     return instCombineSVETBL(IC, II);
1900   case Intrinsic::aarch64_sve_uunpkhi:
1901   case Intrinsic::aarch64_sve_uunpklo:
1902   case Intrinsic::aarch64_sve_sunpkhi:
1903   case Intrinsic::aarch64_sve_sunpklo:
1904     return instCombineSVEUnpack(IC, II);
1905   case Intrinsic::aarch64_sve_zip1:
1906   case Intrinsic::aarch64_sve_zip2:
1907     return instCombineSVEZip(IC, II);
1908   case Intrinsic::aarch64_sve_ld1_gather_index:
1909     return instCombineLD1GatherIndex(IC, II);
1910   case Intrinsic::aarch64_sve_st1_scatter_index:
1911     return instCombineST1ScatterIndex(IC, II);
1912   case Intrinsic::aarch64_sve_ld1:
1913     return instCombineSVELD1(IC, II, DL);
1914   case Intrinsic::aarch64_sve_st1:
1915     return instCombineSVEST1(IC, II, DL);
1916   case Intrinsic::aarch64_sve_sdiv:
1917     return instCombineSVESDIV(IC, II);
1918   case Intrinsic::aarch64_sve_sel:
1919     return instCombineSVESel(IC, II);
1920   case Intrinsic::aarch64_sve_srshl:
1921     return instCombineSVESrshl(IC, II);
1922   case Intrinsic::aarch64_sve_dupq_lane:
1923     return instCombineSVEDupqLane(IC, II);
1924   }
1925 
1926   return std::nullopt;
1927 }
1928 
1929 std::optional<Value *> AArch64TTIImpl::simplifyDemandedVectorEltsIntrinsic(
1930     InstCombiner &IC, IntrinsicInst &II, APInt OrigDemandedElts,
1931     APInt &UndefElts, APInt &UndefElts2, APInt &UndefElts3,
1932     std::function<void(Instruction *, unsigned, APInt, APInt &)>
1933         SimplifyAndSetOp) const {
1934   switch (II.getIntrinsicID()) {
1935   default:
1936     break;
1937   case Intrinsic::aarch64_neon_fcvtxn:
1938   case Intrinsic::aarch64_neon_rshrn:
1939   case Intrinsic::aarch64_neon_sqrshrn:
1940   case Intrinsic::aarch64_neon_sqrshrun:
1941   case Intrinsic::aarch64_neon_sqshrn:
1942   case Intrinsic::aarch64_neon_sqshrun:
1943   case Intrinsic::aarch64_neon_sqxtn:
1944   case Intrinsic::aarch64_neon_sqxtun:
1945   case Intrinsic::aarch64_neon_uqrshrn:
1946   case Intrinsic::aarch64_neon_uqshrn:
1947   case Intrinsic::aarch64_neon_uqxtn:
1948     SimplifyAndSetOp(&II, 0, OrigDemandedElts, UndefElts);
1949     break;
1950   }
1951 
1952   return std::nullopt;
1953 }
1954 
1955 TypeSize
1956 AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
1957   switch (K) {
1958   case TargetTransformInfo::RGK_Scalar:
1959     return TypeSize::getFixed(64);
1960   case TargetTransformInfo::RGK_FixedWidthVector:
1961     if (!ST->isNeonAvailable() && !EnableFixedwidthAutovecInStreamingMode)
1962       return TypeSize::getFixed(0);
1963 
1964     if (ST->hasSVE())
1965       return TypeSize::getFixed(
1966           std::max(ST->getMinSVEVectorSizeInBits(), 128u));
1967 
1968     return TypeSize::getFixed(ST->hasNEON() ? 128 : 0);
1969   case TargetTransformInfo::RGK_ScalableVector:
1970     if ((ST->isStreaming() || ST->isStreamingCompatible()) &&
1971         !EnableScalableAutovecInStreamingMode)
1972       return TypeSize::getScalable(0);
1973 
1974     return TypeSize::getScalable(ST->hasSVE() ? 128 : 0);
1975   }
1976   llvm_unreachable("Unsupported register kind");
1977 }
1978 
1979 bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
1980                                            ArrayRef<const Value *> Args,
1981                                            Type *SrcOverrideTy) {
1982   // A helper that returns a vector type from the given type. The number of
1983   // elements in type Ty determines the vector width.
1984   auto toVectorTy = [&](Type *ArgTy) {
1985     return VectorType::get(ArgTy->getScalarType(),
1986                            cast<VectorType>(DstTy)->getElementCount());
1987   };
1988 
1989   // Exit early if DstTy is not a vector type whose elements are one of [i16,
1990   // i32, i64]. SVE doesn't generally have the same set of instructions to
1991   // perform an extend with the add/sub/mul. There are SMULLB style
1992   // instructions, but they operate on top/bottom, requiring some sort of lane
1993   // interleaving to be used with zext/sext.
1994   unsigned DstEltSize = DstTy->getScalarSizeInBits();
1995   if (!useNeonVector(DstTy) || Args.size() != 2 ||
1996       (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64))
1997     return false;
1998 
1999   // Determine if the operation has a widening variant. We consider both the
2000   // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
2001   // instructions.
2002   //
2003   // TODO: Add additional widening operations (e.g., shl, etc.) once we
2004   //       verify that their extending operands are eliminated during code
2005   //       generation.
2006   Type *SrcTy = SrcOverrideTy;
2007   switch (Opcode) {
2008   case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
2009   case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
2010     // The second operand needs to be an extend
2011     if (isa<SExtInst>(Args[1]) || isa<ZExtInst>(Args[1])) {
2012       if (!SrcTy)
2013         SrcTy =
2014             toVectorTy(cast<Instruction>(Args[1])->getOperand(0)->getType());
2015     } else
2016       return false;
2017     break;
2018   case Instruction::Mul: { // SMULL(2), UMULL(2)
2019     // Both operands need to be extends of the same type.
2020     if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) ||
2021         (isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) {
2022       if (!SrcTy)
2023         SrcTy =
2024             toVectorTy(cast<Instruction>(Args[0])->getOperand(0)->getType());
2025     } else if (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1])) {
2026       // If one of the operands is a Zext and the other has enough zero bits to
2027       // be treated as unsigned, we can still general a umull, meaning the zext
2028       // is free.
2029       KnownBits Known =
2030           computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL);
2031       if (Args[0]->getType()->getScalarSizeInBits() -
2032               Known.Zero.countLeadingOnes() >
2033           DstTy->getScalarSizeInBits() / 2)
2034         return false;
2035       if (!SrcTy)
2036         SrcTy = toVectorTy(Type::getIntNTy(DstTy->getContext(),
2037                                            DstTy->getScalarSizeInBits() / 2));
2038     } else
2039       return false;
2040     break;
2041   }
2042   default:
2043     return false;
2044   }
2045 
2046   // Legalize the destination type and ensure it can be used in a widening
2047   // operation.
2048   auto DstTyL = getTypeLegalizationCost(DstTy);
2049   if (!DstTyL.second.isVector() || DstEltSize != DstTy->getScalarSizeInBits())
2050     return false;
2051 
2052   // Legalize the source type and ensure it can be used in a widening
2053   // operation.
2054   assert(SrcTy && "Expected some SrcTy");
2055   auto SrcTyL = getTypeLegalizationCost(SrcTy);
2056   unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits();
2057   if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits())
2058     return false;
2059 
2060   // Get the total number of vector elements in the legalized types.
2061   InstructionCost NumDstEls =
2062       DstTyL.first * DstTyL.second.getVectorMinNumElements();
2063   InstructionCost NumSrcEls =
2064       SrcTyL.first * SrcTyL.second.getVectorMinNumElements();
2065 
2066   // Return true if the legalized types have the same number of vector elements
2067   // and the destination element type size is twice that of the source type.
2068   return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize;
2069 }
2070 
2071 InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
2072                                                  Type *Src,
2073                                                  TTI::CastContextHint CCH,
2074                                                  TTI::TargetCostKind CostKind,
2075                                                  const Instruction *I) {
2076   int ISD = TLI->InstructionOpcodeToISD(Opcode);
2077   assert(ISD && "Invalid opcode");
2078   // If the cast is observable, and it is used by a widening instruction (e.g.,
2079   // uaddl, saddw, etc.), it may be free.
2080   if (I && I->hasOneUser()) {
2081     auto *SingleUser = cast<Instruction>(*I->user_begin());
2082     SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
2083     if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands, Src)) {
2084       // For adds only count the second operand as free if both operands are
2085       // extends but not the same operation. (i.e both operands are not free in
2086       // add(sext, zext)).
2087       if (SingleUser->getOpcode() == Instruction::Add) {
2088         if (I == SingleUser->getOperand(1) ||
2089             (isa<CastInst>(SingleUser->getOperand(1)) &&
2090              cast<CastInst>(SingleUser->getOperand(1))->getOpcode() == Opcode))
2091           return 0;
2092       } else // Others are free so long as isWideningInstruction returned true.
2093         return 0;
2094     }
2095   }
2096 
2097   // TODO: Allow non-throughput costs that aren't binary.
2098   auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost {
2099     if (CostKind != TTI::TCK_RecipThroughput)
2100       return Cost == 0 ? 0 : 1;
2101     return Cost;
2102   };
2103 
2104   EVT SrcTy = TLI->getValueType(DL, Src);
2105   EVT DstTy = TLI->getValueType(DL, Dst);
2106 
2107   if (!SrcTy.isSimple() || !DstTy.isSimple())
2108     return AdjustCost(
2109         BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
2110 
2111   static const TypeConversionCostTblEntry
2112   ConversionTbl[] = {
2113     { ISD::TRUNCATE, MVT::v2i8,   MVT::v2i64,  1},  // xtn
2114     { ISD::TRUNCATE, MVT::v2i16,  MVT::v2i64,  1},  // xtn
2115     { ISD::TRUNCATE, MVT::v2i32,  MVT::v2i64,  1},  // xtn
2116     { ISD::TRUNCATE, MVT::v4i8,   MVT::v4i32,  1},  // xtn
2117     { ISD::TRUNCATE, MVT::v4i8,   MVT::v4i64,  3},  // 2 xtn + 1 uzp1
2118     { ISD::TRUNCATE, MVT::v4i16,  MVT::v4i32,  1},  // xtn
2119     { ISD::TRUNCATE, MVT::v4i16,  MVT::v4i64,  2},  // 1 uzp1 + 1 xtn
2120     { ISD::TRUNCATE, MVT::v4i32,  MVT::v4i64,  1},  // 1 uzp1
2121     { ISD::TRUNCATE, MVT::v8i8,   MVT::v8i16,  1},  // 1 xtn
2122     { ISD::TRUNCATE, MVT::v8i8,   MVT::v8i32,  2},  // 1 uzp1 + 1 xtn
2123     { ISD::TRUNCATE, MVT::v8i8,   MVT::v8i64,  4},  // 3 x uzp1 + xtn
2124     { ISD::TRUNCATE, MVT::v8i16,  MVT::v8i32,  1},  // 1 uzp1
2125     { ISD::TRUNCATE, MVT::v8i16,  MVT::v8i64,  3},  // 3 x uzp1
2126     { ISD::TRUNCATE, MVT::v8i32,  MVT::v8i64,  2},  // 2 x uzp1
2127     { ISD::TRUNCATE, MVT::v16i8,  MVT::v16i16, 1},  // uzp1
2128     { ISD::TRUNCATE, MVT::v16i8,  MVT::v16i32, 3},  // (2 + 1) x uzp1
2129     { ISD::TRUNCATE, MVT::v16i8,  MVT::v16i64, 7},  // (4 + 2 + 1) x uzp1
2130     { ISD::TRUNCATE, MVT::v16i16, MVT::v16i32, 2},  // 2 x uzp1
2131     { ISD::TRUNCATE, MVT::v16i16, MVT::v16i64, 6},  // (4 + 2) x uzp1
2132     { ISD::TRUNCATE, MVT::v16i32, MVT::v16i64, 4},  // 4 x uzp1
2133 
2134     // Truncations on nxvmiN
2135     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i16, 1 },
2136     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i32, 1 },
2137     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i64, 1 },
2138     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i16, 1 },
2139     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i32, 1 },
2140     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i64, 2 },
2141     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i16, 1 },
2142     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i32, 3 },
2143     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i64, 5 },
2144     { ISD::TRUNCATE, MVT::nxv16i1, MVT::nxv16i8, 1 },
2145     { ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i32, 1 },
2146     { ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 1 },
2147     { ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 1 },
2148     { ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 2 },
2149     { ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 3 },
2150     { ISD::TRUNCATE, MVT::nxv8i32, MVT::nxv8i64, 6 },
2151 
2152     // The number of shll instructions for the extension.
2153     { ISD::SIGN_EXTEND, MVT::v4i64,  MVT::v4i16, 3 },
2154     { ISD::ZERO_EXTEND, MVT::v4i64,  MVT::v4i16, 3 },
2155     { ISD::SIGN_EXTEND, MVT::v4i64,  MVT::v4i32, 2 },
2156     { ISD::ZERO_EXTEND, MVT::v4i64,  MVT::v4i32, 2 },
2157     { ISD::SIGN_EXTEND, MVT::v8i32,  MVT::v8i8,  3 },
2158     { ISD::ZERO_EXTEND, MVT::v8i32,  MVT::v8i8,  3 },
2159     { ISD::SIGN_EXTEND, MVT::v8i32,  MVT::v8i16, 2 },
2160     { ISD::ZERO_EXTEND, MVT::v8i32,  MVT::v8i16, 2 },
2161     { ISD::SIGN_EXTEND, MVT::v8i64,  MVT::v8i8,  7 },
2162     { ISD::ZERO_EXTEND, MVT::v8i64,  MVT::v8i8,  7 },
2163     { ISD::SIGN_EXTEND, MVT::v8i64,  MVT::v8i16, 6 },
2164     { ISD::ZERO_EXTEND, MVT::v8i64,  MVT::v8i16, 6 },
2165     { ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
2166     { ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
2167     { ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
2168     { ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
2169 
2170     // LowerVectorINT_TO_FP:
2171     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
2172     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
2173     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
2174     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
2175     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
2176     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
2177 
2178     // Complex: to v2f32
2179     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i8,  3 },
2180     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
2181     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
2182     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i8,  3 },
2183     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
2184     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
2185 
2186     // Complex: to v4f32
2187     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i8,  4 },
2188     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
2189     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8,  3 },
2190     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
2191 
2192     // Complex: to v8f32
2193     { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
2194     { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
2195     { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
2196     { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
2197 
2198     // Complex: to v16f32
2199     { ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
2200     { ISD::UINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
2201 
2202     // Complex: to v2f64
2203     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
2204     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
2205     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
2206     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
2207     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
2208     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
2209 
2210     // Complex: to v4f64
2211     { ISD::SINT_TO_FP, MVT::v4f64, MVT::v4i32,  4 },
2212     { ISD::UINT_TO_FP, MVT::v4f64, MVT::v4i32,  4 },
2213 
2214     // LowerVectorFP_TO_INT
2215     { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f32, 1 },
2216     { ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f32, 1 },
2217     { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f64, 1 },
2218     { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f32, 1 },
2219     { ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f32, 1 },
2220     { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f64, 1 },
2221 
2222     // Complex, from v2f32: legal type is v2i32 (no cost) or v2i64 (1 ext).
2223     { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f32, 2 },
2224     { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f32, 1 },
2225     { ISD::FP_TO_SINT, MVT::v2i8,  MVT::v2f32, 1 },
2226     { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f32, 2 },
2227     { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f32, 1 },
2228     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f32, 1 },
2229 
2230     // Complex, from v4f32: legal type is v4i16, 1 narrowing => ~2
2231     { ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f32, 2 },
2232     { ISD::FP_TO_SINT, MVT::v4i8,  MVT::v4f32, 2 },
2233     { ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 2 },
2234     { ISD::FP_TO_UINT, MVT::v4i8,  MVT::v4f32, 2 },
2235 
2236     // Complex, from nxv2f32.
2237     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
2238     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
2239     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
2240     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f32, 1 },
2241     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
2242     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
2243     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
2244     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f32, 1 },
2245 
2246     // Complex, from v2f64: legal type is v2i32, 1 narrowing => ~2.
2247     { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f64, 2 },
2248     { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f64, 2 },
2249     { ISD::FP_TO_SINT, MVT::v2i8,  MVT::v2f64, 2 },
2250     { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f64, 2 },
2251     { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f64, 2 },
2252     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f64, 2 },
2253 
2254     // Complex, from nxv2f64.
2255     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
2256     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
2257     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
2258     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f64, 1 },
2259     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
2260     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
2261     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
2262     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f64, 1 },
2263 
2264     // Complex, from nxv4f32.
2265     { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
2266     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
2267     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
2268     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f32, 1 },
2269     { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
2270     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
2271     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
2272     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f32, 1 },
2273 
2274     // Complex, from nxv8f64. Illegal -> illegal conversions not required.
2275     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
2276     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f64, 7 },
2277     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
2278     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f64, 7 },
2279 
2280     // Complex, from nxv4f64. Illegal -> illegal conversions not required.
2281     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
2282     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
2283     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f64, 3 },
2284     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
2285     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
2286     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f64, 3 },
2287 
2288     // Complex, from nxv8f32. Illegal -> illegal conversions not required.
2289     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
2290     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f32, 3 },
2291     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
2292     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f32, 3 },
2293 
2294     // Complex, from nxv8f16.
2295     { ISD::FP_TO_SINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
2296     { ISD::FP_TO_SINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
2297     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
2298     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f16, 1 },
2299     { ISD::FP_TO_UINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
2300     { ISD::FP_TO_UINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
2301     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
2302     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f16, 1 },
2303 
2304     // Complex, from nxv4f16.
2305     { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
2306     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
2307     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
2308     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f16, 1 },
2309     { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
2310     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
2311     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
2312     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f16, 1 },
2313 
2314     // Complex, from nxv2f16.
2315     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
2316     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
2317     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
2318     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f16, 1 },
2319     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
2320     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
2321     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
2322     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f16, 1 },
2323 
2324     // Truncate from nxvmf32 to nxvmf16.
2325     { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f32, 1 },
2326     { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1 },
2327     { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3 },
2328 
2329     // Truncate from nxvmf64 to nxvmf16.
2330     { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1 },
2331     { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3 },
2332     { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7 },
2333 
2334     // Truncate from nxvmf64 to nxvmf32.
2335     { ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1 },
2336     { ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3 },
2337     { ISD::FP_ROUND, MVT::nxv8f32, MVT::nxv8f64, 6 },
2338 
2339     // Extend from nxvmf16 to nxvmf32.
2340     { ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2f16, 1},
2341     { ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1},
2342     { ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2},
2343 
2344     // Extend from nxvmf16 to nxvmf64.
2345     { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1},
2346     { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2},
2347     { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4},
2348 
2349     // Extend from nxvmf32 to nxvmf64.
2350     { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1},
2351     { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2},
2352     { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f32, 6},
2353 
2354     // Bitcasts from float to integer
2355     { ISD::BITCAST, MVT::nxv2f16, MVT::nxv2i16, 0 },
2356     { ISD::BITCAST, MVT::nxv4f16, MVT::nxv4i16, 0 },
2357     { ISD::BITCAST, MVT::nxv2f32, MVT::nxv2i32, 0 },
2358 
2359     // Bitcasts from integer to float
2360     { ISD::BITCAST, MVT::nxv2i16, MVT::nxv2f16, 0 },
2361     { ISD::BITCAST, MVT::nxv4i16, MVT::nxv4f16, 0 },
2362     { ISD::BITCAST, MVT::nxv2i32, MVT::nxv2f32, 0 },
2363 
2364     // Add cost for extending to illegal -too wide- scalable vectors.
2365     // zero/sign extend are implemented by multiple unpack operations,
2366     // where each operation has a cost of 1.
2367     { ISD::ZERO_EXTEND, MVT::nxv16i16, MVT::nxv16i8, 2},
2368     { ISD::ZERO_EXTEND, MVT::nxv16i32, MVT::nxv16i8, 6},
2369     { ISD::ZERO_EXTEND, MVT::nxv16i64, MVT::nxv16i8, 14},
2370     { ISD::ZERO_EXTEND, MVT::nxv8i32, MVT::nxv8i16, 2},
2371     { ISD::ZERO_EXTEND, MVT::nxv8i64, MVT::nxv8i16, 6},
2372     { ISD::ZERO_EXTEND, MVT::nxv4i64, MVT::nxv4i32, 2},
2373 
2374     { ISD::SIGN_EXTEND, MVT::nxv16i16, MVT::nxv16i8, 2},
2375     { ISD::SIGN_EXTEND, MVT::nxv16i32, MVT::nxv16i8, 6},
2376     { ISD::SIGN_EXTEND, MVT::nxv16i64, MVT::nxv16i8, 14},
2377     { ISD::SIGN_EXTEND, MVT::nxv8i32, MVT::nxv8i16, 2},
2378     { ISD::SIGN_EXTEND, MVT::nxv8i64, MVT::nxv8i16, 6},
2379     { ISD::SIGN_EXTEND, MVT::nxv4i64, MVT::nxv4i32, 2},
2380   };
2381 
2382   // We have to estimate a cost of fixed length operation upon
2383   // SVE registers(operations) with the number of registers required
2384   // for a fixed type to be represented upon SVE registers.
2385   EVT WiderTy = SrcTy.bitsGT(DstTy) ? SrcTy : DstTy;
2386   if (SrcTy.isFixedLengthVector() && DstTy.isFixedLengthVector() &&
2387       SrcTy.getVectorNumElements() == DstTy.getVectorNumElements() &&
2388       ST->useSVEForFixedLengthVectors(WiderTy)) {
2389     std::pair<InstructionCost, MVT> LT =
2390         getTypeLegalizationCost(WiderTy.getTypeForEVT(Dst->getContext()));
2391     unsigned NumElements = AArch64::SVEBitsPerBlock /
2392                            LT.second.getVectorElementType().getSizeInBits();
2393     return AdjustCost(
2394         LT.first *
2395         getCastInstrCost(
2396             Opcode, ScalableVectorType::get(Dst->getScalarType(), NumElements),
2397             ScalableVectorType::get(Src->getScalarType(), NumElements), CCH,
2398             CostKind, I));
2399   }
2400 
2401   if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD,
2402                                                  DstTy.getSimpleVT(),
2403                                                  SrcTy.getSimpleVT()))
2404     return AdjustCost(Entry->Cost);
2405 
2406   static const TypeConversionCostTblEntry FP16Tbl[] = {
2407       {ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f16, 1}, // fcvtzs
2408       {ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f16, 1},
2409       {ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f16, 1}, // fcvtzs
2410       {ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f16, 1},
2411       {ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f16, 2}, // fcvtl+fcvtzs
2412       {ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f16, 2},
2413       {ISD::FP_TO_SINT, MVT::v8i8, MVT::v8f16, 2}, // fcvtzs+xtn
2414       {ISD::FP_TO_UINT, MVT::v8i8, MVT::v8f16, 2},
2415       {ISD::FP_TO_SINT, MVT::v8i16, MVT::v8f16, 1}, // fcvtzs
2416       {ISD::FP_TO_UINT, MVT::v8i16, MVT::v8f16, 1},
2417       {ISD::FP_TO_SINT, MVT::v8i32, MVT::v8f16, 4}, // 2*fcvtl+2*fcvtzs
2418       {ISD::FP_TO_UINT, MVT::v8i32, MVT::v8f16, 4},
2419       {ISD::FP_TO_SINT, MVT::v16i8, MVT::v16f16, 3}, // 2*fcvtzs+xtn
2420       {ISD::FP_TO_UINT, MVT::v16i8, MVT::v16f16, 3},
2421       {ISD::FP_TO_SINT, MVT::v16i16, MVT::v16f16, 2}, // 2*fcvtzs
2422       {ISD::FP_TO_UINT, MVT::v16i16, MVT::v16f16, 2},
2423       {ISD::FP_TO_SINT, MVT::v16i32, MVT::v16f16, 8}, // 4*fcvtl+4*fcvtzs
2424       {ISD::FP_TO_UINT, MVT::v16i32, MVT::v16f16, 8},
2425       {ISD::UINT_TO_FP, MVT::v8f16, MVT::v8i8, 2},   // ushll + ucvtf
2426       {ISD::SINT_TO_FP, MVT::v8f16, MVT::v8i8, 2},   // sshll + scvtf
2427       {ISD::UINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * ushl(2) + 2 * ucvtf
2428       {ISD::SINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * sshl(2) + 2 * scvtf
2429   };
2430 
2431   if (ST->hasFullFP16())
2432     if (const auto *Entry = ConvertCostTableLookup(
2433             FP16Tbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
2434       return AdjustCost(Entry->Cost);
2435 
2436   // The BasicTTIImpl version only deals with CCH==TTI::CastContextHint::Normal,
2437   // but we also want to include the TTI::CastContextHint::Masked case too.
2438   if ((ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND) &&
2439       CCH == TTI::CastContextHint::Masked && ST->hasSVEorSME() &&
2440       TLI->isTypeLegal(DstTy))
2441     CCH = TTI::CastContextHint::Normal;
2442 
2443   return AdjustCost(
2444       BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
2445 }
2446 
2447 InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode,
2448                                                          Type *Dst,
2449                                                          VectorType *VecTy,
2450                                                          unsigned Index) {
2451 
2452   // Make sure we were given a valid extend opcode.
2453   assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) &&
2454          "Invalid opcode");
2455 
2456   // We are extending an element we extract from a vector, so the source type
2457   // of the extend is the element type of the vector.
2458   auto *Src = VecTy->getElementType();
2459 
2460   // Sign- and zero-extends are for integer types only.
2461   assert(isa<IntegerType>(Dst) && isa<IntegerType>(Src) && "Invalid type");
2462 
2463   // Get the cost for the extract. We compute the cost (if any) for the extend
2464   // below.
2465   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2466   InstructionCost Cost = getVectorInstrCost(Instruction::ExtractElement, VecTy,
2467                                             CostKind, Index, nullptr, nullptr);
2468 
2469   // Legalize the types.
2470   auto VecLT = getTypeLegalizationCost(VecTy);
2471   auto DstVT = TLI->getValueType(DL, Dst);
2472   auto SrcVT = TLI->getValueType(DL, Src);
2473 
2474   // If the resulting type is still a vector and the destination type is legal,
2475   // we may get the extension for free. If not, get the default cost for the
2476   // extend.
2477   if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT))
2478     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
2479                                    CostKind);
2480 
2481   // The destination type should be larger than the element type. If not, get
2482   // the default cost for the extend.
2483   if (DstVT.getFixedSizeInBits() < SrcVT.getFixedSizeInBits())
2484     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
2485                                    CostKind);
2486 
2487   switch (Opcode) {
2488   default:
2489     llvm_unreachable("Opcode should be either SExt or ZExt");
2490 
2491   // For sign-extends, we only need a smov, which performs the extension
2492   // automatically.
2493   case Instruction::SExt:
2494     return Cost;
2495 
2496   // For zero-extends, the extend is performed automatically by a umov unless
2497   // the destination type is i64 and the element type is i8 or i16.
2498   case Instruction::ZExt:
2499     if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u)
2500       return Cost;
2501   }
2502 
2503   // If we are unable to perform the extend for free, get the default cost.
2504   return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
2505                                  CostKind);
2506 }
2507 
2508 InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
2509                                                TTI::TargetCostKind CostKind,
2510                                                const Instruction *I) {
2511   if (CostKind != TTI::TCK_RecipThroughput)
2512     return Opcode == Instruction::PHI ? 0 : 1;
2513   assert(CostKind == TTI::TCK_RecipThroughput && "unexpected CostKind");
2514   // Branches are assumed to be predicted.
2515   return 0;
2516 }
2517 
2518 InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
2519                                                          Type *Val,
2520                                                          unsigned Index,
2521                                                          bool HasRealUse) {
2522   assert(Val->isVectorTy() && "This must be a vector type");
2523 
2524   if (Index != -1U) {
2525     // Legalize the type.
2526     std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Val);
2527 
2528     // This type is legalized to a scalar type.
2529     if (!LT.second.isVector())
2530       return 0;
2531 
2532     // The type may be split. For fixed-width vectors we can normalize the
2533     // index to the new type.
2534     if (LT.second.isFixedLengthVector()) {
2535       unsigned Width = LT.second.getVectorNumElements();
2536       Index = Index % Width;
2537     }
2538 
2539     // The element at index zero is already inside the vector.
2540     // - For a physical (HasRealUse==true) insert-element or extract-element
2541     // instruction that extracts integers, an explicit FPR -> GPR move is
2542     // needed. So it has non-zero cost.
2543     // - For the rest of cases (virtual instruction or element type is float),
2544     // consider the instruction free.
2545     if (Index == 0 && (!HasRealUse || !Val->getScalarType()->isIntegerTy()))
2546       return 0;
2547 
2548     // This is recognising a LD1 single-element structure to one lane of one
2549     // register instruction. I.e., if this is an `insertelement` instruction,
2550     // and its second operand is a load, then we will generate a LD1, which
2551     // are expensive instructions.
2552     if (I && dyn_cast<LoadInst>(I->getOperand(1)))
2553       return ST->getVectorInsertExtractBaseCost() + 1;
2554 
2555     // i1 inserts and extract will include an extra cset or cmp of the vector
2556     // value. Increase the cost by 1 to account.
2557     if (Val->getScalarSizeInBits() == 1)
2558       return ST->getVectorInsertExtractBaseCost() + 1;
2559 
2560     // FIXME:
2561     // If the extract-element and insert-element instructions could be
2562     // simplified away (e.g., could be combined into users by looking at use-def
2563     // context), they have no cost. This is not done in the first place for
2564     // compile-time considerations.
2565   }
2566 
2567   // All other insert/extracts cost this much.
2568   return ST->getVectorInsertExtractBaseCost();
2569 }
2570 
2571 InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
2572                                                    TTI::TargetCostKind CostKind,
2573                                                    unsigned Index, Value *Op0,
2574                                                    Value *Op1) {
2575   bool HasRealUse =
2576       Opcode == Instruction::InsertElement && Op0 && !isa<UndefValue>(Op0);
2577   return getVectorInstrCostHelper(nullptr, Val, Index, HasRealUse);
2578 }
2579 
2580 InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,
2581                                                    Type *Val,
2582                                                    TTI::TargetCostKind CostKind,
2583                                                    unsigned Index) {
2584   return getVectorInstrCostHelper(&I, Val, Index, true /* HasRealUse */);
2585 }
2586 
2587 InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
2588     unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
2589     TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info,
2590     ArrayRef<const Value *> Args,
2591     const Instruction *CxtI) {
2592 
2593   // TODO: Handle more cost kinds.
2594   if (CostKind != TTI::TCK_RecipThroughput)
2595     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
2596                                          Op2Info, Args, CxtI);
2597 
2598   // Legalize the type.
2599   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
2600   int ISD = TLI->InstructionOpcodeToISD(Opcode);
2601 
2602   switch (ISD) {
2603   default:
2604     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
2605                                          Op2Info);
2606   case ISD::SDIV:
2607     if (Op2Info.isConstant() && Op2Info.isUniform() && Op2Info.isPowerOf2()) {
2608       // On AArch64, scalar signed division by constants power-of-two are
2609       // normally expanded to the sequence ADD + CMP + SELECT + SRA.
2610       // The OperandValue properties many not be same as that of previous
2611       // operation; conservatively assume OP_None.
2612       InstructionCost Cost = getArithmeticInstrCost(
2613           Instruction::Add, Ty, CostKind,
2614           Op1Info.getNoProps(), Op2Info.getNoProps());
2615       Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind,
2616                                      Op1Info.getNoProps(), Op2Info.getNoProps());
2617       Cost += getArithmeticInstrCost(
2618           Instruction::Select, Ty, CostKind,
2619           Op1Info.getNoProps(), Op2Info.getNoProps());
2620       Cost += getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
2621                                      Op1Info.getNoProps(), Op2Info.getNoProps());
2622       return Cost;
2623     }
2624     [[fallthrough]];
2625   case ISD::UDIV: {
2626     if (Op2Info.isConstant() && Op2Info.isUniform()) {
2627       auto VT = TLI->getValueType(DL, Ty);
2628       if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT)) {
2629         // Vector signed division by constant are expanded to the
2630         // sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division
2631         // to MULHS + SUB + SRL + ADD + SRL.
2632         InstructionCost MulCost = getArithmeticInstrCost(
2633             Instruction::Mul, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
2634         InstructionCost AddCost = getArithmeticInstrCost(
2635             Instruction::Add, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
2636         InstructionCost ShrCost = getArithmeticInstrCost(
2637             Instruction::AShr, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
2638         return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1;
2639       }
2640     }
2641 
2642     InstructionCost Cost = BaseT::getArithmeticInstrCost(
2643         Opcode, Ty, CostKind, Op1Info, Op2Info);
2644     if (Ty->isVectorTy()) {
2645       if (TLI->isOperationLegalOrCustom(ISD, LT.second) && ST->hasSVE()) {
2646         // SDIV/UDIV operations are lowered using SVE, then we can have less
2647         // costs.
2648         if (isa<FixedVectorType>(Ty) && cast<FixedVectorType>(Ty)
2649                                                 ->getPrimitiveSizeInBits()
2650                                                 .getFixedValue() < 128) {
2651           EVT VT = TLI->getValueType(DL, Ty);
2652           static const CostTblEntry DivTbl[]{
2653               {ISD::SDIV, MVT::v2i8, 5},  {ISD::SDIV, MVT::v4i8, 8},
2654               {ISD::SDIV, MVT::v8i8, 8},  {ISD::SDIV, MVT::v2i16, 5},
2655               {ISD::SDIV, MVT::v4i16, 5}, {ISD::SDIV, MVT::v2i32, 1},
2656               {ISD::UDIV, MVT::v2i8, 5},  {ISD::UDIV, MVT::v4i8, 8},
2657               {ISD::UDIV, MVT::v8i8, 8},  {ISD::UDIV, MVT::v2i16, 5},
2658               {ISD::UDIV, MVT::v4i16, 5}, {ISD::UDIV, MVT::v2i32, 1}};
2659 
2660           const auto *Entry = CostTableLookup(DivTbl, ISD, VT.getSimpleVT());
2661           if (nullptr != Entry)
2662             return Entry->Cost;
2663         }
2664         // For 8/16-bit elements, the cost is higher because the type
2665         // requires promotion and possibly splitting:
2666         if (LT.second.getScalarType() == MVT::i8)
2667           Cost *= 8;
2668         else if (LT.second.getScalarType() == MVT::i16)
2669           Cost *= 4;
2670         return Cost;
2671       } else {
2672         // If one of the operands is a uniform constant then the cost for each
2673         // element is Cost for insertion, extraction and division.
2674         // Insertion cost = 2, Extraction Cost = 2, Division = cost for the
2675         // operation with scalar type
2676         if ((Op1Info.isConstant() && Op1Info.isUniform()) ||
2677             (Op2Info.isConstant() && Op2Info.isUniform())) {
2678           if (auto *VTy = dyn_cast<FixedVectorType>(Ty)) {
2679             InstructionCost DivCost = BaseT::getArithmeticInstrCost(
2680                 Opcode, Ty->getScalarType(), CostKind, Op1Info, Op2Info);
2681             return (4 + DivCost) * VTy->getNumElements();
2682           }
2683         }
2684         // On AArch64, without SVE, vector divisions are expanded
2685         // into scalar divisions of each pair of elements.
2686         Cost += getArithmeticInstrCost(Instruction::ExtractElement, Ty,
2687                                        CostKind, Op1Info, Op2Info);
2688         Cost += getArithmeticInstrCost(Instruction::InsertElement, Ty, CostKind,
2689                                        Op1Info, Op2Info);
2690       }
2691 
2692       // TODO: if one of the arguments is scalar, then it's not necessary to
2693       // double the cost of handling the vector elements.
2694       Cost += Cost;
2695     }
2696     return Cost;
2697   }
2698   case ISD::MUL:
2699     // When SVE is available, then we can lower the v2i64 operation using
2700     // the SVE mul instruction, which has a lower cost.
2701     if (LT.second == MVT::v2i64 && ST->hasSVE())
2702       return LT.first;
2703 
2704     // When SVE is not available, there is no MUL.2d instruction,
2705     // which means mul <2 x i64> is expensive as elements are extracted
2706     // from the vectors and the muls scalarized.
2707     // As getScalarizationOverhead is a bit too pessimistic, we
2708     // estimate the cost for a i64 vector directly here, which is:
2709     // - four 2-cost i64 extracts,
2710     // - two 2-cost i64 inserts, and
2711     // - two 1-cost muls.
2712     // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with
2713     // LT.first = 2 the cost is 28. If both operands are extensions it will not
2714     // need to scalarize so the cost can be cheaper (smull or umull).
2715     // so the cost can be cheaper (smull or umull).
2716     if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args))
2717       return LT.first;
2718     return LT.first * 14;
2719   case ISD::ADD:
2720   case ISD::XOR:
2721   case ISD::OR:
2722   case ISD::AND:
2723   case ISD::SRL:
2724   case ISD::SRA:
2725   case ISD::SHL:
2726     // These nodes are marked as 'custom' for combining purposes only.
2727     // We know that they are legal. See LowerAdd in ISelLowering.
2728     return LT.first;
2729 
2730   case ISD::FNEG:
2731   case ISD::FADD:
2732   case ISD::FSUB:
2733     // Increase the cost for half and bfloat types if not architecturally
2734     // supported.
2735     if ((Ty->getScalarType()->isHalfTy() && !ST->hasFullFP16()) ||
2736         (Ty->getScalarType()->isBFloatTy() && !ST->hasBF16()))
2737       return 2 * LT.first;
2738     if (!Ty->getScalarType()->isFP128Ty())
2739       return LT.first;
2740     [[fallthrough]];
2741   case ISD::FMUL:
2742   case ISD::FDIV:
2743     // These nodes are marked as 'custom' just to lower them to SVE.
2744     // We know said lowering will incur no additional cost.
2745     if (!Ty->getScalarType()->isFP128Ty())
2746       return 2 * LT.first;
2747 
2748     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
2749                                          Op2Info);
2750   }
2751 }
2752 
2753 InstructionCost AArch64TTIImpl::getAddressComputationCost(Type *Ty,
2754                                                           ScalarEvolution *SE,
2755                                                           const SCEV *Ptr) {
2756   // Address computations in vectorized code with non-consecutive addresses will
2757   // likely result in more instructions compared to scalar code where the
2758   // computation can more often be merged into the index mode. The resulting
2759   // extra micro-ops can significantly decrease throughput.
2760   unsigned NumVectorInstToHideOverhead = NeonNonConstStrideOverhead;
2761   int MaxMergeDistance = 64;
2762 
2763   if (Ty->isVectorTy() && SE &&
2764       !BaseT::isConstantStridedAccessLessThan(SE, Ptr, MaxMergeDistance + 1))
2765     return NumVectorInstToHideOverhead;
2766 
2767   // In many cases the address computation is not merged into the instruction
2768   // addressing mode.
2769   return 1;
2770 }
2771 
2772 InstructionCost AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
2773                                                    Type *CondTy,
2774                                                    CmpInst::Predicate VecPred,
2775                                                    TTI::TargetCostKind CostKind,
2776                                                    const Instruction *I) {
2777   // TODO: Handle other cost kinds.
2778   if (CostKind != TTI::TCK_RecipThroughput)
2779     return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
2780                                      I);
2781 
2782   int ISD = TLI->InstructionOpcodeToISD(Opcode);
2783   // We don't lower some vector selects well that are wider than the register
2784   // width.
2785   if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT) {
2786     // We would need this many instructions to hide the scalarization happening.
2787     const int AmortizationCost = 20;
2788 
2789     // If VecPred is not set, check if we can get a predicate from the context
2790     // instruction, if its type matches the requested ValTy.
2791     if (VecPred == CmpInst::BAD_ICMP_PREDICATE && I && I->getType() == ValTy) {
2792       CmpInst::Predicate CurrentPred;
2793       if (match(I, m_Select(m_Cmp(CurrentPred, m_Value(), m_Value()), m_Value(),
2794                             m_Value())))
2795         VecPred = CurrentPred;
2796     }
2797     // Check if we have a compare/select chain that can be lowered using
2798     // a (F)CMxx & BFI pair.
2799     if (CmpInst::isIntPredicate(VecPred) || VecPred == CmpInst::FCMP_OLE ||
2800         VecPred == CmpInst::FCMP_OLT || VecPred == CmpInst::FCMP_OGT ||
2801         VecPred == CmpInst::FCMP_OGE || VecPred == CmpInst::FCMP_OEQ ||
2802         VecPred == CmpInst::FCMP_UNE) {
2803       static const auto ValidMinMaxTys = {
2804           MVT::v8i8,  MVT::v16i8, MVT::v4i16, MVT::v8i16, MVT::v2i32,
2805           MVT::v4i32, MVT::v2i64, MVT::v2f32, MVT::v4f32, MVT::v2f64};
2806       static const auto ValidFP16MinMaxTys = {MVT::v4f16, MVT::v8f16};
2807 
2808       auto LT = getTypeLegalizationCost(ValTy);
2809       if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }) ||
2810           (ST->hasFullFP16() &&
2811            any_of(ValidFP16MinMaxTys, [&LT](MVT M) { return M == LT.second; })))
2812         return LT.first;
2813     }
2814 
2815     static const TypeConversionCostTblEntry
2816     VectorSelectTbl[] = {
2817       { ISD::SELECT, MVT::v2i1, MVT::v2f32, 2 },
2818       { ISD::SELECT, MVT::v2i1, MVT::v2f64, 2 },
2819       { ISD::SELECT, MVT::v4i1, MVT::v4f32, 2 },
2820       { ISD::SELECT, MVT::v4i1, MVT::v4f16, 2 },
2821       { ISD::SELECT, MVT::v8i1, MVT::v8f16, 2 },
2822       { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 },
2823       { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 },
2824       { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 },
2825       { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost },
2826       { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost },
2827       { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost }
2828     };
2829 
2830     EVT SelCondTy = TLI->getValueType(DL, CondTy);
2831     EVT SelValTy = TLI->getValueType(DL, ValTy);
2832     if (SelCondTy.isSimple() && SelValTy.isSimple()) {
2833       if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD,
2834                                                      SelCondTy.getSimpleVT(),
2835                                                      SelValTy.getSimpleVT()))
2836         return Entry->Cost;
2837     }
2838   }
2839 
2840   if (isa<FixedVectorType>(ValTy) && ISD == ISD::SETCC) {
2841     auto LT = getTypeLegalizationCost(ValTy);
2842     // Cost v4f16 FCmp without FP16 support via converting to v4f32 and back.
2843     if (LT.second == MVT::v4f16 && !ST->hasFullFP16())
2844       return LT.first * 4; // fcvtl + fcvtl + fcmp + xtn
2845   }
2846 
2847   // Treat the icmp in icmp(and, 0) as free, as we can make use of ands.
2848   // FIXME: This can apply to more conditions and add/sub if it can be shown to
2849   // be profitable.
2850   if (ValTy->isIntegerTy() && ISD == ISD::SETCC && I &&
2851       ICmpInst::isEquality(VecPred) &&
2852       TLI->isTypeLegal(TLI->getValueType(DL, ValTy)) &&
2853       match(I->getOperand(1), m_Zero()) &&
2854       match(I->getOperand(0), m_And(m_Value(), m_Value())))
2855     return 0;
2856 
2857   // The base case handles scalable vectors fine for now, since it treats the
2858   // cost as 1 * legalization cost.
2859   return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I);
2860 }
2861 
2862 AArch64TTIImpl::TTI::MemCmpExpansionOptions
2863 AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
2864   TTI::MemCmpExpansionOptions Options;
2865   if (ST->requiresStrictAlign()) {
2866     // TODO: Add cost modeling for strict align. Misaligned loads expand to
2867     // a bunch of instructions when strict align is enabled.
2868     return Options;
2869   }
2870   Options.AllowOverlappingLoads = true;
2871   Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize);
2872   Options.NumLoadsPerBlock = Options.MaxNumLoads;
2873   // TODO: Though vector loads usually perform well on AArch64, in some targets
2874   // they may wake up the FP unit, which raises the power consumption.  Perhaps
2875   // they could be used with no holds barred (-O3).
2876   Options.LoadSizes = {8, 4, 2, 1};
2877   return Options;
2878 }
2879 
2880 bool AArch64TTIImpl::prefersVectorizedAddressing() const {
2881   return ST->hasSVE();
2882 }
2883 
2884 InstructionCost
2885 AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
2886                                       Align Alignment, unsigned AddressSpace,
2887                                       TTI::TargetCostKind CostKind) {
2888   if (useNeonVector(Src))
2889     return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace,
2890                                         CostKind);
2891   auto LT = getTypeLegalizationCost(Src);
2892   if (!LT.first.isValid())
2893     return InstructionCost::getInvalid();
2894 
2895   // The code-generator is currently not able to handle scalable vectors
2896   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
2897   // it. This change will be removed when code-generation for these types is
2898   // sufficiently reliable.
2899   if (cast<VectorType>(Src)->getElementCount() == ElementCount::getScalable(1))
2900     return InstructionCost::getInvalid();
2901 
2902   return LT.first;
2903 }
2904 
2905 static unsigned getSVEGatherScatterOverhead(unsigned Opcode) {
2906   return Opcode == Instruction::Load ? SVEGatherOverhead : SVEScatterOverhead;
2907 }
2908 
2909 InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
2910     unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
2911     Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
2912   if (useNeonVector(DataTy))
2913     return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
2914                                          Alignment, CostKind, I);
2915   auto *VT = cast<VectorType>(DataTy);
2916   auto LT = getTypeLegalizationCost(DataTy);
2917   if (!LT.first.isValid())
2918     return InstructionCost::getInvalid();
2919 
2920   // The code-generator is currently not able to handle scalable vectors
2921   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
2922   // it. This change will be removed when code-generation for these types is
2923   // sufficiently reliable.
2924   if (cast<VectorType>(DataTy)->getElementCount() ==
2925       ElementCount::getScalable(1))
2926     return InstructionCost::getInvalid();
2927 
2928   ElementCount LegalVF = LT.second.getVectorElementCount();
2929   InstructionCost MemOpCost =
2930       getMemoryOpCost(Opcode, VT->getElementType(), Alignment, 0, CostKind,
2931                       {TTI::OK_AnyValue, TTI::OP_None}, I);
2932   // Add on an overhead cost for using gathers/scatters.
2933   // TODO: At the moment this is applied unilaterally for all CPUs, but at some
2934   // point we may want a per-CPU overhead.
2935   MemOpCost *= getSVEGatherScatterOverhead(Opcode);
2936   return LT.first * MemOpCost * getMaxNumElements(LegalVF);
2937 }
2938 
2939 bool AArch64TTIImpl::useNeonVector(const Type *Ty) const {
2940   return isa<FixedVectorType>(Ty) && !ST->useSVEForFixedLengthVectors();
2941 }
2942 
2943 InstructionCost AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty,
2944                                                 MaybeAlign Alignment,
2945                                                 unsigned AddressSpace,
2946                                                 TTI::TargetCostKind CostKind,
2947                                                 TTI::OperandValueInfo OpInfo,
2948                                                 const Instruction *I) {
2949   EVT VT = TLI->getValueType(DL, Ty, true);
2950   // Type legalization can't handle structs
2951   if (VT == MVT::Other)
2952     return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
2953                                   CostKind);
2954 
2955   auto LT = getTypeLegalizationCost(Ty);
2956   if (!LT.first.isValid())
2957     return InstructionCost::getInvalid();
2958 
2959   // The code-generator is currently not able to handle scalable vectors
2960   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
2961   // it. This change will be removed when code-generation for these types is
2962   // sufficiently reliable.
2963   if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
2964     if (VTy->getElementCount() == ElementCount::getScalable(1))
2965       return InstructionCost::getInvalid();
2966 
2967   // TODO: consider latency as well for TCK_SizeAndLatency.
2968   if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency)
2969     return LT.first;
2970 
2971   if (CostKind != TTI::TCK_RecipThroughput)
2972     return 1;
2973 
2974   if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store &&
2975       LT.second.is128BitVector() && (!Alignment || *Alignment < Align(16))) {
2976     // Unaligned stores are extremely inefficient. We don't split all
2977     // unaligned 128-bit stores because the negative impact that has shown in
2978     // practice on inlined block copy code.
2979     // We make such stores expensive so that we will only vectorize if there
2980     // are 6 other instructions getting vectorized.
2981     const int AmortizationCost = 6;
2982 
2983     return LT.first * 2 * AmortizationCost;
2984   }
2985 
2986   // Opaque ptr or ptr vector types are i64s and can be lowered to STP/LDPs.
2987   if (Ty->isPtrOrPtrVectorTy())
2988     return LT.first;
2989 
2990   // Check truncating stores and extending loads.
2991   if (useNeonVector(Ty) &&
2992       Ty->getScalarSizeInBits() != LT.second.getScalarSizeInBits()) {
2993     // v4i8 types are lowered to scalar a load/store and sshll/xtn.
2994     if (VT == MVT::v4i8)
2995       return 2;
2996     // Otherwise we need to scalarize.
2997     return cast<FixedVectorType>(Ty)->getNumElements() * 2;
2998   }
2999 
3000   return LT.first;
3001 }
3002 
3003 InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
3004     unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
3005     Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
3006     bool UseMaskForCond, bool UseMaskForGaps) {
3007   assert(Factor >= 2 && "Invalid interleave factor");
3008   auto *VecVTy = cast<VectorType>(VecTy);
3009 
3010   if (VecTy->isScalableTy() && (!ST->hasSVE() || Factor != 2))
3011     return InstructionCost::getInvalid();
3012 
3013   // Vectorization for masked interleaved accesses is only enabled for scalable
3014   // VF.
3015   if (!VecTy->isScalableTy() && (UseMaskForCond || UseMaskForGaps))
3016     return InstructionCost::getInvalid();
3017 
3018   if (!UseMaskForGaps && Factor <= TLI->getMaxSupportedInterleaveFactor()) {
3019     unsigned MinElts = VecVTy->getElementCount().getKnownMinValue();
3020     auto *SubVecTy =
3021         VectorType::get(VecVTy->getElementType(),
3022                         VecVTy->getElementCount().divideCoefficientBy(Factor));
3023 
3024     // ldN/stN only support legal vector types of size 64 or 128 in bits.
3025     // Accesses having vector types that are a multiple of 128 bits can be
3026     // matched to more than one ldN/stN instruction.
3027     bool UseScalable;
3028     if (MinElts % Factor == 0 &&
3029         TLI->isLegalInterleavedAccessType(SubVecTy, DL, UseScalable))
3030       return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL, UseScalable);
3031   }
3032 
3033   return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices,
3034                                            Alignment, AddressSpace, CostKind,
3035                                            UseMaskForCond, UseMaskForGaps);
3036 }
3037 
3038 InstructionCost
3039 AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
3040   InstructionCost Cost = 0;
3041   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
3042   for (auto *I : Tys) {
3043     if (!I->isVectorTy())
3044       continue;
3045     if (I->getScalarSizeInBits() * cast<FixedVectorType>(I)->getNumElements() ==
3046         128)
3047       Cost += getMemoryOpCost(Instruction::Store, I, Align(128), 0, CostKind) +
3048               getMemoryOpCost(Instruction::Load, I, Align(128), 0, CostKind);
3049   }
3050   return Cost;
3051 }
3052 
3053 unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
3054   return ST->getMaxInterleaveFactor();
3055 }
3056 
3057 // For Falkor, we want to avoid having too many strided loads in a loop since
3058 // that can exhaust the HW prefetcher resources.  We adjust the unroller
3059 // MaxCount preference below to attempt to ensure unrolling doesn't create too
3060 // many strided loads.
3061 static void
3062 getFalkorUnrollingPreferences(Loop *L, ScalarEvolution &SE,
3063                               TargetTransformInfo::UnrollingPreferences &UP) {
3064   enum { MaxStridedLoads = 7 };
3065   auto countStridedLoads = [](Loop *L, ScalarEvolution &SE) {
3066     int StridedLoads = 0;
3067     // FIXME? We could make this more precise by looking at the CFG and
3068     // e.g. not counting loads in each side of an if-then-else diamond.
3069     for (const auto BB : L->blocks()) {
3070       for (auto &I : *BB) {
3071         LoadInst *LMemI = dyn_cast<LoadInst>(&I);
3072         if (!LMemI)
3073           continue;
3074 
3075         Value *PtrValue = LMemI->getPointerOperand();
3076         if (L->isLoopInvariant(PtrValue))
3077           continue;
3078 
3079         const SCEV *LSCEV = SE.getSCEV(PtrValue);
3080         const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
3081         if (!LSCEVAddRec || !LSCEVAddRec->isAffine())
3082           continue;
3083 
3084         // FIXME? We could take pairing of unrolled load copies into account
3085         // by looking at the AddRec, but we would probably have to limit this
3086         // to loops with no stores or other memory optimization barriers.
3087         ++StridedLoads;
3088         // We've seen enough strided loads that seeing more won't make a
3089         // difference.
3090         if (StridedLoads > MaxStridedLoads / 2)
3091           return StridedLoads;
3092       }
3093     }
3094     return StridedLoads;
3095   };
3096 
3097   int StridedLoads = countStridedLoads(L, SE);
3098   LLVM_DEBUG(dbgs() << "falkor-hwpf: detected " << StridedLoads
3099                     << " strided loads\n");
3100   // Pick the largest power of 2 unroll count that won't result in too many
3101   // strided loads.
3102   if (StridedLoads) {
3103     UP.MaxCount = 1 << Log2_32(MaxStridedLoads / StridedLoads);
3104     LLVM_DEBUG(dbgs() << "falkor-hwpf: setting unroll MaxCount to "
3105                       << UP.MaxCount << '\n');
3106   }
3107 }
3108 
3109 void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
3110                                              TTI::UnrollingPreferences &UP,
3111                                              OptimizationRemarkEmitter *ORE) {
3112   // Enable partial unrolling and runtime unrolling.
3113   BaseT::getUnrollingPreferences(L, SE, UP, ORE);
3114 
3115   UP.UpperBound = true;
3116 
3117   // For inner loop, it is more likely to be a hot one, and the runtime check
3118   // can be promoted out from LICM pass, so the overhead is less, let's try
3119   // a larger threshold to unroll more loops.
3120   if (L->getLoopDepth() > 1)
3121     UP.PartialThreshold *= 2;
3122 
3123   // Disable partial & runtime unrolling on -Os.
3124   UP.PartialOptSizeThreshold = 0;
3125 
3126   if (ST->getProcFamily() == AArch64Subtarget::Falkor &&
3127       EnableFalkorHWPFUnrollFix)
3128     getFalkorUnrollingPreferences(L, SE, UP);
3129 
3130   // Scan the loop: don't unroll loops with calls as this could prevent
3131   // inlining. Don't unroll vector loops either, as they don't benefit much from
3132   // unrolling.
3133   for (auto *BB : L->getBlocks()) {
3134     for (auto &I : *BB) {
3135       // Don't unroll vectorised loop.
3136       if (I.getType()->isVectorTy())
3137         return;
3138 
3139       if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
3140         if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
3141           if (!isLoweredToCall(F))
3142             continue;
3143         }
3144         return;
3145       }
3146     }
3147   }
3148 
3149   // Enable runtime unrolling for in-order models
3150   // If mcpu is omitted, getProcFamily() returns AArch64Subtarget::Others, so by
3151   // checking for that case, we can ensure that the default behaviour is
3152   // unchanged
3153   if (ST->getProcFamily() != AArch64Subtarget::Others &&
3154       !ST->getSchedModel().isOutOfOrder()) {
3155     UP.Runtime = true;
3156     UP.Partial = true;
3157     UP.UnrollRemainder = true;
3158     UP.DefaultUnrollRuntimeCount = 4;
3159 
3160     UP.UnrollAndJam = true;
3161     UP.UnrollAndJamInnerLoopThreshold = 60;
3162   }
3163 }
3164 
3165 void AArch64TTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE,
3166                                            TTI::PeelingPreferences &PP) {
3167   BaseT::getPeelingPreferences(L, SE, PP);
3168 }
3169 
3170 Value *AArch64TTIImpl::getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst,
3171                                                          Type *ExpectedType) {
3172   switch (Inst->getIntrinsicID()) {
3173   default:
3174     return nullptr;
3175   case Intrinsic::aarch64_neon_st2:
3176   case Intrinsic::aarch64_neon_st3:
3177   case Intrinsic::aarch64_neon_st4: {
3178     // Create a struct type
3179     StructType *ST = dyn_cast<StructType>(ExpectedType);
3180     if (!ST)
3181       return nullptr;
3182     unsigned NumElts = Inst->arg_size() - 1;
3183     if (ST->getNumElements() != NumElts)
3184       return nullptr;
3185     for (unsigned i = 0, e = NumElts; i != e; ++i) {
3186       if (Inst->getArgOperand(i)->getType() != ST->getElementType(i))
3187         return nullptr;
3188     }
3189     Value *Res = PoisonValue::get(ExpectedType);
3190     IRBuilder<> Builder(Inst);
3191     for (unsigned i = 0, e = NumElts; i != e; ++i) {
3192       Value *L = Inst->getArgOperand(i);
3193       Res = Builder.CreateInsertValue(Res, L, i);
3194     }
3195     return Res;
3196   }
3197   case Intrinsic::aarch64_neon_ld2:
3198   case Intrinsic::aarch64_neon_ld3:
3199   case Intrinsic::aarch64_neon_ld4:
3200     if (Inst->getType() == ExpectedType)
3201       return Inst;
3202     return nullptr;
3203   }
3204 }
3205 
3206 bool AArch64TTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
3207                                         MemIntrinsicInfo &Info) {
3208   switch (Inst->getIntrinsicID()) {
3209   default:
3210     break;
3211   case Intrinsic::aarch64_neon_ld2:
3212   case Intrinsic::aarch64_neon_ld3:
3213   case Intrinsic::aarch64_neon_ld4:
3214     Info.ReadMem = true;
3215     Info.WriteMem = false;
3216     Info.PtrVal = Inst->getArgOperand(0);
3217     break;
3218   case Intrinsic::aarch64_neon_st2:
3219   case Intrinsic::aarch64_neon_st3:
3220   case Intrinsic::aarch64_neon_st4:
3221     Info.ReadMem = false;
3222     Info.WriteMem = true;
3223     Info.PtrVal = Inst->getArgOperand(Inst->arg_size() - 1);
3224     break;
3225   }
3226 
3227   switch (Inst->getIntrinsicID()) {
3228   default:
3229     return false;
3230   case Intrinsic::aarch64_neon_ld2:
3231   case Intrinsic::aarch64_neon_st2:
3232     Info.MatchingId = VECTOR_LDST_TWO_ELEMENTS;
3233     break;
3234   case Intrinsic::aarch64_neon_ld3:
3235   case Intrinsic::aarch64_neon_st3:
3236     Info.MatchingId = VECTOR_LDST_THREE_ELEMENTS;
3237     break;
3238   case Intrinsic::aarch64_neon_ld4:
3239   case Intrinsic::aarch64_neon_st4:
3240     Info.MatchingId = VECTOR_LDST_FOUR_ELEMENTS;
3241     break;
3242   }
3243   return true;
3244 }
3245 
3246 /// See if \p I should be considered for address type promotion. We check if \p
3247 /// I is a sext with right type and used in memory accesses. If it used in a
3248 /// "complex" getelementptr, we allow it to be promoted without finding other
3249 /// sext instructions that sign extended the same initial value. A getelementptr
3250 /// is considered as "complex" if it has more than 2 operands.
3251 bool AArch64TTIImpl::shouldConsiderAddressTypePromotion(
3252     const Instruction &I, bool &AllowPromotionWithoutCommonHeader) {
3253   bool Considerable = false;
3254   AllowPromotionWithoutCommonHeader = false;
3255   if (!isa<SExtInst>(&I))
3256     return false;
3257   Type *ConsideredSExtType =
3258       Type::getInt64Ty(I.getParent()->getParent()->getContext());
3259   if (I.getType() != ConsideredSExtType)
3260     return false;
3261   // See if the sext is the one with the right type and used in at least one
3262   // GetElementPtrInst.
3263   for (const User *U : I.users()) {
3264     if (const GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(U)) {
3265       Considerable = true;
3266       // A getelementptr is considered as "complex" if it has more than 2
3267       // operands. We will promote a SExt used in such complex GEP as we
3268       // expect some computation to be merged if they are done on 64 bits.
3269       if (GEPInst->getNumOperands() > 2) {
3270         AllowPromotionWithoutCommonHeader = true;
3271         break;
3272       }
3273     }
3274   }
3275   return Considerable;
3276 }
3277 
3278 bool AArch64TTIImpl::isLegalToVectorizeReduction(
3279     const RecurrenceDescriptor &RdxDesc, ElementCount VF) const {
3280   if (!VF.isScalable())
3281     return true;
3282 
3283   Type *Ty = RdxDesc.getRecurrenceType();
3284   if (Ty->isBFloatTy() || !isElementTypeLegalForScalableVector(Ty))
3285     return false;
3286 
3287   switch (RdxDesc.getRecurrenceKind()) {
3288   case RecurKind::Add:
3289   case RecurKind::FAdd:
3290   case RecurKind::And:
3291   case RecurKind::Or:
3292   case RecurKind::Xor:
3293   case RecurKind::SMin:
3294   case RecurKind::SMax:
3295   case RecurKind::UMin:
3296   case RecurKind::UMax:
3297   case RecurKind::FMin:
3298   case RecurKind::FMax:
3299   case RecurKind::SelectICmp:
3300   case RecurKind::SelectFCmp:
3301   case RecurKind::FMulAdd:
3302     return true;
3303   default:
3304     return false;
3305   }
3306 }
3307 
3308 InstructionCost
3309 AArch64TTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty,
3310                                        FastMathFlags FMF,
3311                                        TTI::TargetCostKind CostKind) {
3312   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
3313 
3314   if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
3315     return BaseT::getMinMaxReductionCost(IID, Ty, FMF, CostKind);
3316 
3317   InstructionCost LegalizationCost = 0;
3318   if (LT.first > 1) {
3319     Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext());
3320     IntrinsicCostAttributes Attrs(IID, LegalVTy, {LegalVTy, LegalVTy}, FMF);
3321     LegalizationCost = getIntrinsicInstrCost(Attrs, CostKind) * (LT.first - 1);
3322   }
3323 
3324   return LegalizationCost + /*Cost of horizontal reduction*/ 2;
3325 }
3326 
3327 InstructionCost AArch64TTIImpl::getArithmeticReductionCostSVE(
3328     unsigned Opcode, VectorType *ValTy, TTI::TargetCostKind CostKind) {
3329   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
3330   InstructionCost LegalizationCost = 0;
3331   if (LT.first > 1) {
3332     Type *LegalVTy = EVT(LT.second).getTypeForEVT(ValTy->getContext());
3333     LegalizationCost = getArithmeticInstrCost(Opcode, LegalVTy, CostKind);
3334     LegalizationCost *= LT.first - 1;
3335   }
3336 
3337   int ISD = TLI->InstructionOpcodeToISD(Opcode);
3338   assert(ISD && "Invalid opcode");
3339   // Add the final reduction cost for the legal horizontal reduction
3340   switch (ISD) {
3341   case ISD::ADD:
3342   case ISD::AND:
3343   case ISD::OR:
3344   case ISD::XOR:
3345   case ISD::FADD:
3346     return LegalizationCost + 2;
3347   default:
3348     return InstructionCost::getInvalid();
3349   }
3350 }
3351 
3352 InstructionCost
3353 AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
3354                                            std::optional<FastMathFlags> FMF,
3355                                            TTI::TargetCostKind CostKind) {
3356   if (TTI::requiresOrderedReduction(FMF)) {
3357     if (auto *FixedVTy = dyn_cast<FixedVectorType>(ValTy)) {
3358       InstructionCost BaseCost =
3359           BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
3360       // Add on extra cost to reflect the extra overhead on some CPUs. We still
3361       // end up vectorizing for more computationally intensive loops.
3362       return BaseCost + FixedVTy->getNumElements();
3363     }
3364 
3365     if (Opcode != Instruction::FAdd)
3366       return InstructionCost::getInvalid();
3367 
3368     auto *VTy = cast<ScalableVectorType>(ValTy);
3369     InstructionCost Cost =
3370         getArithmeticInstrCost(Opcode, VTy->getScalarType(), CostKind);
3371     Cost *= getMaxNumElements(VTy->getElementCount());
3372     return Cost;
3373   }
3374 
3375   if (isa<ScalableVectorType>(ValTy))
3376     return getArithmeticReductionCostSVE(Opcode, ValTy, CostKind);
3377 
3378   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
3379   MVT MTy = LT.second;
3380   int ISD = TLI->InstructionOpcodeToISD(Opcode);
3381   assert(ISD && "Invalid opcode");
3382 
3383   // Horizontal adds can use the 'addv' instruction. We model the cost of these
3384   // instructions as twice a normal vector add, plus 1 for each legalization
3385   // step (LT.first). This is the only arithmetic vector reduction operation for
3386   // which we have an instruction.
3387   // OR, XOR and AND costs should match the codegen from:
3388   // OR: llvm/test/CodeGen/AArch64/reduce-or.ll
3389   // XOR: llvm/test/CodeGen/AArch64/reduce-xor.ll
3390   // AND: llvm/test/CodeGen/AArch64/reduce-and.ll
3391   static const CostTblEntry CostTblNoPairwise[]{
3392       {ISD::ADD, MVT::v8i8,   2},
3393       {ISD::ADD, MVT::v16i8,  2},
3394       {ISD::ADD, MVT::v4i16,  2},
3395       {ISD::ADD, MVT::v8i16,  2},
3396       {ISD::ADD, MVT::v4i32,  2},
3397       {ISD::ADD, MVT::v2i64,  2},
3398       {ISD::OR,  MVT::v8i8,  15},
3399       {ISD::OR,  MVT::v16i8, 17},
3400       {ISD::OR,  MVT::v4i16,  7},
3401       {ISD::OR,  MVT::v8i16,  9},
3402       {ISD::OR,  MVT::v2i32,  3},
3403       {ISD::OR,  MVT::v4i32,  5},
3404       {ISD::OR,  MVT::v2i64,  3},
3405       {ISD::XOR, MVT::v8i8,  15},
3406       {ISD::XOR, MVT::v16i8, 17},
3407       {ISD::XOR, MVT::v4i16,  7},
3408       {ISD::XOR, MVT::v8i16,  9},
3409       {ISD::XOR, MVT::v2i32,  3},
3410       {ISD::XOR, MVT::v4i32,  5},
3411       {ISD::XOR, MVT::v2i64,  3},
3412       {ISD::AND, MVT::v8i8,  15},
3413       {ISD::AND, MVT::v16i8, 17},
3414       {ISD::AND, MVT::v4i16,  7},
3415       {ISD::AND, MVT::v8i16,  9},
3416       {ISD::AND, MVT::v2i32,  3},
3417       {ISD::AND, MVT::v4i32,  5},
3418       {ISD::AND, MVT::v2i64,  3},
3419   };
3420   switch (ISD) {
3421   default:
3422     break;
3423   case ISD::ADD:
3424     if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy))
3425       return (LT.first - 1) + Entry->Cost;
3426     break;
3427   case ISD::XOR:
3428   case ISD::AND:
3429   case ISD::OR:
3430     const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy);
3431     if (!Entry)
3432       break;
3433     auto *ValVTy = cast<FixedVectorType>(ValTy);
3434     if (MTy.getVectorNumElements() <= ValVTy->getNumElements() &&
3435         isPowerOf2_32(ValVTy->getNumElements())) {
3436       InstructionCost ExtraCost = 0;
3437       if (LT.first != 1) {
3438         // Type needs to be split, so there is an extra cost of LT.first - 1
3439         // arithmetic ops.
3440         auto *Ty = FixedVectorType::get(ValTy->getElementType(),
3441                                         MTy.getVectorNumElements());
3442         ExtraCost = getArithmeticInstrCost(Opcode, Ty, CostKind);
3443         ExtraCost *= LT.first - 1;
3444       }
3445       // All and/or/xor of i1 will be lowered with maxv/minv/addv + fmov
3446       auto Cost = ValVTy->getElementType()->isIntegerTy(1) ? 2 : Entry->Cost;
3447       return Cost + ExtraCost;
3448     }
3449     break;
3450   }
3451   return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
3452 }
3453 
3454 InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
3455   static const CostTblEntry ShuffleTbl[] = {
3456       { TTI::SK_Splice, MVT::nxv16i8,  1 },
3457       { TTI::SK_Splice, MVT::nxv8i16,  1 },
3458       { TTI::SK_Splice, MVT::nxv4i32,  1 },
3459       { TTI::SK_Splice, MVT::nxv2i64,  1 },
3460       { TTI::SK_Splice, MVT::nxv2f16,  1 },
3461       { TTI::SK_Splice, MVT::nxv4f16,  1 },
3462       { TTI::SK_Splice, MVT::nxv8f16,  1 },
3463       { TTI::SK_Splice, MVT::nxv2bf16, 1 },
3464       { TTI::SK_Splice, MVT::nxv4bf16, 1 },
3465       { TTI::SK_Splice, MVT::nxv8bf16, 1 },
3466       { TTI::SK_Splice, MVT::nxv2f32,  1 },
3467       { TTI::SK_Splice, MVT::nxv4f32,  1 },
3468       { TTI::SK_Splice, MVT::nxv2f64,  1 },
3469   };
3470 
3471   // The code-generator is currently not able to handle scalable vectors
3472   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
3473   // it. This change will be removed when code-generation for these types is
3474   // sufficiently reliable.
3475   if (Tp->getElementCount() == ElementCount::getScalable(1))
3476     return InstructionCost::getInvalid();
3477 
3478   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
3479   Type *LegalVTy = EVT(LT.second).getTypeForEVT(Tp->getContext());
3480   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
3481   EVT PromotedVT = LT.second.getScalarType() == MVT::i1
3482                        ? TLI->getPromotedVTForPredicate(EVT(LT.second))
3483                        : LT.second;
3484   Type *PromotedVTy = EVT(PromotedVT).getTypeForEVT(Tp->getContext());
3485   InstructionCost LegalizationCost = 0;
3486   if (Index < 0) {
3487     LegalizationCost =
3488         getCmpSelInstrCost(Instruction::ICmp, PromotedVTy, PromotedVTy,
3489                            CmpInst::BAD_ICMP_PREDICATE, CostKind) +
3490         getCmpSelInstrCost(Instruction::Select, PromotedVTy, LegalVTy,
3491                            CmpInst::BAD_ICMP_PREDICATE, CostKind);
3492   }
3493 
3494   // Predicated splice are promoted when lowering. See AArch64ISelLowering.cpp
3495   // Cost performed on a promoted type.
3496   if (LT.second.getScalarType() == MVT::i1) {
3497     LegalizationCost +=
3498         getCastInstrCost(Instruction::ZExt, PromotedVTy, LegalVTy,
3499                          TTI::CastContextHint::None, CostKind) +
3500         getCastInstrCost(Instruction::Trunc, LegalVTy, PromotedVTy,
3501                          TTI::CastContextHint::None, CostKind);
3502   }
3503   const auto *Entry =
3504       CostTableLookup(ShuffleTbl, TTI::SK_Splice, PromotedVT.getSimpleVT());
3505   assert(Entry && "Illegal Type for Splice");
3506   LegalizationCost += Entry->Cost;
3507   return LegalizationCost * LT.first;
3508 }
3509 
3510 InstructionCost AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
3511                                                VectorType *Tp,
3512                                                ArrayRef<int> Mask,
3513                                                TTI::TargetCostKind CostKind,
3514                                                int Index, VectorType *SubTp,
3515                                                ArrayRef<const Value *> Args) {
3516   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
3517   // If we have a Mask, and the LT is being legalized somehow, split the Mask
3518   // into smaller vectors and sum the cost of each shuffle.
3519   if (!Mask.empty() && isa<FixedVectorType>(Tp) && LT.second.isVector() &&
3520       Tp->getScalarSizeInBits() == LT.second.getScalarSizeInBits() &&
3521       cast<FixedVectorType>(Tp)->getNumElements() >
3522           LT.second.getVectorNumElements() &&
3523       !Index && !SubTp) {
3524     unsigned TpNumElts = cast<FixedVectorType>(Tp)->getNumElements();
3525     assert(Mask.size() == TpNumElts && "Expected Mask and Tp size to match!");
3526     unsigned LTNumElts = LT.second.getVectorNumElements();
3527     unsigned NumVecs = (TpNumElts + LTNumElts - 1) / LTNumElts;
3528     VectorType *NTp =
3529         VectorType::get(Tp->getScalarType(), LT.second.getVectorElementCount());
3530     InstructionCost Cost;
3531     for (unsigned N = 0; N < NumVecs; N++) {
3532       SmallVector<int> NMask;
3533       // Split the existing mask into chunks of size LTNumElts. Track the source
3534       // sub-vectors to ensure the result has at most 2 inputs.
3535       unsigned Source1, Source2;
3536       unsigned NumSources = 0;
3537       for (unsigned E = 0; E < LTNumElts; E++) {
3538         int MaskElt = (N * LTNumElts + E < TpNumElts) ? Mask[N * LTNumElts + E]
3539                                                       : PoisonMaskElem;
3540         if (MaskElt < 0) {
3541           NMask.push_back(PoisonMaskElem);
3542           continue;
3543         }
3544 
3545         // Calculate which source from the input this comes from and whether it
3546         // is new to us.
3547         unsigned Source = MaskElt / LTNumElts;
3548         if (NumSources == 0) {
3549           Source1 = Source;
3550           NumSources = 1;
3551         } else if (NumSources == 1 && Source != Source1) {
3552           Source2 = Source;
3553           NumSources = 2;
3554         } else if (NumSources >= 2 && Source != Source1 && Source != Source2) {
3555           NumSources++;
3556         }
3557 
3558         // Add to the new mask. For the NumSources>2 case these are not correct,
3559         // but are only used for the modular lane number.
3560         if (Source == Source1)
3561           NMask.push_back(MaskElt % LTNumElts);
3562         else if (Source == Source2)
3563           NMask.push_back(MaskElt % LTNumElts + LTNumElts);
3564         else
3565           NMask.push_back(MaskElt % LTNumElts);
3566       }
3567       // If the sub-mask has at most 2 input sub-vectors then re-cost it using
3568       // getShuffleCost. If not then cost it using the worst case.
3569       if (NumSources <= 2)
3570         Cost += getShuffleCost(NumSources <= 1 ? TTI::SK_PermuteSingleSrc
3571                                                : TTI::SK_PermuteTwoSrc,
3572                                NTp, NMask, CostKind, 0, nullptr, Args);
3573       else if (any_of(enumerate(NMask), [&](const auto &ME) {
3574                  return ME.value() % LTNumElts == ME.index();
3575                }))
3576         Cost += LTNumElts - 1;
3577       else
3578         Cost += LTNumElts;
3579     }
3580     return Cost;
3581   }
3582 
3583   Kind = improveShuffleKindFromMask(Kind, Mask);
3584 
3585   // Check for broadcast loads, which are supported by the LD1R instruction.
3586   // In terms of code-size, the shuffle vector is free when a load + dup get
3587   // folded into a LD1R. That's what we check and return here. For performance
3588   // and reciprocal throughput, a LD1R is not completely free. In this case, we
3589   // return the cost for the broadcast below (i.e. 1 for most/all types), so
3590   // that we model the load + dup sequence slightly higher because LD1R is a
3591   // high latency instruction.
3592   if (CostKind == TTI::TCK_CodeSize && Kind == TTI::SK_Broadcast) {
3593     bool IsLoad = !Args.empty() && isa<LoadInst>(Args[0]);
3594     if (IsLoad && LT.second.isVector() &&
3595         isLegalBroadcastLoad(Tp->getElementType(),
3596                              LT.second.getVectorElementCount()))
3597       return 0;
3598   }
3599 
3600   // If we have 4 elements for the shuffle and a Mask, get the cost straight
3601   // from the perfect shuffle tables.
3602   if (Mask.size() == 4 && Tp->getElementCount() == ElementCount::getFixed(4) &&
3603       (Tp->getScalarSizeInBits() == 16 || Tp->getScalarSizeInBits() == 32) &&
3604       all_of(Mask, [](int E) { return E < 8; }))
3605     return getPerfectShuffleCost(Mask);
3606 
3607   if (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Transpose ||
3608       Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc ||
3609       Kind == TTI::SK_Reverse || Kind == TTI::SK_Splice) {
3610     static const CostTblEntry ShuffleTbl[] = {
3611         // Broadcast shuffle kinds can be performed with 'dup'.
3612         {TTI::SK_Broadcast, MVT::v8i8, 1},
3613         {TTI::SK_Broadcast, MVT::v16i8, 1},
3614         {TTI::SK_Broadcast, MVT::v4i16, 1},
3615         {TTI::SK_Broadcast, MVT::v8i16, 1},
3616         {TTI::SK_Broadcast, MVT::v2i32, 1},
3617         {TTI::SK_Broadcast, MVT::v4i32, 1},
3618         {TTI::SK_Broadcast, MVT::v2i64, 1},
3619         {TTI::SK_Broadcast, MVT::v4f16, 1},
3620         {TTI::SK_Broadcast, MVT::v8f16, 1},
3621         {TTI::SK_Broadcast, MVT::v2f32, 1},
3622         {TTI::SK_Broadcast, MVT::v4f32, 1},
3623         {TTI::SK_Broadcast, MVT::v2f64, 1},
3624         // Transpose shuffle kinds can be performed with 'trn1/trn2' and
3625         // 'zip1/zip2' instructions.
3626         {TTI::SK_Transpose, MVT::v8i8, 1},
3627         {TTI::SK_Transpose, MVT::v16i8, 1},
3628         {TTI::SK_Transpose, MVT::v4i16, 1},
3629         {TTI::SK_Transpose, MVT::v8i16, 1},
3630         {TTI::SK_Transpose, MVT::v2i32, 1},
3631         {TTI::SK_Transpose, MVT::v4i32, 1},
3632         {TTI::SK_Transpose, MVT::v2i64, 1},
3633         {TTI::SK_Transpose, MVT::v4f16, 1},
3634         {TTI::SK_Transpose, MVT::v8f16, 1},
3635         {TTI::SK_Transpose, MVT::v2f32, 1},
3636         {TTI::SK_Transpose, MVT::v4f32, 1},
3637         {TTI::SK_Transpose, MVT::v2f64, 1},
3638         // Select shuffle kinds.
3639         // TODO: handle vXi8/vXi16.
3640         {TTI::SK_Select, MVT::v2i32, 1}, // mov.
3641         {TTI::SK_Select, MVT::v4i32, 2}, // rev+trn (or similar).
3642         {TTI::SK_Select, MVT::v2i64, 1}, // mov.
3643         {TTI::SK_Select, MVT::v2f32, 1}, // mov.
3644         {TTI::SK_Select, MVT::v4f32, 2}, // rev+trn (or similar).
3645         {TTI::SK_Select, MVT::v2f64, 1}, // mov.
3646         // PermuteSingleSrc shuffle kinds.
3647         {TTI::SK_PermuteSingleSrc, MVT::v2i32, 1}, // mov.
3648         {TTI::SK_PermuteSingleSrc, MVT::v4i32, 3}, // perfectshuffle worst case.
3649         {TTI::SK_PermuteSingleSrc, MVT::v2i64, 1}, // mov.
3650         {TTI::SK_PermuteSingleSrc, MVT::v2f32, 1}, // mov.
3651         {TTI::SK_PermuteSingleSrc, MVT::v4f32, 3}, // perfectshuffle worst case.
3652         {TTI::SK_PermuteSingleSrc, MVT::v2f64, 1}, // mov.
3653         {TTI::SK_PermuteSingleSrc, MVT::v4i16, 3}, // perfectshuffle worst case.
3654         {TTI::SK_PermuteSingleSrc, MVT::v4f16, 3}, // perfectshuffle worst case.
3655         {TTI::SK_PermuteSingleSrc, MVT::v4bf16, 3}, // same
3656         {TTI::SK_PermuteSingleSrc, MVT::v8i16, 8},  // constpool + load + tbl
3657         {TTI::SK_PermuteSingleSrc, MVT::v8f16, 8},  // constpool + load + tbl
3658         {TTI::SK_PermuteSingleSrc, MVT::v8bf16, 8}, // constpool + load + tbl
3659         {TTI::SK_PermuteSingleSrc, MVT::v8i8, 8},   // constpool + load + tbl
3660         {TTI::SK_PermuteSingleSrc, MVT::v16i8, 8},  // constpool + load + tbl
3661         // Reverse can be lowered with `rev`.
3662         {TTI::SK_Reverse, MVT::v2i32, 1}, // REV64
3663         {TTI::SK_Reverse, MVT::v4i32, 2}, // REV64; EXT
3664         {TTI::SK_Reverse, MVT::v2i64, 1}, // EXT
3665         {TTI::SK_Reverse, MVT::v2f32, 1}, // REV64
3666         {TTI::SK_Reverse, MVT::v4f32, 2}, // REV64; EXT
3667         {TTI::SK_Reverse, MVT::v2f64, 1}, // EXT
3668         {TTI::SK_Reverse, MVT::v8f16, 2}, // REV64; EXT
3669         {TTI::SK_Reverse, MVT::v8i16, 2}, // REV64; EXT
3670         {TTI::SK_Reverse, MVT::v16i8, 2}, // REV64; EXT
3671         {TTI::SK_Reverse, MVT::v4f16, 1}, // REV64
3672         {TTI::SK_Reverse, MVT::v4i16, 1}, // REV64
3673         {TTI::SK_Reverse, MVT::v8i8, 1},  // REV64
3674         // Splice can all be lowered as `ext`.
3675         {TTI::SK_Splice, MVT::v2i32, 1},
3676         {TTI::SK_Splice, MVT::v4i32, 1},
3677         {TTI::SK_Splice, MVT::v2i64, 1},
3678         {TTI::SK_Splice, MVT::v2f32, 1},
3679         {TTI::SK_Splice, MVT::v4f32, 1},
3680         {TTI::SK_Splice, MVT::v2f64, 1},
3681         {TTI::SK_Splice, MVT::v8f16, 1},
3682         {TTI::SK_Splice, MVT::v8bf16, 1},
3683         {TTI::SK_Splice, MVT::v8i16, 1},
3684         {TTI::SK_Splice, MVT::v16i8, 1},
3685         {TTI::SK_Splice, MVT::v4bf16, 1},
3686         {TTI::SK_Splice, MVT::v4f16, 1},
3687         {TTI::SK_Splice, MVT::v4i16, 1},
3688         {TTI::SK_Splice, MVT::v8i8, 1},
3689         // Broadcast shuffle kinds for scalable vectors
3690         {TTI::SK_Broadcast, MVT::nxv16i8, 1},
3691         {TTI::SK_Broadcast, MVT::nxv8i16, 1},
3692         {TTI::SK_Broadcast, MVT::nxv4i32, 1},
3693         {TTI::SK_Broadcast, MVT::nxv2i64, 1},
3694         {TTI::SK_Broadcast, MVT::nxv2f16, 1},
3695         {TTI::SK_Broadcast, MVT::nxv4f16, 1},
3696         {TTI::SK_Broadcast, MVT::nxv8f16, 1},
3697         {TTI::SK_Broadcast, MVT::nxv2bf16, 1},
3698         {TTI::SK_Broadcast, MVT::nxv4bf16, 1},
3699         {TTI::SK_Broadcast, MVT::nxv8bf16, 1},
3700         {TTI::SK_Broadcast, MVT::nxv2f32, 1},
3701         {TTI::SK_Broadcast, MVT::nxv4f32, 1},
3702         {TTI::SK_Broadcast, MVT::nxv2f64, 1},
3703         {TTI::SK_Broadcast, MVT::nxv16i1, 1},
3704         {TTI::SK_Broadcast, MVT::nxv8i1, 1},
3705         {TTI::SK_Broadcast, MVT::nxv4i1, 1},
3706         {TTI::SK_Broadcast, MVT::nxv2i1, 1},
3707         // Handle the cases for vector.reverse with scalable vectors
3708         {TTI::SK_Reverse, MVT::nxv16i8, 1},
3709         {TTI::SK_Reverse, MVT::nxv8i16, 1},
3710         {TTI::SK_Reverse, MVT::nxv4i32, 1},
3711         {TTI::SK_Reverse, MVT::nxv2i64, 1},
3712         {TTI::SK_Reverse, MVT::nxv2f16, 1},
3713         {TTI::SK_Reverse, MVT::nxv4f16, 1},
3714         {TTI::SK_Reverse, MVT::nxv8f16, 1},
3715         {TTI::SK_Reverse, MVT::nxv2bf16, 1},
3716         {TTI::SK_Reverse, MVT::nxv4bf16, 1},
3717         {TTI::SK_Reverse, MVT::nxv8bf16, 1},
3718         {TTI::SK_Reverse, MVT::nxv2f32, 1},
3719         {TTI::SK_Reverse, MVT::nxv4f32, 1},
3720         {TTI::SK_Reverse, MVT::nxv2f64, 1},
3721         {TTI::SK_Reverse, MVT::nxv16i1, 1},
3722         {TTI::SK_Reverse, MVT::nxv8i1, 1},
3723         {TTI::SK_Reverse, MVT::nxv4i1, 1},
3724         {TTI::SK_Reverse, MVT::nxv2i1, 1},
3725     };
3726     if (const auto *Entry = CostTableLookup(ShuffleTbl, Kind, LT.second))
3727       return LT.first * Entry->Cost;
3728   }
3729 
3730   if (Kind == TTI::SK_Splice && isa<ScalableVectorType>(Tp))
3731     return getSpliceCost(Tp, Index);
3732 
3733   // Inserting a subvector can often be done with either a D, S or H register
3734   // move, so long as the inserted vector is "aligned".
3735   if (Kind == TTI::SK_InsertSubvector && LT.second.isFixedLengthVector() &&
3736       LT.second.getSizeInBits() <= 128 && SubTp) {
3737     std::pair<InstructionCost, MVT> SubLT = getTypeLegalizationCost(SubTp);
3738     if (SubLT.second.isVector()) {
3739       int NumElts = LT.second.getVectorNumElements();
3740       int NumSubElts = SubLT.second.getVectorNumElements();
3741       if ((Index % NumSubElts) == 0 && (NumElts % NumSubElts) == 0)
3742         return SubLT.first;
3743     }
3744   }
3745 
3746   return BaseT::getShuffleCost(Kind, Tp, Mask, CostKind, Index, SubTp);
3747 }
3748 
3749 static bool containsDecreasingPointers(Loop *TheLoop,
3750                                        PredicatedScalarEvolution *PSE) {
3751   const auto &Strides = DenseMap<Value *, const SCEV *>();
3752   for (BasicBlock *BB : TheLoop->blocks()) {
3753     // Scan the instructions in the block and look for addresses that are
3754     // consecutive and decreasing.
3755     for (Instruction &I : *BB) {
3756       if (isa<LoadInst>(&I) || isa<StoreInst>(&I)) {
3757         Value *Ptr = getLoadStorePointerOperand(&I);
3758         Type *AccessTy = getLoadStoreType(&I);
3759         if (getPtrStride(*PSE, AccessTy, Ptr, TheLoop, Strides, /*Assume=*/true,
3760                          /*ShouldCheckWrap=*/false)
3761                 .value_or(0) < 0)
3762           return true;
3763       }
3764     }
3765   }
3766   return false;
3767 }
3768 
3769 bool AArch64TTIImpl::preferPredicateOverEpilogue(TailFoldingInfo *TFI) {
3770   if (!ST->hasSVE())
3771     return false;
3772 
3773   // We don't currently support vectorisation with interleaving for SVE - with
3774   // such loops we're better off not using tail-folding. This gives us a chance
3775   // to fall back on fixed-width vectorisation using NEON's ld2/st2/etc.
3776   if (TFI->IAI->hasGroups())
3777     return false;
3778 
3779   TailFoldingOpts Required = TailFoldingOpts::Disabled;
3780   if (TFI->LVL->getReductionVars().size())
3781     Required |= TailFoldingOpts::Reductions;
3782   if (TFI->LVL->getFixedOrderRecurrences().size())
3783     Required |= TailFoldingOpts::Recurrences;
3784 
3785   // We call this to discover whether any load/store pointers in the loop have
3786   // negative strides. This will require extra work to reverse the loop
3787   // predicate, which may be expensive.
3788   if (containsDecreasingPointers(TFI->LVL->getLoop(),
3789                                  TFI->LVL->getPredicatedScalarEvolution()))
3790     Required |= TailFoldingOpts::Reverse;
3791   if (Required == TailFoldingOpts::Disabled)
3792     Required |= TailFoldingOpts::Simple;
3793 
3794   if (!TailFoldingOptionLoc.satisfies(ST->getSVETailFoldingDefaultOpts(),
3795                                       Required))
3796     return false;
3797 
3798   // Don't tail-fold for tight loops where we would be better off interleaving
3799   // with an unpredicated loop.
3800   unsigned NumInsns = 0;
3801   for (BasicBlock *BB : TFI->LVL->getLoop()->blocks()) {
3802     NumInsns += BB->sizeWithoutDebug();
3803   }
3804 
3805   // We expect 4 of these to be a IV PHI, IV add, IV compare and branch.
3806   return NumInsns >= SVETailFoldInsnThreshold;
3807 }
3808 
3809 InstructionCost
3810 AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
3811                                      int64_t BaseOffset, bool HasBaseReg,
3812                                      int64_t Scale, unsigned AddrSpace) const {
3813   // Scaling factors are not free at all.
3814   // Operands                     | Rt Latency
3815   // -------------------------------------------
3816   // Rt, [Xn, Xm]                 | 4
3817   // -------------------------------------------
3818   // Rt, [Xn, Xm, lsl #imm]       | Rn: 4 Rm: 5
3819   // Rt, [Xn, Wm, <extend> #imm]  |
3820   TargetLoweringBase::AddrMode AM;
3821   AM.BaseGV = BaseGV;
3822   AM.BaseOffs = BaseOffset;
3823   AM.HasBaseReg = HasBaseReg;
3824   AM.Scale = Scale;
3825   if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
3826     // Scale represents reg2 * scale, thus account for 1 if
3827     // it is not equal to 0 or 1.
3828     return AM.Scale != 0 && AM.Scale != 1;
3829   return -1;
3830 }
3831