xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (revision 46c59ea9b61755455ff6bf9f3e7b834e1af634ea)
1 //===-- AArch64TargetTransformInfo.cpp - AArch64 specific TTI -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "AArch64TargetTransformInfo.h"
10 #include "AArch64ExpandImm.h"
11 #include "AArch64PerfectShuffle.h"
12 #include "MCTargetDesc/AArch64AddressingModes.h"
13 #include "llvm/Analysis/IVDescriptors.h"
14 #include "llvm/Analysis/LoopInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/CodeGen/BasicTTIImpl.h"
17 #include "llvm/CodeGen/CostTable.h"
18 #include "llvm/CodeGen/TargetLowering.h"
19 #include "llvm/IR/IntrinsicInst.h"
20 #include "llvm/IR/Intrinsics.h"
21 #include "llvm/IR/IntrinsicsAArch64.h"
22 #include "llvm/IR/PatternMatch.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Transforms/InstCombine/InstCombiner.h"
25 #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
26 #include <algorithm>
27 #include <optional>
28 using namespace llvm;
29 using namespace llvm::PatternMatch;
30 
31 #define DEBUG_TYPE "aarch64tti"
32 
33 static cl::opt<bool> EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix",
34                                                cl::init(true), cl::Hidden);
35 
36 static cl::opt<unsigned> SVEGatherOverhead("sve-gather-overhead", cl::init(10),
37                                            cl::Hidden);
38 
39 static cl::opt<unsigned> SVEScatterOverhead("sve-scatter-overhead",
40                                             cl::init(10), cl::Hidden);
41 
42 static cl::opt<unsigned> SVETailFoldInsnThreshold("sve-tail-folding-insn-threshold",
43                                                   cl::init(15), cl::Hidden);
44 
45 static cl::opt<unsigned>
46     NeonNonConstStrideOverhead("neon-nonconst-stride-overhead", cl::init(10),
47                                cl::Hidden);
48 
49 static cl::opt<unsigned> CallPenaltyChangeSM(
50     "call-penalty-sm-change", cl::init(5), cl::Hidden,
51     cl::desc(
52         "Penalty of calling a function that requires a change to PSTATE.SM"));
53 
54 static cl::opt<unsigned> InlineCallPenaltyChangeSM(
55     "inline-call-penalty-sm-change", cl::init(10), cl::Hidden,
56     cl::desc("Penalty of inlining a call that requires a change to PSTATE.SM"));
57 
58 namespace {
59 class TailFoldingOption {
60   // These bitfields will only ever be set to something non-zero in operator=,
61   // when setting the -sve-tail-folding option. This option should always be of
62   // the form (default|simple|all|disable)[+(Flag1|Flag2|etc)], where here
63   // InitialBits is one of (disabled|all|simple). EnableBits represents
64   // additional flags we're enabling, and DisableBits for those flags we're
65   // disabling. The default flag is tracked in the variable NeedsDefault, since
66   // at the time of setting the option we may not know what the default value
67   // for the CPU is.
68   TailFoldingOpts InitialBits = TailFoldingOpts::Disabled;
69   TailFoldingOpts EnableBits = TailFoldingOpts::Disabled;
70   TailFoldingOpts DisableBits = TailFoldingOpts::Disabled;
71 
72   // This value needs to be initialised to true in case the user does not
73   // explicitly set the -sve-tail-folding option.
74   bool NeedsDefault = true;
75 
76   void setInitialBits(TailFoldingOpts Bits) { InitialBits = Bits; }
77 
78   void setNeedsDefault(bool V) { NeedsDefault = V; }
79 
80   void setEnableBit(TailFoldingOpts Bit) {
81     EnableBits |= Bit;
82     DisableBits &= ~Bit;
83   }
84 
85   void setDisableBit(TailFoldingOpts Bit) {
86     EnableBits &= ~Bit;
87     DisableBits |= Bit;
88   }
89 
90   TailFoldingOpts getBits(TailFoldingOpts DefaultBits) const {
91     TailFoldingOpts Bits = TailFoldingOpts::Disabled;
92 
93     assert((InitialBits == TailFoldingOpts::Disabled || !NeedsDefault) &&
94            "Initial bits should only include one of "
95            "(disabled|all|simple|default)");
96     Bits = NeedsDefault ? DefaultBits : InitialBits;
97     Bits |= EnableBits;
98     Bits &= ~DisableBits;
99 
100     return Bits;
101   }
102 
103   void reportError(std::string Opt) {
104     errs() << "invalid argument '" << Opt
105            << "' to -sve-tail-folding=; the option should be of the form\n"
106               "  (disabled|all|default|simple)[+(reductions|recurrences"
107               "|reverse|noreductions|norecurrences|noreverse)]\n";
108     report_fatal_error("Unrecognised tail-folding option");
109   }
110 
111 public:
112 
113   void operator=(const std::string &Val) {
114     // If the user explicitly sets -sve-tail-folding= then treat as an error.
115     if (Val.empty()) {
116       reportError("");
117       return;
118     }
119 
120     // Since the user is explicitly setting the option we don't automatically
121     // need the default unless they require it.
122     setNeedsDefault(false);
123 
124     SmallVector<StringRef, 4> TailFoldTypes;
125     StringRef(Val).split(TailFoldTypes, '+', -1, false);
126 
127     unsigned StartIdx = 1;
128     if (TailFoldTypes[0] == "disabled")
129       setInitialBits(TailFoldingOpts::Disabled);
130     else if (TailFoldTypes[0] == "all")
131       setInitialBits(TailFoldingOpts::All);
132     else if (TailFoldTypes[0] == "default")
133       setNeedsDefault(true);
134     else if (TailFoldTypes[0] == "simple")
135       setInitialBits(TailFoldingOpts::Simple);
136     else {
137       StartIdx = 0;
138       setInitialBits(TailFoldingOpts::Disabled);
139     }
140 
141     for (unsigned I = StartIdx; I < TailFoldTypes.size(); I++) {
142       if (TailFoldTypes[I] == "reductions")
143         setEnableBit(TailFoldingOpts::Reductions);
144       else if (TailFoldTypes[I] == "recurrences")
145         setEnableBit(TailFoldingOpts::Recurrences);
146       else if (TailFoldTypes[I] == "reverse")
147         setEnableBit(TailFoldingOpts::Reverse);
148       else if (TailFoldTypes[I] == "noreductions")
149         setDisableBit(TailFoldingOpts::Reductions);
150       else if (TailFoldTypes[I] == "norecurrences")
151         setDisableBit(TailFoldingOpts::Recurrences);
152       else if (TailFoldTypes[I] == "noreverse")
153         setDisableBit(TailFoldingOpts::Reverse);
154       else
155         reportError(Val);
156     }
157   }
158 
159   bool satisfies(TailFoldingOpts DefaultBits, TailFoldingOpts Required) const {
160     return (getBits(DefaultBits) & Required) == Required;
161   }
162 };
163 } // namespace
164 
165 TailFoldingOption TailFoldingOptionLoc;
166 
167 cl::opt<TailFoldingOption, true, cl::parser<std::string>> SVETailFolding(
168     "sve-tail-folding",
169     cl::desc(
170         "Control the use of vectorisation using tail-folding for SVE where the"
171         " option is specified in the form (Initial)[+(Flag1|Flag2|...)]:"
172         "\ndisabled      (Initial) No loop types will vectorize using "
173         "tail-folding"
174         "\ndefault       (Initial) Uses the default tail-folding settings for "
175         "the target CPU"
176         "\nall           (Initial) All legal loop types will vectorize using "
177         "tail-folding"
178         "\nsimple        (Initial) Use tail-folding for simple loops (not "
179         "reductions or recurrences)"
180         "\nreductions    Use tail-folding for loops containing reductions"
181         "\nnoreductions  Inverse of above"
182         "\nrecurrences   Use tail-folding for loops containing fixed order "
183         "recurrences"
184         "\nnorecurrences Inverse of above"
185         "\nreverse       Use tail-folding for loops requiring reversed "
186         "predicates"
187         "\nnoreverse     Inverse of above"),
188     cl::location(TailFoldingOptionLoc));
189 
190 // Experimental option that will only be fully functional when the
191 // code-generator is changed to use SVE instead of NEON for all fixed-width
192 // operations.
193 static cl::opt<bool> EnableFixedwidthAutovecInStreamingMode(
194     "enable-fixedwidth-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
195 
196 // Experimental option that will only be fully functional when the cost-model
197 // and code-generator have been changed to avoid using scalable vector
198 // instructions that are not legal in streaming SVE mode.
199 static cl::opt<bool> EnableScalableAutovecInStreamingMode(
200     "enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
201 
202 static bool isSMEABIRoutineCall(const CallInst &CI) {
203   const auto *F = CI.getCalledFunction();
204   return F && StringSwitch<bool>(F->getName())
205                   .Case("__arm_sme_state", true)
206                   .Case("__arm_tpidr2_save", true)
207                   .Case("__arm_tpidr2_restore", true)
208                   .Case("__arm_za_disable", true)
209                   .Default(false);
210 }
211 
212 /// Returns true if the function has explicit operations that can only be
213 /// lowered using incompatible instructions for the selected mode. This also
214 /// returns true if the function F may use or modify ZA state.
215 static bool hasPossibleIncompatibleOps(const Function *F) {
216   for (const BasicBlock &BB : *F) {
217     for (const Instruction &I : BB) {
218       // Be conservative for now and assume that any call to inline asm or to
219       // intrinsics could could result in non-streaming ops (e.g. calls to
220       // @llvm.aarch64.* or @llvm.gather/scatter intrinsics). We can assume that
221       // all native LLVM instructions can be lowered to compatible instructions.
222       if (isa<CallInst>(I) && !I.isDebugOrPseudoInst() &&
223           (cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
224            isSMEABIRoutineCall(cast<CallInst>(I))))
225         return true;
226     }
227   }
228   return false;
229 }
230 
231 bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
232                                          const Function *Callee) const {
233   SMEAttrs CallerAttrs(*Caller);
234   SMEAttrs CalleeAttrs(*Callee);
235   if (CalleeAttrs.hasNewZABody())
236     return false;
237 
238   if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
239       CallerAttrs.requiresSMChange(CalleeAttrs,
240                                    /*BodyOverridesInterface=*/true)) {
241     if (hasPossibleIncompatibleOps(Callee))
242       return false;
243   }
244 
245   const TargetMachine &TM = getTLI()->getTargetMachine();
246 
247   const FeatureBitset &CallerBits =
248       TM.getSubtargetImpl(*Caller)->getFeatureBits();
249   const FeatureBitset &CalleeBits =
250       TM.getSubtargetImpl(*Callee)->getFeatureBits();
251 
252   // Inline a callee if its target-features are a subset of the callers
253   // target-features.
254   return (CallerBits & CalleeBits) == CalleeBits;
255 }
256 
257 bool AArch64TTIImpl::areTypesABICompatible(
258     const Function *Caller, const Function *Callee,
259     const ArrayRef<Type *> &Types) const {
260   if (!BaseT::areTypesABICompatible(Caller, Callee, Types))
261     return false;
262 
263   // We need to ensure that argument promotion does not attempt to promote
264   // pointers to fixed-length vector types larger than 128 bits like
265   // <8 x float> (and pointers to aggregate types which have such fixed-length
266   // vector type members) into the values of the pointees. Such vector types
267   // are used for SVE VLS but there is no ABI for SVE VLS arguments and the
268   // backend cannot lower such value arguments. The 128-bit fixed-length SVE
269   // types can be safely treated as 128-bit NEON types and they cannot be
270   // distinguished in IR.
271   if (ST->useSVEForFixedLengthVectors() && llvm::any_of(Types, [](Type *Ty) {
272         auto FVTy = dyn_cast<FixedVectorType>(Ty);
273         return FVTy &&
274                FVTy->getScalarSizeInBits() * FVTy->getNumElements() > 128;
275       }))
276     return false;
277 
278   return true;
279 }
280 
281 unsigned
282 AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
283                                      unsigned DefaultCallPenalty) const {
284   // This function calculates a penalty for executing Call in F.
285   //
286   // There are two ways this function can be called:
287   // (1)  F:
288   //       call from F -> G (the call here is Call)
289   //
290   // For (1), Call.getCaller() == F, so it will always return a high cost if
291   // a streaming-mode change is required (thus promoting the need to inline the
292   // function)
293   //
294   // (2)  F:
295   //       call from F -> G (the call here is not Call)
296   //      G:
297   //       call from G -> H (the call here is Call)
298   //
299   // For (2), if after inlining the body of G into F the call to H requires a
300   // streaming-mode change, and the call to G from F would also require a
301   // streaming-mode change, then there is benefit to do the streaming-mode
302   // change only once and avoid inlining of G into F.
303   SMEAttrs FAttrs(*F);
304   SMEAttrs CalleeAttrs(Call);
305   if (FAttrs.requiresSMChange(CalleeAttrs)) {
306     if (F == Call.getCaller()) // (1)
307       return CallPenaltyChangeSM * DefaultCallPenalty;
308     if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2)
309       return InlineCallPenaltyChangeSM * DefaultCallPenalty;
310   }
311 
312   return DefaultCallPenalty;
313 }
314 
315 bool AArch64TTIImpl::shouldMaximizeVectorBandwidth(
316     TargetTransformInfo::RegisterKind K) const {
317   assert(K != TargetTransformInfo::RGK_Scalar);
318   return (K == TargetTransformInfo::RGK_FixedWidthVector &&
319           ST->isNeonAvailable());
320 }
321 
322 /// Calculate the cost of materializing a 64-bit value. This helper
323 /// method might only calculate a fraction of a larger immediate. Therefore it
324 /// is valid to return a cost of ZERO.
325 InstructionCost AArch64TTIImpl::getIntImmCost(int64_t Val) {
326   // Check if the immediate can be encoded within an instruction.
327   if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64))
328     return 0;
329 
330   if (Val < 0)
331     Val = ~Val;
332 
333   // Calculate how many moves we will need to materialize this constant.
334   SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn;
335   AArch64_IMM::expandMOVImm(Val, 64, Insn);
336   return Insn.size();
337 }
338 
339 /// Calculate the cost of materializing the given constant.
340 InstructionCost AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
341                                               TTI::TargetCostKind CostKind) {
342   assert(Ty->isIntegerTy());
343 
344   unsigned BitSize = Ty->getPrimitiveSizeInBits();
345   if (BitSize == 0)
346     return ~0U;
347 
348   // Sign-extend all constants to a multiple of 64-bit.
349   APInt ImmVal = Imm;
350   if (BitSize & 0x3f)
351     ImmVal = Imm.sext((BitSize + 63) & ~0x3fU);
352 
353   // Split the constant into 64-bit chunks and calculate the cost for each
354   // chunk.
355   InstructionCost Cost = 0;
356   for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) {
357     APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64);
358     int64_t Val = Tmp.getSExtValue();
359     Cost += getIntImmCost(Val);
360   }
361   // We need at least one instruction to materialze the constant.
362   return std::max<InstructionCost>(1, Cost);
363 }
364 
365 InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
366                                                   const APInt &Imm, Type *Ty,
367                                                   TTI::TargetCostKind CostKind,
368                                                   Instruction *Inst) {
369   assert(Ty->isIntegerTy());
370 
371   unsigned BitSize = Ty->getPrimitiveSizeInBits();
372   // There is no cost model for constants with a bit size of 0. Return TCC_Free
373   // here, so that constant hoisting will ignore this constant.
374   if (BitSize == 0)
375     return TTI::TCC_Free;
376 
377   unsigned ImmIdx = ~0U;
378   switch (Opcode) {
379   default:
380     return TTI::TCC_Free;
381   case Instruction::GetElementPtr:
382     // Always hoist the base address of a GetElementPtr.
383     if (Idx == 0)
384       return 2 * TTI::TCC_Basic;
385     return TTI::TCC_Free;
386   case Instruction::Store:
387     ImmIdx = 0;
388     break;
389   case Instruction::Add:
390   case Instruction::Sub:
391   case Instruction::Mul:
392   case Instruction::UDiv:
393   case Instruction::SDiv:
394   case Instruction::URem:
395   case Instruction::SRem:
396   case Instruction::And:
397   case Instruction::Or:
398   case Instruction::Xor:
399   case Instruction::ICmp:
400     ImmIdx = 1;
401     break;
402   // Always return TCC_Free for the shift value of a shift instruction.
403   case Instruction::Shl:
404   case Instruction::LShr:
405   case Instruction::AShr:
406     if (Idx == 1)
407       return TTI::TCC_Free;
408     break;
409   case Instruction::Trunc:
410   case Instruction::ZExt:
411   case Instruction::SExt:
412   case Instruction::IntToPtr:
413   case Instruction::PtrToInt:
414   case Instruction::BitCast:
415   case Instruction::PHI:
416   case Instruction::Call:
417   case Instruction::Select:
418   case Instruction::Ret:
419   case Instruction::Load:
420     break;
421   }
422 
423   if (Idx == ImmIdx) {
424     int NumConstants = (BitSize + 63) / 64;
425     InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
426     return (Cost <= NumConstants * TTI::TCC_Basic)
427                ? static_cast<int>(TTI::TCC_Free)
428                : Cost;
429   }
430   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
431 }
432 
433 InstructionCost
434 AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
435                                     const APInt &Imm, Type *Ty,
436                                     TTI::TargetCostKind CostKind) {
437   assert(Ty->isIntegerTy());
438 
439   unsigned BitSize = Ty->getPrimitiveSizeInBits();
440   // There is no cost model for constants with a bit size of 0. Return TCC_Free
441   // here, so that constant hoisting will ignore this constant.
442   if (BitSize == 0)
443     return TTI::TCC_Free;
444 
445   // Most (all?) AArch64 intrinsics do not support folding immediates into the
446   // selected instruction, so we compute the materialization cost for the
447   // immediate directly.
448   if (IID >= Intrinsic::aarch64_addg && IID <= Intrinsic::aarch64_udiv)
449     return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
450 
451   switch (IID) {
452   default:
453     return TTI::TCC_Free;
454   case Intrinsic::sadd_with_overflow:
455   case Intrinsic::uadd_with_overflow:
456   case Intrinsic::ssub_with_overflow:
457   case Intrinsic::usub_with_overflow:
458   case Intrinsic::smul_with_overflow:
459   case Intrinsic::umul_with_overflow:
460     if (Idx == 1) {
461       int NumConstants = (BitSize + 63) / 64;
462       InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
463       return (Cost <= NumConstants * TTI::TCC_Basic)
464                  ? static_cast<int>(TTI::TCC_Free)
465                  : Cost;
466     }
467     break;
468   case Intrinsic::experimental_stackmap:
469     if ((Idx < 2) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
470       return TTI::TCC_Free;
471     break;
472   case Intrinsic::experimental_patchpoint_void:
473   case Intrinsic::experimental_patchpoint_i64:
474     if ((Idx < 4) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
475       return TTI::TCC_Free;
476     break;
477   case Intrinsic::experimental_gc_statepoint:
478     if ((Idx < 5) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
479       return TTI::TCC_Free;
480     break;
481   }
482   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
483 }
484 
485 TargetTransformInfo::PopcntSupportKind
486 AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) {
487   assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2");
488   if (TyWidth == 32 || TyWidth == 64)
489     return TTI::PSK_FastHardware;
490   // TODO: AArch64TargetLowering::LowerCTPOP() supports 128bit popcount.
491   return TTI::PSK_Software;
492 }
493 
494 InstructionCost
495 AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
496                                       TTI::TargetCostKind CostKind) {
497   auto *RetTy = ICA.getReturnType();
498   switch (ICA.getID()) {
499   case Intrinsic::umin:
500   case Intrinsic::umax:
501   case Intrinsic::smin:
502   case Intrinsic::smax: {
503     static const auto ValidMinMaxTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
504                                         MVT::v8i16, MVT::v2i32, MVT::v4i32,
505                                         MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32,
506                                         MVT::nxv2i64};
507     auto LT = getTypeLegalizationCost(RetTy);
508     // v2i64 types get converted to cmp+bif hence the cost of 2
509     if (LT.second == MVT::v2i64)
510       return LT.first * 2;
511     if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }))
512       return LT.first;
513     break;
514   }
515   case Intrinsic::sadd_sat:
516   case Intrinsic::ssub_sat:
517   case Intrinsic::uadd_sat:
518   case Intrinsic::usub_sat: {
519     static const auto ValidSatTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
520                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
521                                      MVT::v2i64};
522     auto LT = getTypeLegalizationCost(RetTy);
523     // This is a base cost of 1 for the vadd, plus 3 extract shifts if we
524     // need to extend the type, as it uses shr(qadd(shl, shl)).
525     unsigned Instrs =
526         LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4;
527     if (any_of(ValidSatTys, [&LT](MVT M) { return M == LT.second; }))
528       return LT.first * Instrs;
529     break;
530   }
531   case Intrinsic::abs: {
532     static const auto ValidAbsTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
533                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
534                                      MVT::v2i64};
535     auto LT = getTypeLegalizationCost(RetTy);
536     if (any_of(ValidAbsTys, [&LT](MVT M) { return M == LT.second; }))
537       return LT.first;
538     break;
539   }
540   case Intrinsic::bswap: {
541     static const auto ValidAbsTys = {MVT::v4i16, MVT::v8i16, MVT::v2i32,
542                                      MVT::v4i32, MVT::v2i64};
543     auto LT = getTypeLegalizationCost(RetTy);
544     if (any_of(ValidAbsTys, [&LT](MVT M) { return M == LT.second; }) &&
545         LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits())
546       return LT.first;
547     break;
548   }
549   case Intrinsic::experimental_stepvector: {
550     InstructionCost Cost = 1; // Cost of the `index' instruction
551     auto LT = getTypeLegalizationCost(RetTy);
552     // Legalisation of illegal vectors involves an `index' instruction plus
553     // (LT.first - 1) vector adds.
554     if (LT.first > 1) {
555       Type *LegalVTy = EVT(LT.second).getTypeForEVT(RetTy->getContext());
556       InstructionCost AddCost =
557           getArithmeticInstrCost(Instruction::Add, LegalVTy, CostKind);
558       Cost += AddCost * (LT.first - 1);
559     }
560     return Cost;
561   }
562   case Intrinsic::bitreverse: {
563     static const CostTblEntry BitreverseTbl[] = {
564         {Intrinsic::bitreverse, MVT::i32, 1},
565         {Intrinsic::bitreverse, MVT::i64, 1},
566         {Intrinsic::bitreverse, MVT::v8i8, 1},
567         {Intrinsic::bitreverse, MVT::v16i8, 1},
568         {Intrinsic::bitreverse, MVT::v4i16, 2},
569         {Intrinsic::bitreverse, MVT::v8i16, 2},
570         {Intrinsic::bitreverse, MVT::v2i32, 2},
571         {Intrinsic::bitreverse, MVT::v4i32, 2},
572         {Intrinsic::bitreverse, MVT::v1i64, 2},
573         {Intrinsic::bitreverse, MVT::v2i64, 2},
574     };
575     const auto LegalisationCost = getTypeLegalizationCost(RetTy);
576     const auto *Entry =
577         CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second);
578     if (Entry) {
579       // Cost Model is using the legal type(i32) that i8 and i16 will be
580       // converted to +1 so that we match the actual lowering cost
581       if (TLI->getValueType(DL, RetTy, true) == MVT::i8 ||
582           TLI->getValueType(DL, RetTy, true) == MVT::i16)
583         return LegalisationCost.first * Entry->Cost + 1;
584 
585       return LegalisationCost.first * Entry->Cost;
586     }
587     break;
588   }
589   case Intrinsic::ctpop: {
590     if (!ST->hasNEON()) {
591       // 32-bit or 64-bit ctpop without NEON is 12 instructions.
592       return getTypeLegalizationCost(RetTy).first * 12;
593     }
594     static const CostTblEntry CtpopCostTbl[] = {
595         {ISD::CTPOP, MVT::v2i64, 4},
596         {ISD::CTPOP, MVT::v4i32, 3},
597         {ISD::CTPOP, MVT::v8i16, 2},
598         {ISD::CTPOP, MVT::v16i8, 1},
599         {ISD::CTPOP, MVT::i64,   4},
600         {ISD::CTPOP, MVT::v2i32, 3},
601         {ISD::CTPOP, MVT::v4i16, 2},
602         {ISD::CTPOP, MVT::v8i8,  1},
603         {ISD::CTPOP, MVT::i32,   5},
604     };
605     auto LT = getTypeLegalizationCost(RetTy);
606     MVT MTy = LT.second;
607     if (const auto *Entry = CostTableLookup(CtpopCostTbl, ISD::CTPOP, MTy)) {
608       // Extra cost of +1 when illegal vector types are legalized by promoting
609       // the integer type.
610       int ExtraCost = MTy.isVector() && MTy.getScalarSizeInBits() !=
611                                             RetTy->getScalarSizeInBits()
612                           ? 1
613                           : 0;
614       return LT.first * Entry->Cost + ExtraCost;
615     }
616     break;
617   }
618   case Intrinsic::sadd_with_overflow:
619   case Intrinsic::uadd_with_overflow:
620   case Intrinsic::ssub_with_overflow:
621   case Intrinsic::usub_with_overflow:
622   case Intrinsic::smul_with_overflow:
623   case Intrinsic::umul_with_overflow: {
624     static const CostTblEntry WithOverflowCostTbl[] = {
625         {Intrinsic::sadd_with_overflow, MVT::i8, 3},
626         {Intrinsic::uadd_with_overflow, MVT::i8, 3},
627         {Intrinsic::sadd_with_overflow, MVT::i16, 3},
628         {Intrinsic::uadd_with_overflow, MVT::i16, 3},
629         {Intrinsic::sadd_with_overflow, MVT::i32, 1},
630         {Intrinsic::uadd_with_overflow, MVT::i32, 1},
631         {Intrinsic::sadd_with_overflow, MVT::i64, 1},
632         {Intrinsic::uadd_with_overflow, MVT::i64, 1},
633         {Intrinsic::ssub_with_overflow, MVT::i8, 3},
634         {Intrinsic::usub_with_overflow, MVT::i8, 3},
635         {Intrinsic::ssub_with_overflow, MVT::i16, 3},
636         {Intrinsic::usub_with_overflow, MVT::i16, 3},
637         {Intrinsic::ssub_with_overflow, MVT::i32, 1},
638         {Intrinsic::usub_with_overflow, MVT::i32, 1},
639         {Intrinsic::ssub_with_overflow, MVT::i64, 1},
640         {Intrinsic::usub_with_overflow, MVT::i64, 1},
641         {Intrinsic::smul_with_overflow, MVT::i8, 5},
642         {Intrinsic::umul_with_overflow, MVT::i8, 4},
643         {Intrinsic::smul_with_overflow, MVT::i16, 5},
644         {Intrinsic::umul_with_overflow, MVT::i16, 4},
645         {Intrinsic::smul_with_overflow, MVT::i32, 2}, // eg umull;tst
646         {Intrinsic::umul_with_overflow, MVT::i32, 2}, // eg umull;cmp sxtw
647         {Intrinsic::smul_with_overflow, MVT::i64, 3}, // eg mul;smulh;cmp
648         {Intrinsic::umul_with_overflow, MVT::i64, 3}, // eg mul;umulh;cmp asr
649     };
650     EVT MTy = TLI->getValueType(DL, RetTy->getContainedType(0), true);
651     if (MTy.isSimple())
652       if (const auto *Entry = CostTableLookup(WithOverflowCostTbl, ICA.getID(),
653                                               MTy.getSimpleVT()))
654         return Entry->Cost;
655     break;
656   }
657   case Intrinsic::fptosi_sat:
658   case Intrinsic::fptoui_sat: {
659     if (ICA.getArgTypes().empty())
660       break;
661     bool IsSigned = ICA.getID() == Intrinsic::fptosi_sat;
662     auto LT = getTypeLegalizationCost(ICA.getArgTypes()[0]);
663     EVT MTy = TLI->getValueType(DL, RetTy);
664     // Check for the legal types, which are where the size of the input and the
665     // output are the same, or we are using cvt f64->i32 or f32->i64.
666     if ((LT.second == MVT::f32 || LT.second == MVT::f64 ||
667          LT.second == MVT::v2f32 || LT.second == MVT::v4f32 ||
668          LT.second == MVT::v2f64) &&
669         (LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits() ||
670          (LT.second == MVT::f64 && MTy == MVT::i32) ||
671          (LT.second == MVT::f32 && MTy == MVT::i64)))
672       return LT.first;
673     // Similarly for fp16 sizes
674     if (ST->hasFullFP16() &&
675         ((LT.second == MVT::f16 && MTy == MVT::i32) ||
676          ((LT.second == MVT::v4f16 || LT.second == MVT::v8f16) &&
677           (LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits()))))
678       return LT.first;
679 
680     // Otherwise we use a legal convert followed by a min+max
681     if ((LT.second.getScalarType() == MVT::f32 ||
682          LT.second.getScalarType() == MVT::f64 ||
683          (ST->hasFullFP16() && LT.second.getScalarType() == MVT::f16)) &&
684         LT.second.getScalarSizeInBits() >= MTy.getScalarSizeInBits()) {
685       Type *LegalTy =
686           Type::getIntNTy(RetTy->getContext(), LT.second.getScalarSizeInBits());
687       if (LT.second.isVector())
688         LegalTy = VectorType::get(LegalTy, LT.second.getVectorElementCount());
689       InstructionCost Cost = 1;
690       IntrinsicCostAttributes Attrs1(IsSigned ? Intrinsic::smin : Intrinsic::umin,
691                                     LegalTy, {LegalTy, LegalTy});
692       Cost += getIntrinsicInstrCost(Attrs1, CostKind);
693       IntrinsicCostAttributes Attrs2(IsSigned ? Intrinsic::smax : Intrinsic::umax,
694                                     LegalTy, {LegalTy, LegalTy});
695       Cost += getIntrinsicInstrCost(Attrs2, CostKind);
696       return LT.first * Cost;
697     }
698     break;
699   }
700   case Intrinsic::fshl:
701   case Intrinsic::fshr: {
702     if (ICA.getArgs().empty())
703       break;
704 
705     // TODO: Add handling for fshl where third argument is not a constant.
706     const TTI::OperandValueInfo OpInfoZ = TTI::getOperandInfo(ICA.getArgs()[2]);
707     if (!OpInfoZ.isConstant())
708       break;
709 
710     const auto LegalisationCost = getTypeLegalizationCost(RetTy);
711     if (OpInfoZ.isUniform()) {
712       // FIXME: The costs could be lower if the codegen is better.
713       static const CostTblEntry FshlTbl[] = {
714           {Intrinsic::fshl, MVT::v4i32, 3}, // ushr + shl + orr
715           {Intrinsic::fshl, MVT::v2i64, 3}, {Intrinsic::fshl, MVT::v16i8, 4},
716           {Intrinsic::fshl, MVT::v8i16, 4}, {Intrinsic::fshl, MVT::v2i32, 3},
717           {Intrinsic::fshl, MVT::v8i8, 4},  {Intrinsic::fshl, MVT::v4i16, 4}};
718       // Costs for both fshl & fshr are the same, so just pass Intrinsic::fshl
719       // to avoid having to duplicate the costs.
720       const auto *Entry =
721           CostTableLookup(FshlTbl, Intrinsic::fshl, LegalisationCost.second);
722       if (Entry)
723         return LegalisationCost.first * Entry->Cost;
724     }
725 
726     auto TyL = getTypeLegalizationCost(RetTy);
727     if (!RetTy->isIntegerTy())
728       break;
729 
730     // Estimate cost manually, as types like i8 and i16 will get promoted to
731     // i32 and CostTableLookup will ignore the extra conversion cost.
732     bool HigherCost = (RetTy->getScalarSizeInBits() != 32 &&
733                        RetTy->getScalarSizeInBits() < 64) ||
734                       (RetTy->getScalarSizeInBits() % 64 != 0);
735     unsigned ExtraCost = HigherCost ? 1 : 0;
736     if (RetTy->getScalarSizeInBits() == 32 ||
737         RetTy->getScalarSizeInBits() == 64)
738       ExtraCost = 0; // fhsl/fshr for i32 and i64 can be lowered to a single
739                      // extr instruction.
740     else if (HigherCost)
741       ExtraCost = 1;
742     else
743       break;
744     return TyL.first + ExtraCost;
745   }
746   default:
747     break;
748   }
749   return BaseT::getIntrinsicInstrCost(ICA, CostKind);
750 }
751 
752 /// The function will remove redundant reinterprets casting in the presence
753 /// of the control flow
754 static std::optional<Instruction *> processPhiNode(InstCombiner &IC,
755                                                    IntrinsicInst &II) {
756   SmallVector<Instruction *, 32> Worklist;
757   auto RequiredType = II.getType();
758 
759   auto *PN = dyn_cast<PHINode>(II.getArgOperand(0));
760   assert(PN && "Expected Phi Node!");
761 
762   // Don't create a new Phi unless we can remove the old one.
763   if (!PN->hasOneUse())
764     return std::nullopt;
765 
766   for (Value *IncValPhi : PN->incoming_values()) {
767     auto *Reinterpret = dyn_cast<IntrinsicInst>(IncValPhi);
768     if (!Reinterpret ||
769         Reinterpret->getIntrinsicID() !=
770             Intrinsic::aarch64_sve_convert_to_svbool ||
771         RequiredType != Reinterpret->getArgOperand(0)->getType())
772       return std::nullopt;
773   }
774 
775   // Create the new Phi
776   IC.Builder.SetInsertPoint(PN);
777   PHINode *NPN = IC.Builder.CreatePHI(RequiredType, PN->getNumIncomingValues());
778   Worklist.push_back(PN);
779 
780   for (unsigned I = 0; I < PN->getNumIncomingValues(); I++) {
781     auto *Reinterpret = cast<Instruction>(PN->getIncomingValue(I));
782     NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(I));
783     Worklist.push_back(Reinterpret);
784   }
785 
786   // Cleanup Phi Node and reinterprets
787   return IC.replaceInstUsesWith(II, NPN);
788 }
789 
790 // (from_svbool (binop (to_svbool pred) (svbool_t _) (svbool_t _))))
791 // => (binop (pred) (from_svbool _) (from_svbool _))
792 //
793 // The above transformation eliminates a `to_svbool` in the predicate
794 // operand of bitwise operation `binop` by narrowing the vector width of
795 // the operation. For example, it would convert a `<vscale x 16 x i1>
796 // and` into a `<vscale x 4 x i1> and`. This is profitable because
797 // to_svbool must zero the new lanes during widening, whereas
798 // from_svbool is free.
799 static std::optional<Instruction *>
800 tryCombineFromSVBoolBinOp(InstCombiner &IC, IntrinsicInst &II) {
801   auto BinOp = dyn_cast<IntrinsicInst>(II.getOperand(0));
802   if (!BinOp)
803     return std::nullopt;
804 
805   auto IntrinsicID = BinOp->getIntrinsicID();
806   switch (IntrinsicID) {
807   case Intrinsic::aarch64_sve_and_z:
808   case Intrinsic::aarch64_sve_bic_z:
809   case Intrinsic::aarch64_sve_eor_z:
810   case Intrinsic::aarch64_sve_nand_z:
811   case Intrinsic::aarch64_sve_nor_z:
812   case Intrinsic::aarch64_sve_orn_z:
813   case Intrinsic::aarch64_sve_orr_z:
814     break;
815   default:
816     return std::nullopt;
817   }
818 
819   auto BinOpPred = BinOp->getOperand(0);
820   auto BinOpOp1 = BinOp->getOperand(1);
821   auto BinOpOp2 = BinOp->getOperand(2);
822 
823   auto PredIntr = dyn_cast<IntrinsicInst>(BinOpPred);
824   if (!PredIntr ||
825       PredIntr->getIntrinsicID() != Intrinsic::aarch64_sve_convert_to_svbool)
826     return std::nullopt;
827 
828   auto PredOp = PredIntr->getOperand(0);
829   auto PredOpTy = cast<VectorType>(PredOp->getType());
830   if (PredOpTy != II.getType())
831     return std::nullopt;
832 
833   SmallVector<Value *> NarrowedBinOpArgs = {PredOp};
834   auto NarrowBinOpOp1 = IC.Builder.CreateIntrinsic(
835       Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp1});
836   NarrowedBinOpArgs.push_back(NarrowBinOpOp1);
837   if (BinOpOp1 == BinOpOp2)
838     NarrowedBinOpArgs.push_back(NarrowBinOpOp1);
839   else
840     NarrowedBinOpArgs.push_back(IC.Builder.CreateIntrinsic(
841         Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp2}));
842 
843   auto NarrowedBinOp =
844       IC.Builder.CreateIntrinsic(IntrinsicID, {PredOpTy}, NarrowedBinOpArgs);
845   return IC.replaceInstUsesWith(II, NarrowedBinOp);
846 }
847 
848 static std::optional<Instruction *>
849 instCombineConvertFromSVBool(InstCombiner &IC, IntrinsicInst &II) {
850   // If the reinterpret instruction operand is a PHI Node
851   if (isa<PHINode>(II.getArgOperand(0)))
852     return processPhiNode(IC, II);
853 
854   if (auto BinOpCombine = tryCombineFromSVBoolBinOp(IC, II))
855     return BinOpCombine;
856 
857   // Ignore converts to/from svcount_t.
858   if (isa<TargetExtType>(II.getArgOperand(0)->getType()) ||
859       isa<TargetExtType>(II.getType()))
860     return std::nullopt;
861 
862   SmallVector<Instruction *, 32> CandidatesForRemoval;
863   Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr;
864 
865   const auto *IVTy = cast<VectorType>(II.getType());
866 
867   // Walk the chain of conversions.
868   while (Cursor) {
869     // If the type of the cursor has fewer lanes than the final result, zeroing
870     // must take place, which breaks the equivalence chain.
871     const auto *CursorVTy = cast<VectorType>(Cursor->getType());
872     if (CursorVTy->getElementCount().getKnownMinValue() <
873         IVTy->getElementCount().getKnownMinValue())
874       break;
875 
876     // If the cursor has the same type as I, it is a viable replacement.
877     if (Cursor->getType() == IVTy)
878       EarliestReplacement = Cursor;
879 
880     auto *IntrinsicCursor = dyn_cast<IntrinsicInst>(Cursor);
881 
882     // If this is not an SVE conversion intrinsic, this is the end of the chain.
883     if (!IntrinsicCursor || !(IntrinsicCursor->getIntrinsicID() ==
884                                   Intrinsic::aarch64_sve_convert_to_svbool ||
885                               IntrinsicCursor->getIntrinsicID() ==
886                                   Intrinsic::aarch64_sve_convert_from_svbool))
887       break;
888 
889     CandidatesForRemoval.insert(CandidatesForRemoval.begin(), IntrinsicCursor);
890     Cursor = IntrinsicCursor->getOperand(0);
891   }
892 
893   // If no viable replacement in the conversion chain was found, there is
894   // nothing to do.
895   if (!EarliestReplacement)
896     return std::nullopt;
897 
898   return IC.replaceInstUsesWith(II, EarliestReplacement);
899 }
900 
901 static bool isAllActivePredicate(Value *Pred) {
902   // Look through convert.from.svbool(convert.to.svbool(...) chain.
903   Value *UncastedPred;
904   if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_from_svbool>(
905                       m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>(
906                           m_Value(UncastedPred)))))
907     // If the predicate has the same or less lanes than the uncasted
908     // predicate then we know the casting has no effect.
909     if (cast<ScalableVectorType>(Pred->getType())->getMinNumElements() <=
910         cast<ScalableVectorType>(UncastedPred->getType())->getMinNumElements())
911       Pred = UncastedPred;
912 
913   return match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
914                          m_ConstantInt<AArch64SVEPredPattern::all>()));
915 }
916 
917 static std::optional<Instruction *> instCombineSVESel(InstCombiner &IC,
918                                                       IntrinsicInst &II) {
919   // svsel(ptrue, x, y) => x
920   auto *OpPredicate = II.getOperand(0);
921   if (isAllActivePredicate(OpPredicate))
922     return IC.replaceInstUsesWith(II, II.getOperand(1));
923 
924   auto Select =
925       IC.Builder.CreateSelect(OpPredicate, II.getOperand(1), II.getOperand(2));
926   return IC.replaceInstUsesWith(II, Select);
927 }
928 
929 static std::optional<Instruction *> instCombineSVEDup(InstCombiner &IC,
930                                                       IntrinsicInst &II) {
931   IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
932   if (!Pg)
933     return std::nullopt;
934 
935   if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
936     return std::nullopt;
937 
938   const auto PTruePattern =
939       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
940   if (PTruePattern != AArch64SVEPredPattern::vl1)
941     return std::nullopt;
942 
943   // The intrinsic is inserting into lane zero so use an insert instead.
944   auto *IdxTy = Type::getInt64Ty(II.getContext());
945   auto *Insert = InsertElementInst::Create(
946       II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0));
947   Insert->insertBefore(&II);
948   Insert->takeName(&II);
949 
950   return IC.replaceInstUsesWith(II, Insert);
951 }
952 
953 static std::optional<Instruction *> instCombineSVEDupX(InstCombiner &IC,
954                                                        IntrinsicInst &II) {
955   // Replace DupX with a regular IR splat.
956   auto *RetTy = cast<ScalableVectorType>(II.getType());
957   Value *Splat = IC.Builder.CreateVectorSplat(RetTy->getElementCount(),
958                                               II.getArgOperand(0));
959   Splat->takeName(&II);
960   return IC.replaceInstUsesWith(II, Splat);
961 }
962 
963 static std::optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
964                                                         IntrinsicInst &II) {
965   LLVMContext &Ctx = II.getContext();
966 
967   // Check that the predicate is all active
968   auto *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
969   if (!Pg || Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
970     return std::nullopt;
971 
972   const auto PTruePattern =
973       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
974   if (PTruePattern != AArch64SVEPredPattern::all)
975     return std::nullopt;
976 
977   // Check that we have a compare of zero..
978   auto *SplatValue =
979       dyn_cast_or_null<ConstantInt>(getSplatValue(II.getArgOperand(2)));
980   if (!SplatValue || !SplatValue->isZero())
981     return std::nullopt;
982 
983   // ..against a dupq
984   auto *DupQLane = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
985   if (!DupQLane ||
986       DupQLane->getIntrinsicID() != Intrinsic::aarch64_sve_dupq_lane)
987     return std::nullopt;
988 
989   // Where the dupq is a lane 0 replicate of a vector insert
990   if (!cast<ConstantInt>(DupQLane->getArgOperand(1))->isZero())
991     return std::nullopt;
992 
993   auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0));
994   if (!VecIns || VecIns->getIntrinsicID() != Intrinsic::vector_insert)
995     return std::nullopt;
996 
997   // Where the vector insert is a fixed constant vector insert into undef at
998   // index zero
999   if (!isa<UndefValue>(VecIns->getArgOperand(0)))
1000     return std::nullopt;
1001 
1002   if (!cast<ConstantInt>(VecIns->getArgOperand(2))->isZero())
1003     return std::nullopt;
1004 
1005   auto *ConstVec = dyn_cast<Constant>(VecIns->getArgOperand(1));
1006   if (!ConstVec)
1007     return std::nullopt;
1008 
1009   auto *VecTy = dyn_cast<FixedVectorType>(ConstVec->getType());
1010   auto *OutTy = dyn_cast<ScalableVectorType>(II.getType());
1011   if (!VecTy || !OutTy || VecTy->getNumElements() != OutTy->getMinNumElements())
1012     return std::nullopt;
1013 
1014   unsigned NumElts = VecTy->getNumElements();
1015   unsigned PredicateBits = 0;
1016 
1017   // Expand intrinsic operands to a 16-bit byte level predicate
1018   for (unsigned I = 0; I < NumElts; ++I) {
1019     auto *Arg = dyn_cast<ConstantInt>(ConstVec->getAggregateElement(I));
1020     if (!Arg)
1021       return std::nullopt;
1022     if (!Arg->isZero())
1023       PredicateBits |= 1 << (I * (16 / NumElts));
1024   }
1025 
1026   // If all bits are zero bail early with an empty predicate
1027   if (PredicateBits == 0) {
1028     auto *PFalse = Constant::getNullValue(II.getType());
1029     PFalse->takeName(&II);
1030     return IC.replaceInstUsesWith(II, PFalse);
1031   }
1032 
1033   // Calculate largest predicate type used (where byte predicate is largest)
1034   unsigned Mask = 8;
1035   for (unsigned I = 0; I < 16; ++I)
1036     if ((PredicateBits & (1 << I)) != 0)
1037       Mask |= (I % 8);
1038 
1039   unsigned PredSize = Mask & -Mask;
1040   auto *PredType = ScalableVectorType::get(
1041       Type::getInt1Ty(Ctx), AArch64::SVEBitsPerBlock / (PredSize * 8));
1042 
1043   // Ensure all relevant bits are set
1044   for (unsigned I = 0; I < 16; I += PredSize)
1045     if ((PredicateBits & (1 << I)) == 0)
1046       return std::nullopt;
1047 
1048   auto *PTruePat =
1049       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
1050   auto *PTrue = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
1051                                            {PredType}, {PTruePat});
1052   auto *ConvertToSVBool = IC.Builder.CreateIntrinsic(
1053       Intrinsic::aarch64_sve_convert_to_svbool, {PredType}, {PTrue});
1054   auto *ConvertFromSVBool =
1055       IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
1056                                  {II.getType()}, {ConvertToSVBool});
1057 
1058   ConvertFromSVBool->takeName(&II);
1059   return IC.replaceInstUsesWith(II, ConvertFromSVBool);
1060 }
1061 
1062 static std::optional<Instruction *> instCombineSVELast(InstCombiner &IC,
1063                                                        IntrinsicInst &II) {
1064   Value *Pg = II.getArgOperand(0);
1065   Value *Vec = II.getArgOperand(1);
1066   auto IntrinsicID = II.getIntrinsicID();
1067   bool IsAfter = IntrinsicID == Intrinsic::aarch64_sve_lasta;
1068 
1069   // lastX(splat(X)) --> X
1070   if (auto *SplatVal = getSplatValue(Vec))
1071     return IC.replaceInstUsesWith(II, SplatVal);
1072 
1073   // If x and/or y is a splat value then:
1074   // lastX (binop (x, y)) --> binop(lastX(x), lastX(y))
1075   Value *LHS, *RHS;
1076   if (match(Vec, m_OneUse(m_BinOp(m_Value(LHS), m_Value(RHS))))) {
1077     if (isSplatValue(LHS) || isSplatValue(RHS)) {
1078       auto *OldBinOp = cast<BinaryOperator>(Vec);
1079       auto OpC = OldBinOp->getOpcode();
1080       auto *NewLHS =
1081           IC.Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, LHS});
1082       auto *NewRHS =
1083           IC.Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, RHS});
1084       auto *NewBinOp = BinaryOperator::CreateWithCopiedFlags(
1085           OpC, NewLHS, NewRHS, OldBinOp, OldBinOp->getName(), &II);
1086       return IC.replaceInstUsesWith(II, NewBinOp);
1087     }
1088   }
1089 
1090   auto *C = dyn_cast<Constant>(Pg);
1091   if (IsAfter && C && C->isNullValue()) {
1092     // The intrinsic is extracting lane 0 so use an extract instead.
1093     auto *IdxTy = Type::getInt64Ty(II.getContext());
1094     auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, 0));
1095     Extract->insertBefore(&II);
1096     Extract->takeName(&II);
1097     return IC.replaceInstUsesWith(II, Extract);
1098   }
1099 
1100   auto *IntrPG = dyn_cast<IntrinsicInst>(Pg);
1101   if (!IntrPG)
1102     return std::nullopt;
1103 
1104   if (IntrPG->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
1105     return std::nullopt;
1106 
1107   const auto PTruePattern =
1108       cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue();
1109 
1110   // Can the intrinsic's predicate be converted to a known constant index?
1111   unsigned MinNumElts = getNumElementsFromSVEPredPattern(PTruePattern);
1112   if (!MinNumElts)
1113     return std::nullopt;
1114 
1115   unsigned Idx = MinNumElts - 1;
1116   // Increment the index if extracting the element after the last active
1117   // predicate element.
1118   if (IsAfter)
1119     ++Idx;
1120 
1121   // Ignore extracts whose index is larger than the known minimum vector
1122   // length. NOTE: This is an artificial constraint where we prefer to
1123   // maintain what the user asked for until an alternative is proven faster.
1124   auto *PgVTy = cast<ScalableVectorType>(Pg->getType());
1125   if (Idx >= PgVTy->getMinNumElements())
1126     return std::nullopt;
1127 
1128   // The intrinsic is extracting a fixed lane so use an extract instead.
1129   auto *IdxTy = Type::getInt64Ty(II.getContext());
1130   auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, Idx));
1131   Extract->insertBefore(&II);
1132   Extract->takeName(&II);
1133   return IC.replaceInstUsesWith(II, Extract);
1134 }
1135 
1136 static std::optional<Instruction *> instCombineSVECondLast(InstCombiner &IC,
1137                                                            IntrinsicInst &II) {
1138   // The SIMD&FP variant of CLAST[AB] is significantly faster than the scalar
1139   // integer variant across a variety of micro-architectures. Replace scalar
1140   // integer CLAST[AB] intrinsic with optimal SIMD&FP variant. A simple
1141   // bitcast-to-fp + clast[ab] + bitcast-to-int will cost a cycle or two more
1142   // depending on the micro-architecture, but has been observed as generally
1143   // being faster, particularly when the CLAST[AB] op is a loop-carried
1144   // dependency.
1145   Value *Pg = II.getArgOperand(0);
1146   Value *Fallback = II.getArgOperand(1);
1147   Value *Vec = II.getArgOperand(2);
1148   Type *Ty = II.getType();
1149 
1150   if (!Ty->isIntegerTy())
1151     return std::nullopt;
1152 
1153   Type *FPTy;
1154   switch (cast<IntegerType>(Ty)->getBitWidth()) {
1155   default:
1156     return std::nullopt;
1157   case 16:
1158     FPTy = IC.Builder.getHalfTy();
1159     break;
1160   case 32:
1161     FPTy = IC.Builder.getFloatTy();
1162     break;
1163   case 64:
1164     FPTy = IC.Builder.getDoubleTy();
1165     break;
1166   }
1167 
1168   Value *FPFallBack = IC.Builder.CreateBitCast(Fallback, FPTy);
1169   auto *FPVTy = VectorType::get(
1170       FPTy, cast<VectorType>(Vec->getType())->getElementCount());
1171   Value *FPVec = IC.Builder.CreateBitCast(Vec, FPVTy);
1172   auto *FPII = IC.Builder.CreateIntrinsic(
1173       II.getIntrinsicID(), {FPVec->getType()}, {Pg, FPFallBack, FPVec});
1174   Value *FPIItoInt = IC.Builder.CreateBitCast(FPII, II.getType());
1175   return IC.replaceInstUsesWith(II, FPIItoInt);
1176 }
1177 
1178 static std::optional<Instruction *> instCombineRDFFR(InstCombiner &IC,
1179                                                      IntrinsicInst &II) {
1180   LLVMContext &Ctx = II.getContext();
1181   // Replace rdffr with predicated rdffr.z intrinsic, so that optimizePTestInstr
1182   // can work with RDFFR_PP for ptest elimination.
1183   auto *AllPat =
1184       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
1185   auto *PTrue = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
1186                                            {II.getType()}, {AllPat});
1187   auto *RDFFR =
1188       IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_rdffr_z, {}, {PTrue});
1189   RDFFR->takeName(&II);
1190   return IC.replaceInstUsesWith(II, RDFFR);
1191 }
1192 
1193 static std::optional<Instruction *>
1194 instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
1195   const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
1196 
1197   if (Pattern == AArch64SVEPredPattern::all) {
1198     Constant *StepVal = ConstantInt::get(II.getType(), NumElts);
1199     auto *VScale = IC.Builder.CreateVScale(StepVal);
1200     VScale->takeName(&II);
1201     return IC.replaceInstUsesWith(II, VScale);
1202   }
1203 
1204   unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern);
1205 
1206   return MinNumElts && NumElts >= MinNumElts
1207              ? std::optional<Instruction *>(IC.replaceInstUsesWith(
1208                    II, ConstantInt::get(II.getType(), MinNumElts)))
1209              : std::nullopt;
1210 }
1211 
1212 static std::optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
1213                                                         IntrinsicInst &II) {
1214   Value *PgVal = II.getArgOperand(0);
1215   Value *OpVal = II.getArgOperand(1);
1216 
1217   // PTEST_<FIRST|LAST>(X, X) is equivalent to PTEST_ANY(X, X).
1218   // Later optimizations prefer this form.
1219   if (PgVal == OpVal &&
1220       (II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_first ||
1221        II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_last)) {
1222     Value *Ops[] = {PgVal, OpVal};
1223     Type *Tys[] = {PgVal->getType()};
1224 
1225     auto *PTest =
1226         IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptest_any, Tys, Ops);
1227     PTest->takeName(&II);
1228 
1229     return IC.replaceInstUsesWith(II, PTest);
1230   }
1231 
1232   IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(PgVal);
1233   IntrinsicInst *Op = dyn_cast<IntrinsicInst>(OpVal);
1234 
1235   if (!Pg || !Op)
1236     return std::nullopt;
1237 
1238   Intrinsic::ID OpIID = Op->getIntrinsicID();
1239 
1240   if (Pg->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
1241       OpIID == Intrinsic::aarch64_sve_convert_to_svbool &&
1242       Pg->getArgOperand(0)->getType() == Op->getArgOperand(0)->getType()) {
1243     Value *Ops[] = {Pg->getArgOperand(0), Op->getArgOperand(0)};
1244     Type *Tys[] = {Pg->getArgOperand(0)->getType()};
1245 
1246     auto *PTest = IC.Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
1247 
1248     PTest->takeName(&II);
1249     return IC.replaceInstUsesWith(II, PTest);
1250   }
1251 
1252   // Transform PTEST_ANY(X=OP(PG,...), X) -> PTEST_ANY(PG, X)).
1253   // Later optimizations may rewrite sequence to use the flag-setting variant
1254   // of instruction X to remove PTEST.
1255   if ((Pg == Op) && (II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_any) &&
1256       ((OpIID == Intrinsic::aarch64_sve_brka_z) ||
1257        (OpIID == Intrinsic::aarch64_sve_brkb_z) ||
1258        (OpIID == Intrinsic::aarch64_sve_brkpa_z) ||
1259        (OpIID == Intrinsic::aarch64_sve_brkpb_z) ||
1260        (OpIID == Intrinsic::aarch64_sve_rdffr_z) ||
1261        (OpIID == Intrinsic::aarch64_sve_and_z) ||
1262        (OpIID == Intrinsic::aarch64_sve_bic_z) ||
1263        (OpIID == Intrinsic::aarch64_sve_eor_z) ||
1264        (OpIID == Intrinsic::aarch64_sve_nand_z) ||
1265        (OpIID == Intrinsic::aarch64_sve_nor_z) ||
1266        (OpIID == Intrinsic::aarch64_sve_orn_z) ||
1267        (OpIID == Intrinsic::aarch64_sve_orr_z))) {
1268     Value *Ops[] = {Pg->getArgOperand(0), Pg};
1269     Type *Tys[] = {Pg->getType()};
1270 
1271     auto *PTest = IC.Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
1272     PTest->takeName(&II);
1273 
1274     return IC.replaceInstUsesWith(II, PTest);
1275   }
1276 
1277   return std::nullopt;
1278 }
1279 
1280 template <Intrinsic::ID MulOpc, typename Intrinsic::ID FuseOpc>
1281 static std::optional<Instruction *>
1282 instCombineSVEVectorFuseMulAddSub(InstCombiner &IC, IntrinsicInst &II,
1283                                   bool MergeIntoAddendOp) {
1284   Value *P = II.getOperand(0);
1285   Value *MulOp0, *MulOp1, *AddendOp, *Mul;
1286   if (MergeIntoAddendOp) {
1287     AddendOp = II.getOperand(1);
1288     Mul = II.getOperand(2);
1289   } else {
1290     AddendOp = II.getOperand(2);
1291     Mul = II.getOperand(1);
1292   }
1293 
1294   if (!match(Mul, m_Intrinsic<MulOpc>(m_Specific(P), m_Value(MulOp0),
1295                                       m_Value(MulOp1))))
1296     return std::nullopt;
1297 
1298   if (!Mul->hasOneUse())
1299     return std::nullopt;
1300 
1301   Instruction *FMFSource = nullptr;
1302   if (II.getType()->isFPOrFPVectorTy()) {
1303     llvm::FastMathFlags FAddFlags = II.getFastMathFlags();
1304     // Stop the combine when the flags on the inputs differ in case dropping
1305     // flags would lead to us missing out on more beneficial optimizations.
1306     if (FAddFlags != cast<CallInst>(Mul)->getFastMathFlags())
1307       return std::nullopt;
1308     if (!FAddFlags.allowContract())
1309       return std::nullopt;
1310     FMFSource = &II;
1311   }
1312 
1313   CallInst *Res;
1314   if (MergeIntoAddendOp)
1315     Res = IC.Builder.CreateIntrinsic(FuseOpc, {II.getType()},
1316                                      {P, AddendOp, MulOp0, MulOp1}, FMFSource);
1317   else
1318     Res = IC.Builder.CreateIntrinsic(FuseOpc, {II.getType()},
1319                                      {P, MulOp0, MulOp1, AddendOp}, FMFSource);
1320 
1321   return IC.replaceInstUsesWith(II, Res);
1322 }
1323 
1324 static std::optional<Instruction *>
1325 instCombineSVELD1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
1326   Value *Pred = II.getOperand(0);
1327   Value *PtrOp = II.getOperand(1);
1328   Type *VecTy = II.getType();
1329 
1330   if (isAllActivePredicate(Pred)) {
1331     LoadInst *Load = IC.Builder.CreateLoad(VecTy, PtrOp);
1332     Load->copyMetadata(II);
1333     return IC.replaceInstUsesWith(II, Load);
1334   }
1335 
1336   CallInst *MaskedLoad =
1337       IC.Builder.CreateMaskedLoad(VecTy, PtrOp, PtrOp->getPointerAlignment(DL),
1338                                   Pred, ConstantAggregateZero::get(VecTy));
1339   MaskedLoad->copyMetadata(II);
1340   return IC.replaceInstUsesWith(II, MaskedLoad);
1341 }
1342 
1343 static std::optional<Instruction *>
1344 instCombineSVEST1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
1345   Value *VecOp = II.getOperand(0);
1346   Value *Pred = II.getOperand(1);
1347   Value *PtrOp = II.getOperand(2);
1348 
1349   if (isAllActivePredicate(Pred)) {
1350     StoreInst *Store = IC.Builder.CreateStore(VecOp, PtrOp);
1351     Store->copyMetadata(II);
1352     return IC.eraseInstFromFunction(II);
1353   }
1354 
1355   CallInst *MaskedStore = IC.Builder.CreateMaskedStore(
1356       VecOp, PtrOp, PtrOp->getPointerAlignment(DL), Pred);
1357   MaskedStore->copyMetadata(II);
1358   return IC.eraseInstFromFunction(II);
1359 }
1360 
1361 static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) {
1362   switch (Intrinsic) {
1363   case Intrinsic::aarch64_sve_fmul_u:
1364     return Instruction::BinaryOps::FMul;
1365   case Intrinsic::aarch64_sve_fadd_u:
1366     return Instruction::BinaryOps::FAdd;
1367   case Intrinsic::aarch64_sve_fsub_u:
1368     return Instruction::BinaryOps::FSub;
1369   default:
1370     return Instruction::BinaryOpsEnd;
1371   }
1372 }
1373 
1374 static std::optional<Instruction *>
1375 instCombineSVEVectorBinOp(InstCombiner &IC, IntrinsicInst &II) {
1376   // Bail due to missing support for ISD::STRICT_ scalable vector operations.
1377   if (II.isStrictFP())
1378     return std::nullopt;
1379 
1380   auto *OpPredicate = II.getOperand(0);
1381   auto BinOpCode = intrinsicIDToBinOpCode(II.getIntrinsicID());
1382   if (BinOpCode == Instruction::BinaryOpsEnd ||
1383       !match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
1384                               m_ConstantInt<AArch64SVEPredPattern::all>())))
1385     return std::nullopt;
1386   IRBuilderBase::FastMathFlagGuard FMFGuard(IC.Builder);
1387   IC.Builder.setFastMathFlags(II.getFastMathFlags());
1388   auto BinOp =
1389       IC.Builder.CreateBinOp(BinOpCode, II.getOperand(1), II.getOperand(2));
1390   return IC.replaceInstUsesWith(II, BinOp);
1391 }
1392 
1393 // Canonicalise operations that take an all active predicate (e.g. sve.add ->
1394 // sve.add_u).
1395 static std::optional<Instruction *> instCombineSVEAllActive(IntrinsicInst &II,
1396                                                             Intrinsic::ID IID) {
1397   auto *OpPredicate = II.getOperand(0);
1398   if (!match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
1399                               m_ConstantInt<AArch64SVEPredPattern::all>())))
1400     return std::nullopt;
1401 
1402   auto *Mod = II.getModule();
1403   auto *NewDecl = Intrinsic::getDeclaration(Mod, IID, {II.getType()});
1404   II.setCalledFunction(NewDecl);
1405 
1406   return &II;
1407 }
1408 
1409 // Simplify operations where predicate has all inactive lanes or try to replace
1410 // with _u form when all lanes are active
1411 static std::optional<Instruction *>
1412 instCombineSVEAllOrNoActive(InstCombiner &IC, IntrinsicInst &II,
1413                             Intrinsic::ID IID) {
1414   if (match(II.getOperand(0), m_ZeroInt())) {
1415     //  llvm_ir, pred(0), op1, op2 - Spec says to return op1 when all lanes are
1416     //  inactive for sv[func]_m
1417     return IC.replaceInstUsesWith(II, II.getOperand(1));
1418   }
1419   return instCombineSVEAllActive(II, IID);
1420 }
1421 
1422 static std::optional<Instruction *> instCombineSVEVectorAdd(InstCombiner &IC,
1423                                                             IntrinsicInst &II) {
1424   if (auto II_U =
1425           instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_add_u))
1426     return II_U;
1427   if (auto MLA = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
1428                                                    Intrinsic::aarch64_sve_mla>(
1429           IC, II, true))
1430     return MLA;
1431   if (auto MAD = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
1432                                                    Intrinsic::aarch64_sve_mad>(
1433           IC, II, false))
1434     return MAD;
1435   return std::nullopt;
1436 }
1437 
1438 static std::optional<Instruction *>
1439 instCombineSVEVectorFAdd(InstCombiner &IC, IntrinsicInst &II) {
1440   if (auto II_U =
1441           instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fadd_u))
1442     return II_U;
1443   if (auto FMLA =
1444           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1445                                             Intrinsic::aarch64_sve_fmla>(IC, II,
1446                                                                          true))
1447     return FMLA;
1448   if (auto FMAD =
1449           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1450                                             Intrinsic::aarch64_sve_fmad>(IC, II,
1451                                                                          false))
1452     return FMAD;
1453   if (auto FMLA =
1454           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
1455                                             Intrinsic::aarch64_sve_fmla>(IC, II,
1456                                                                          true))
1457     return FMLA;
1458   return std::nullopt;
1459 }
1460 
1461 static std::optional<Instruction *>
1462 instCombineSVEVectorFAddU(InstCombiner &IC, IntrinsicInst &II) {
1463   if (auto FMLA =
1464           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1465                                             Intrinsic::aarch64_sve_fmla>(IC, II,
1466                                                                          true))
1467     return FMLA;
1468   if (auto FMAD =
1469           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1470                                             Intrinsic::aarch64_sve_fmad>(IC, II,
1471                                                                          false))
1472     return FMAD;
1473   if (auto FMLA_U =
1474           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
1475                                             Intrinsic::aarch64_sve_fmla_u>(
1476               IC, II, true))
1477     return FMLA_U;
1478   return instCombineSVEVectorBinOp(IC, II);
1479 }
1480 
1481 static std::optional<Instruction *>
1482 instCombineSVEVectorFSub(InstCombiner &IC, IntrinsicInst &II) {
1483   if (auto II_U =
1484           instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fsub_u))
1485     return II_U;
1486   if (auto FMLS =
1487           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1488                                             Intrinsic::aarch64_sve_fmls>(IC, II,
1489                                                                          true))
1490     return FMLS;
1491   if (auto FMSB =
1492           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1493                                             Intrinsic::aarch64_sve_fnmsb>(
1494               IC, II, false))
1495     return FMSB;
1496   if (auto FMLS =
1497           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
1498                                             Intrinsic::aarch64_sve_fmls>(IC, II,
1499                                                                          true))
1500     return FMLS;
1501   return std::nullopt;
1502 }
1503 
1504 static std::optional<Instruction *>
1505 instCombineSVEVectorFSubU(InstCombiner &IC, IntrinsicInst &II) {
1506   if (auto FMLS =
1507           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1508                                             Intrinsic::aarch64_sve_fmls>(IC, II,
1509                                                                          true))
1510     return FMLS;
1511   if (auto FMSB =
1512           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1513                                             Intrinsic::aarch64_sve_fnmsb>(
1514               IC, II, false))
1515     return FMSB;
1516   if (auto FMLS_U =
1517           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
1518                                             Intrinsic::aarch64_sve_fmls_u>(
1519               IC, II, true))
1520     return FMLS_U;
1521   return instCombineSVEVectorBinOp(IC, II);
1522 }
1523 
1524 static std::optional<Instruction *> instCombineSVEVectorSub(InstCombiner &IC,
1525                                                             IntrinsicInst &II) {
1526   if (auto II_U =
1527           instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_sub_u))
1528     return II_U;
1529   if (auto MLS = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
1530                                                    Intrinsic::aarch64_sve_mls>(
1531           IC, II, true))
1532     return MLS;
1533   return std::nullopt;
1534 }
1535 
1536 static std::optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
1537                                                             IntrinsicInst &II,
1538                                                             Intrinsic::ID IID) {
1539   auto *OpPredicate = II.getOperand(0);
1540   auto *OpMultiplicand = II.getOperand(1);
1541   auto *OpMultiplier = II.getOperand(2);
1542 
1543   // Return true if a given instruction is a unit splat value, false otherwise.
1544   auto IsUnitSplat = [](auto *I) {
1545     auto *SplatValue = getSplatValue(I);
1546     if (!SplatValue)
1547       return false;
1548     return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
1549   };
1550 
1551   // Return true if a given instruction is an aarch64_sve_dup intrinsic call
1552   // with a unit splat value, false otherwise.
1553   auto IsUnitDup = [](auto *I) {
1554     auto *IntrI = dyn_cast<IntrinsicInst>(I);
1555     if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup)
1556       return false;
1557 
1558     auto *SplatValue = IntrI->getOperand(2);
1559     return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
1560   };
1561 
1562   if (IsUnitSplat(OpMultiplier)) {
1563     // [f]mul pg %n, (dupx 1) => %n
1564     OpMultiplicand->takeName(&II);
1565     return IC.replaceInstUsesWith(II, OpMultiplicand);
1566   } else if (IsUnitDup(OpMultiplier)) {
1567     // [f]mul pg %n, (dup pg 1) => %n
1568     auto *DupInst = cast<IntrinsicInst>(OpMultiplier);
1569     auto *DupPg = DupInst->getOperand(1);
1570     // TODO: this is naive. The optimization is still valid if DupPg
1571     // 'encompasses' OpPredicate, not only if they're the same predicate.
1572     if (OpPredicate == DupPg) {
1573       OpMultiplicand->takeName(&II);
1574       return IC.replaceInstUsesWith(II, OpMultiplicand);
1575     }
1576   }
1577 
1578   return instCombineSVEVectorBinOp(IC, II);
1579 }
1580 
1581 static std::optional<Instruction *> instCombineSVEUnpack(InstCombiner &IC,
1582                                                          IntrinsicInst &II) {
1583   Value *UnpackArg = II.getArgOperand(0);
1584   auto *RetTy = cast<ScalableVectorType>(II.getType());
1585   bool IsSigned = II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpkhi ||
1586                   II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpklo;
1587 
1588   // Hi = uunpkhi(splat(X)) --> Hi = splat(extend(X))
1589   // Lo = uunpklo(splat(X)) --> Lo = splat(extend(X))
1590   if (auto *ScalarArg = getSplatValue(UnpackArg)) {
1591     ScalarArg =
1592         IC.Builder.CreateIntCast(ScalarArg, RetTy->getScalarType(), IsSigned);
1593     Value *NewVal =
1594         IC.Builder.CreateVectorSplat(RetTy->getElementCount(), ScalarArg);
1595     NewVal->takeName(&II);
1596     return IC.replaceInstUsesWith(II, NewVal);
1597   }
1598 
1599   return std::nullopt;
1600 }
1601 static std::optional<Instruction *> instCombineSVETBL(InstCombiner &IC,
1602                                                       IntrinsicInst &II) {
1603   auto *OpVal = II.getOperand(0);
1604   auto *OpIndices = II.getOperand(1);
1605   VectorType *VTy = cast<VectorType>(II.getType());
1606 
1607   // Check whether OpIndices is a constant splat value < minimal element count
1608   // of result.
1609   auto *SplatValue = dyn_cast_or_null<ConstantInt>(getSplatValue(OpIndices));
1610   if (!SplatValue ||
1611       SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue()))
1612     return std::nullopt;
1613 
1614   // Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to
1615   // splat_vector(extractelement(OpVal, SplatValue)) for further optimization.
1616   auto *Extract = IC.Builder.CreateExtractElement(OpVal, SplatValue);
1617   auto *VectorSplat =
1618       IC.Builder.CreateVectorSplat(VTy->getElementCount(), Extract);
1619 
1620   VectorSplat->takeName(&II);
1621   return IC.replaceInstUsesWith(II, VectorSplat);
1622 }
1623 
1624 static std::optional<Instruction *> instCombineSVEZip(InstCombiner &IC,
1625                                                       IntrinsicInst &II) {
1626   // zip1(uzp1(A, B), uzp2(A, B)) --> A
1627   // zip2(uzp1(A, B), uzp2(A, B)) --> B
1628   Value *A, *B;
1629   if (match(II.getArgOperand(0),
1630             m_Intrinsic<Intrinsic::aarch64_sve_uzp1>(m_Value(A), m_Value(B))) &&
1631       match(II.getArgOperand(1), m_Intrinsic<Intrinsic::aarch64_sve_uzp2>(
1632                                      m_Specific(A), m_Specific(B))))
1633     return IC.replaceInstUsesWith(
1634         II, (II.getIntrinsicID() == Intrinsic::aarch64_sve_zip1 ? A : B));
1635 
1636   return std::nullopt;
1637 }
1638 
1639 static std::optional<Instruction *>
1640 instCombineLD1GatherIndex(InstCombiner &IC, IntrinsicInst &II) {
1641   Value *Mask = II.getOperand(0);
1642   Value *BasePtr = II.getOperand(1);
1643   Value *Index = II.getOperand(2);
1644   Type *Ty = II.getType();
1645   Value *PassThru = ConstantAggregateZero::get(Ty);
1646 
1647   // Contiguous gather => masked load.
1648   // (sve.ld1.gather.index Mask BasePtr (sve.index IndexBase 1))
1649   // => (masked.load (gep BasePtr IndexBase) Align Mask zeroinitializer)
1650   Value *IndexBase;
1651   if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
1652                        m_Value(IndexBase), m_SpecificInt(1)))) {
1653     Align Alignment =
1654         BasePtr->getPointerAlignment(II.getModule()->getDataLayout());
1655 
1656     Type *VecPtrTy = PointerType::getUnqual(Ty);
1657     Value *Ptr = IC.Builder.CreateGEP(cast<VectorType>(Ty)->getElementType(),
1658                                       BasePtr, IndexBase);
1659     Ptr = IC.Builder.CreateBitCast(Ptr, VecPtrTy);
1660     CallInst *MaskedLoad =
1661         IC.Builder.CreateMaskedLoad(Ty, Ptr, Alignment, Mask, PassThru);
1662     MaskedLoad->takeName(&II);
1663     return IC.replaceInstUsesWith(II, MaskedLoad);
1664   }
1665 
1666   return std::nullopt;
1667 }
1668 
1669 static std::optional<Instruction *>
1670 instCombineST1ScatterIndex(InstCombiner &IC, IntrinsicInst &II) {
1671   Value *Val = II.getOperand(0);
1672   Value *Mask = II.getOperand(1);
1673   Value *BasePtr = II.getOperand(2);
1674   Value *Index = II.getOperand(3);
1675   Type *Ty = Val->getType();
1676 
1677   // Contiguous scatter => masked store.
1678   // (sve.st1.scatter.index Value Mask BasePtr (sve.index IndexBase 1))
1679   // => (masked.store Value (gep BasePtr IndexBase) Align Mask)
1680   Value *IndexBase;
1681   if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
1682                        m_Value(IndexBase), m_SpecificInt(1)))) {
1683     Align Alignment =
1684         BasePtr->getPointerAlignment(II.getModule()->getDataLayout());
1685 
1686     Value *Ptr = IC.Builder.CreateGEP(cast<VectorType>(Ty)->getElementType(),
1687                                       BasePtr, IndexBase);
1688     Type *VecPtrTy = PointerType::getUnqual(Ty);
1689     Ptr = IC.Builder.CreateBitCast(Ptr, VecPtrTy);
1690 
1691     (void)IC.Builder.CreateMaskedStore(Val, Ptr, Alignment, Mask);
1692 
1693     return IC.eraseInstFromFunction(II);
1694   }
1695 
1696   return std::nullopt;
1697 }
1698 
1699 static std::optional<Instruction *> instCombineSVESDIV(InstCombiner &IC,
1700                                                        IntrinsicInst &II) {
1701   Type *Int32Ty = IC.Builder.getInt32Ty();
1702   Value *Pred = II.getOperand(0);
1703   Value *Vec = II.getOperand(1);
1704   Value *DivVec = II.getOperand(2);
1705 
1706   Value *SplatValue = getSplatValue(DivVec);
1707   ConstantInt *SplatConstantInt = dyn_cast_or_null<ConstantInt>(SplatValue);
1708   if (!SplatConstantInt)
1709     return std::nullopt;
1710   APInt Divisor = SplatConstantInt->getValue();
1711 
1712   if (Divisor.isPowerOf2()) {
1713     Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
1714     auto ASRD = IC.Builder.CreateIntrinsic(
1715         Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
1716     return IC.replaceInstUsesWith(II, ASRD);
1717   }
1718   if (Divisor.isNegatedPowerOf2()) {
1719     Divisor.negate();
1720     Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
1721     auto ASRD = IC.Builder.CreateIntrinsic(
1722         Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
1723     auto NEG = IC.Builder.CreateIntrinsic(
1724         Intrinsic::aarch64_sve_neg, {ASRD->getType()}, {ASRD, Pred, ASRD});
1725     return IC.replaceInstUsesWith(II, NEG);
1726   }
1727 
1728   return std::nullopt;
1729 }
1730 
1731 bool SimplifyValuePattern(SmallVector<Value *> &Vec, bool AllowPoison) {
1732   size_t VecSize = Vec.size();
1733   if (VecSize == 1)
1734     return true;
1735   if (!isPowerOf2_64(VecSize))
1736     return false;
1737   size_t HalfVecSize = VecSize / 2;
1738 
1739   for (auto LHS = Vec.begin(), RHS = Vec.begin() + HalfVecSize;
1740        RHS != Vec.end(); LHS++, RHS++) {
1741     if (*LHS != nullptr && *RHS != nullptr) {
1742       if (*LHS == *RHS)
1743         continue;
1744       else
1745         return false;
1746     }
1747     if (!AllowPoison)
1748       return false;
1749     if (*LHS == nullptr && *RHS != nullptr)
1750       *LHS = *RHS;
1751   }
1752 
1753   Vec.resize(HalfVecSize);
1754   SimplifyValuePattern(Vec, AllowPoison);
1755   return true;
1756 }
1757 
1758 // Try to simplify dupqlane patterns like dupqlane(f32 A, f32 B, f32 A, f32 B)
1759 // to dupqlane(f64(C)) where C is A concatenated with B
1760 static std::optional<Instruction *> instCombineSVEDupqLane(InstCombiner &IC,
1761                                                            IntrinsicInst &II) {
1762   Value *CurrentInsertElt = nullptr, *Default = nullptr;
1763   if (!match(II.getOperand(0),
1764              m_Intrinsic<Intrinsic::vector_insert>(
1765                  m_Value(Default), m_Value(CurrentInsertElt), m_Value())) ||
1766       !isa<FixedVectorType>(CurrentInsertElt->getType()))
1767     return std::nullopt;
1768   auto IIScalableTy = cast<ScalableVectorType>(II.getType());
1769 
1770   // Insert the scalars into a container ordered by InsertElement index
1771   SmallVector<Value *> Elts(IIScalableTy->getMinNumElements(), nullptr);
1772   while (auto InsertElt = dyn_cast<InsertElementInst>(CurrentInsertElt)) {
1773     auto Idx = cast<ConstantInt>(InsertElt->getOperand(2));
1774     Elts[Idx->getValue().getZExtValue()] = InsertElt->getOperand(1);
1775     CurrentInsertElt = InsertElt->getOperand(0);
1776   }
1777 
1778   bool AllowPoison =
1779       isa<PoisonValue>(CurrentInsertElt) && isa<PoisonValue>(Default);
1780   if (!SimplifyValuePattern(Elts, AllowPoison))
1781     return std::nullopt;
1782 
1783   // Rebuild the simplified chain of InsertElements. e.g. (a, b, a, b) as (a, b)
1784   Value *InsertEltChain = PoisonValue::get(CurrentInsertElt->getType());
1785   for (size_t I = 0; I < Elts.size(); I++) {
1786     if (Elts[I] == nullptr)
1787       continue;
1788     InsertEltChain = IC.Builder.CreateInsertElement(InsertEltChain, Elts[I],
1789                                                     IC.Builder.getInt64(I));
1790   }
1791   if (InsertEltChain == nullptr)
1792     return std::nullopt;
1793 
1794   // Splat the simplified sequence, e.g. (f16 a, f16 b, f16 c, f16 d) as one i64
1795   // value or (f16 a, f16 b) as one i32 value. This requires an InsertSubvector
1796   // be bitcast to a type wide enough to fit the sequence, be splatted, and then
1797   // be narrowed back to the original type.
1798   unsigned PatternWidth = IIScalableTy->getScalarSizeInBits() * Elts.size();
1799   unsigned PatternElementCount = IIScalableTy->getScalarSizeInBits() *
1800                                  IIScalableTy->getMinNumElements() /
1801                                  PatternWidth;
1802 
1803   IntegerType *WideTy = IC.Builder.getIntNTy(PatternWidth);
1804   auto *WideScalableTy = ScalableVectorType::get(WideTy, PatternElementCount);
1805   auto *WideShuffleMaskTy =
1806       ScalableVectorType::get(IC.Builder.getInt32Ty(), PatternElementCount);
1807 
1808   auto ZeroIdx = ConstantInt::get(IC.Builder.getInt64Ty(), APInt(64, 0));
1809   auto InsertSubvector = IC.Builder.CreateInsertVector(
1810       II.getType(), PoisonValue::get(II.getType()), InsertEltChain, ZeroIdx);
1811   auto WideBitcast =
1812       IC.Builder.CreateBitOrPointerCast(InsertSubvector, WideScalableTy);
1813   auto WideShuffleMask = ConstantAggregateZero::get(WideShuffleMaskTy);
1814   auto WideShuffle = IC.Builder.CreateShuffleVector(
1815       WideBitcast, PoisonValue::get(WideScalableTy), WideShuffleMask);
1816   auto NarrowBitcast =
1817       IC.Builder.CreateBitOrPointerCast(WideShuffle, II.getType());
1818 
1819   return IC.replaceInstUsesWith(II, NarrowBitcast);
1820 }
1821 
1822 static std::optional<Instruction *> instCombineMaxMinNM(InstCombiner &IC,
1823                                                         IntrinsicInst &II) {
1824   Value *A = II.getArgOperand(0);
1825   Value *B = II.getArgOperand(1);
1826   if (A == B)
1827     return IC.replaceInstUsesWith(II, A);
1828 
1829   return std::nullopt;
1830 }
1831 
1832 static std::optional<Instruction *> instCombineSVESrshl(InstCombiner &IC,
1833                                                         IntrinsicInst &II) {
1834   Value *Pred = II.getOperand(0);
1835   Value *Vec = II.getOperand(1);
1836   Value *Shift = II.getOperand(2);
1837 
1838   // Convert SRSHL into the simpler LSL intrinsic when fed by an ABS intrinsic.
1839   Value *AbsPred, *MergedValue;
1840   if (!match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_sqabs>(
1841                       m_Value(MergedValue), m_Value(AbsPred), m_Value())) &&
1842       !match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_abs>(
1843                       m_Value(MergedValue), m_Value(AbsPred), m_Value())))
1844 
1845     return std::nullopt;
1846 
1847   // Transform is valid if any of the following are true:
1848   // * The ABS merge value is an undef or non-negative
1849   // * The ABS predicate is all active
1850   // * The ABS predicate and the SRSHL predicates are the same
1851   if (!isa<UndefValue>(MergedValue) && !match(MergedValue, m_NonNegative()) &&
1852       AbsPred != Pred && !isAllActivePredicate(AbsPred))
1853     return std::nullopt;
1854 
1855   // Only valid when the shift amount is non-negative, otherwise the rounding
1856   // behaviour of SRSHL cannot be ignored.
1857   if (!match(Shift, m_NonNegative()))
1858     return std::nullopt;
1859 
1860   auto LSL = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_lsl,
1861                                         {II.getType()}, {Pred, Vec, Shift});
1862 
1863   return IC.replaceInstUsesWith(II, LSL);
1864 }
1865 
1866 std::optional<Instruction *>
1867 AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
1868                                      IntrinsicInst &II) const {
1869   Intrinsic::ID IID = II.getIntrinsicID();
1870   switch (IID) {
1871   default:
1872     break;
1873   case Intrinsic::aarch64_neon_fmaxnm:
1874   case Intrinsic::aarch64_neon_fminnm:
1875     return instCombineMaxMinNM(IC, II);
1876   case Intrinsic::aarch64_sve_convert_from_svbool:
1877     return instCombineConvertFromSVBool(IC, II);
1878   case Intrinsic::aarch64_sve_dup:
1879     return instCombineSVEDup(IC, II);
1880   case Intrinsic::aarch64_sve_dup_x:
1881     return instCombineSVEDupX(IC, II);
1882   case Intrinsic::aarch64_sve_cmpne:
1883   case Intrinsic::aarch64_sve_cmpne_wide:
1884     return instCombineSVECmpNE(IC, II);
1885   case Intrinsic::aarch64_sve_rdffr:
1886     return instCombineRDFFR(IC, II);
1887   case Intrinsic::aarch64_sve_lasta:
1888   case Intrinsic::aarch64_sve_lastb:
1889     return instCombineSVELast(IC, II);
1890   case Intrinsic::aarch64_sve_clasta_n:
1891   case Intrinsic::aarch64_sve_clastb_n:
1892     return instCombineSVECondLast(IC, II);
1893   case Intrinsic::aarch64_sve_cntd:
1894     return instCombineSVECntElts(IC, II, 2);
1895   case Intrinsic::aarch64_sve_cntw:
1896     return instCombineSVECntElts(IC, II, 4);
1897   case Intrinsic::aarch64_sve_cnth:
1898     return instCombineSVECntElts(IC, II, 8);
1899   case Intrinsic::aarch64_sve_cntb:
1900     return instCombineSVECntElts(IC, II, 16);
1901   case Intrinsic::aarch64_sve_ptest_any:
1902   case Intrinsic::aarch64_sve_ptest_first:
1903   case Intrinsic::aarch64_sve_ptest_last:
1904     return instCombineSVEPTest(IC, II);
1905   case Intrinsic::aarch64_sve_fabd:
1906     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fabd_u);
1907   case Intrinsic::aarch64_sve_fadd:
1908     return instCombineSVEVectorFAdd(IC, II);
1909   case Intrinsic::aarch64_sve_fadd_u:
1910     return instCombineSVEVectorFAddU(IC, II);
1911   case Intrinsic::aarch64_sve_fdiv:
1912     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fdiv_u);
1913   case Intrinsic::aarch64_sve_fmax:
1914     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmax_u);
1915   case Intrinsic::aarch64_sve_fmaxnm:
1916     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmaxnm_u);
1917   case Intrinsic::aarch64_sve_fmin:
1918     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmin_u);
1919   case Intrinsic::aarch64_sve_fminnm:
1920     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fminnm_u);
1921   case Intrinsic::aarch64_sve_fmla:
1922     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmla_u);
1923   case Intrinsic::aarch64_sve_fmls:
1924     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmls_u);
1925   case Intrinsic::aarch64_sve_fmul:
1926     if (auto II_U =
1927             instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmul_u))
1928       return II_U;
1929     return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_fmul_u);
1930   case Intrinsic::aarch64_sve_fmul_u:
1931     return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_fmul_u);
1932   case Intrinsic::aarch64_sve_fmulx:
1933     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmulx_u);
1934   case Intrinsic::aarch64_sve_fnmla:
1935     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fnmla_u);
1936   case Intrinsic::aarch64_sve_fnmls:
1937     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fnmls_u);
1938   case Intrinsic::aarch64_sve_fsub:
1939     return instCombineSVEVectorFSub(IC, II);
1940   case Intrinsic::aarch64_sve_fsub_u:
1941     return instCombineSVEVectorFSubU(IC, II);
1942   case Intrinsic::aarch64_sve_add:
1943     return instCombineSVEVectorAdd(IC, II);
1944   case Intrinsic::aarch64_sve_add_u:
1945     return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul_u,
1946                                              Intrinsic::aarch64_sve_mla_u>(
1947         IC, II, true);
1948   case Intrinsic::aarch64_sve_mla:
1949     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_mla_u);
1950   case Intrinsic::aarch64_sve_mls:
1951     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_mls_u);
1952   case Intrinsic::aarch64_sve_mul:
1953     if (auto II_U =
1954             instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_mul_u))
1955       return II_U;
1956     return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_mul_u);
1957   case Intrinsic::aarch64_sve_mul_u:
1958     return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_mul_u);
1959   case Intrinsic::aarch64_sve_sabd:
1960     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_sabd_u);
1961   case Intrinsic::aarch64_sve_smax:
1962     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_smax_u);
1963   case Intrinsic::aarch64_sve_smin:
1964     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_smin_u);
1965   case Intrinsic::aarch64_sve_smulh:
1966     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_smulh_u);
1967   case Intrinsic::aarch64_sve_sub:
1968     return instCombineSVEVectorSub(IC, II);
1969   case Intrinsic::aarch64_sve_sub_u:
1970     return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul_u,
1971                                              Intrinsic::aarch64_sve_mls_u>(
1972         IC, II, true);
1973   case Intrinsic::aarch64_sve_uabd:
1974     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_uabd_u);
1975   case Intrinsic::aarch64_sve_umax:
1976     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_umax_u);
1977   case Intrinsic::aarch64_sve_umin:
1978     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_umin_u);
1979   case Intrinsic::aarch64_sve_umulh:
1980     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_umulh_u);
1981   case Intrinsic::aarch64_sve_asr:
1982     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_asr_u);
1983   case Intrinsic::aarch64_sve_lsl:
1984     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_lsl_u);
1985   case Intrinsic::aarch64_sve_lsr:
1986     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_lsr_u);
1987   case Intrinsic::aarch64_sve_and:
1988     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_and_u);
1989   case Intrinsic::aarch64_sve_bic:
1990     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_bic_u);
1991   case Intrinsic::aarch64_sve_eor:
1992     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_eor_u);
1993   case Intrinsic::aarch64_sve_orr:
1994     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_orr_u);
1995   case Intrinsic::aarch64_sve_sqsub:
1996     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_sqsub_u);
1997   case Intrinsic::aarch64_sve_uqsub:
1998     return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_uqsub_u);
1999   case Intrinsic::aarch64_sve_tbl:
2000     return instCombineSVETBL(IC, II);
2001   case Intrinsic::aarch64_sve_uunpkhi:
2002   case Intrinsic::aarch64_sve_uunpklo:
2003   case Intrinsic::aarch64_sve_sunpkhi:
2004   case Intrinsic::aarch64_sve_sunpklo:
2005     return instCombineSVEUnpack(IC, II);
2006   case Intrinsic::aarch64_sve_zip1:
2007   case Intrinsic::aarch64_sve_zip2:
2008     return instCombineSVEZip(IC, II);
2009   case Intrinsic::aarch64_sve_ld1_gather_index:
2010     return instCombineLD1GatherIndex(IC, II);
2011   case Intrinsic::aarch64_sve_st1_scatter_index:
2012     return instCombineST1ScatterIndex(IC, II);
2013   case Intrinsic::aarch64_sve_ld1:
2014     return instCombineSVELD1(IC, II, DL);
2015   case Intrinsic::aarch64_sve_st1:
2016     return instCombineSVEST1(IC, II, DL);
2017   case Intrinsic::aarch64_sve_sdiv:
2018     return instCombineSVESDIV(IC, II);
2019   case Intrinsic::aarch64_sve_sel:
2020     return instCombineSVESel(IC, II);
2021   case Intrinsic::aarch64_sve_srshl:
2022     return instCombineSVESrshl(IC, II);
2023   case Intrinsic::aarch64_sve_dupq_lane:
2024     return instCombineSVEDupqLane(IC, II);
2025   }
2026 
2027   return std::nullopt;
2028 }
2029 
2030 std::optional<Value *> AArch64TTIImpl::simplifyDemandedVectorEltsIntrinsic(
2031     InstCombiner &IC, IntrinsicInst &II, APInt OrigDemandedElts,
2032     APInt &UndefElts, APInt &UndefElts2, APInt &UndefElts3,
2033     std::function<void(Instruction *, unsigned, APInt, APInt &)>
2034         SimplifyAndSetOp) const {
2035   switch (II.getIntrinsicID()) {
2036   default:
2037     break;
2038   case Intrinsic::aarch64_neon_fcvtxn:
2039   case Intrinsic::aarch64_neon_rshrn:
2040   case Intrinsic::aarch64_neon_sqrshrn:
2041   case Intrinsic::aarch64_neon_sqrshrun:
2042   case Intrinsic::aarch64_neon_sqshrn:
2043   case Intrinsic::aarch64_neon_sqshrun:
2044   case Intrinsic::aarch64_neon_sqxtn:
2045   case Intrinsic::aarch64_neon_sqxtun:
2046   case Intrinsic::aarch64_neon_uqrshrn:
2047   case Intrinsic::aarch64_neon_uqshrn:
2048   case Intrinsic::aarch64_neon_uqxtn:
2049     SimplifyAndSetOp(&II, 0, OrigDemandedElts, UndefElts);
2050     break;
2051   }
2052 
2053   return std::nullopt;
2054 }
2055 
2056 TypeSize
2057 AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
2058   switch (K) {
2059   case TargetTransformInfo::RGK_Scalar:
2060     return TypeSize::getFixed(64);
2061   case TargetTransformInfo::RGK_FixedWidthVector:
2062     if (!ST->isNeonAvailable() && !EnableFixedwidthAutovecInStreamingMode)
2063       return TypeSize::getFixed(0);
2064 
2065     if (ST->hasSVE())
2066       return TypeSize::getFixed(
2067           std::max(ST->getMinSVEVectorSizeInBits(), 128u));
2068 
2069     return TypeSize::getFixed(ST->hasNEON() ? 128 : 0);
2070   case TargetTransformInfo::RGK_ScalableVector:
2071     if (!ST->isSVEAvailable() && !EnableScalableAutovecInStreamingMode)
2072       return TypeSize::getScalable(0);
2073 
2074     return TypeSize::getScalable(ST->hasSVE() ? 128 : 0);
2075   }
2076   llvm_unreachable("Unsupported register kind");
2077 }
2078 
2079 bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
2080                                            ArrayRef<const Value *> Args,
2081                                            Type *SrcOverrideTy) {
2082   // A helper that returns a vector type from the given type. The number of
2083   // elements in type Ty determines the vector width.
2084   auto toVectorTy = [&](Type *ArgTy) {
2085     return VectorType::get(ArgTy->getScalarType(),
2086                            cast<VectorType>(DstTy)->getElementCount());
2087   };
2088 
2089   // Exit early if DstTy is not a vector type whose elements are one of [i16,
2090   // i32, i64]. SVE doesn't generally have the same set of instructions to
2091   // perform an extend with the add/sub/mul. There are SMULLB style
2092   // instructions, but they operate on top/bottom, requiring some sort of lane
2093   // interleaving to be used with zext/sext.
2094   unsigned DstEltSize = DstTy->getScalarSizeInBits();
2095   if (!useNeonVector(DstTy) || Args.size() != 2 ||
2096       (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64))
2097     return false;
2098 
2099   // Determine if the operation has a widening variant. We consider both the
2100   // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
2101   // instructions.
2102   //
2103   // TODO: Add additional widening operations (e.g., shl, etc.) once we
2104   //       verify that their extending operands are eliminated during code
2105   //       generation.
2106   Type *SrcTy = SrcOverrideTy;
2107   switch (Opcode) {
2108   case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
2109   case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
2110     // The second operand needs to be an extend
2111     if (isa<SExtInst>(Args[1]) || isa<ZExtInst>(Args[1])) {
2112       if (!SrcTy)
2113         SrcTy =
2114             toVectorTy(cast<Instruction>(Args[1])->getOperand(0)->getType());
2115     } else
2116       return false;
2117     break;
2118   case Instruction::Mul: { // SMULL(2), UMULL(2)
2119     // Both operands need to be extends of the same type.
2120     if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) ||
2121         (isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) {
2122       if (!SrcTy)
2123         SrcTy =
2124             toVectorTy(cast<Instruction>(Args[0])->getOperand(0)->getType());
2125     } else if (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1])) {
2126       // If one of the operands is a Zext and the other has enough zero bits to
2127       // be treated as unsigned, we can still general a umull, meaning the zext
2128       // is free.
2129       KnownBits Known =
2130           computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL);
2131       if (Args[0]->getType()->getScalarSizeInBits() -
2132               Known.Zero.countLeadingOnes() >
2133           DstTy->getScalarSizeInBits() / 2)
2134         return false;
2135       if (!SrcTy)
2136         SrcTy = toVectorTy(Type::getIntNTy(DstTy->getContext(),
2137                                            DstTy->getScalarSizeInBits() / 2));
2138     } else
2139       return false;
2140     break;
2141   }
2142   default:
2143     return false;
2144   }
2145 
2146   // Legalize the destination type and ensure it can be used in a widening
2147   // operation.
2148   auto DstTyL = getTypeLegalizationCost(DstTy);
2149   if (!DstTyL.second.isVector() || DstEltSize != DstTy->getScalarSizeInBits())
2150     return false;
2151 
2152   // Legalize the source type and ensure it can be used in a widening
2153   // operation.
2154   assert(SrcTy && "Expected some SrcTy");
2155   auto SrcTyL = getTypeLegalizationCost(SrcTy);
2156   unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits();
2157   if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits())
2158     return false;
2159 
2160   // Get the total number of vector elements in the legalized types.
2161   InstructionCost NumDstEls =
2162       DstTyL.first * DstTyL.second.getVectorMinNumElements();
2163   InstructionCost NumSrcEls =
2164       SrcTyL.first * SrcTyL.second.getVectorMinNumElements();
2165 
2166   // Return true if the legalized types have the same number of vector elements
2167   // and the destination element type size is twice that of the source type.
2168   return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize;
2169 }
2170 
2171 // s/urhadd instructions implement the following pattern, making the
2172 // extends free:
2173 //   %x = add ((zext i8 -> i16), 1)
2174 //   %y = (zext i8 -> i16)
2175 //   trunc i16 (lshr (add %x, %y), 1) -> i8
2176 //
2177 bool AArch64TTIImpl::isExtPartOfAvgExpr(const Instruction *ExtUser, Type *Dst,
2178                                         Type *Src) {
2179   // The source should be a legal vector type.
2180   if (!Src->isVectorTy() || !TLI->isTypeLegal(TLI->getValueType(DL, Src)) ||
2181       (Src->isScalableTy() && !ST->hasSVE2()))
2182     return false;
2183 
2184   if (ExtUser->getOpcode() != Instruction::Add || !ExtUser->hasOneUse())
2185     return false;
2186 
2187   // Look for trunc/shl/add before trying to match the pattern.
2188   const Instruction *Add = ExtUser;
2189   auto *AddUser =
2190       dyn_cast_or_null<Instruction>(Add->getUniqueUndroppableUser());
2191   if (AddUser && AddUser->getOpcode() == Instruction::Add)
2192     Add = AddUser;
2193 
2194   auto *Shr = dyn_cast_or_null<Instruction>(Add->getUniqueUndroppableUser());
2195   if (!Shr || Shr->getOpcode() != Instruction::LShr)
2196     return false;
2197 
2198   auto *Trunc = dyn_cast_or_null<Instruction>(Shr->getUniqueUndroppableUser());
2199   if (!Trunc || Trunc->getOpcode() != Instruction::Trunc ||
2200       Src->getScalarSizeInBits() !=
2201           cast<CastInst>(Trunc)->getDestTy()->getScalarSizeInBits())
2202     return false;
2203 
2204   // Try to match the whole pattern. Ext could be either the first or second
2205   // m_ZExtOrSExt matched.
2206   Instruction *Ex1, *Ex2;
2207   if (!(match(Add, m_c_Add(m_Instruction(Ex1),
2208                            m_c_Add(m_Instruction(Ex2), m_SpecificInt(1))))))
2209     return false;
2210 
2211   // Ensure both extends are of the same type
2212   if (match(Ex1, m_ZExtOrSExt(m_Value())) &&
2213       Ex1->getOpcode() == Ex2->getOpcode())
2214     return true;
2215 
2216   return false;
2217 }
2218 
2219 InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
2220                                                  Type *Src,
2221                                                  TTI::CastContextHint CCH,
2222                                                  TTI::TargetCostKind CostKind,
2223                                                  const Instruction *I) {
2224   int ISD = TLI->InstructionOpcodeToISD(Opcode);
2225   assert(ISD && "Invalid opcode");
2226   // If the cast is observable, and it is used by a widening instruction (e.g.,
2227   // uaddl, saddw, etc.), it may be free.
2228   if (I && I->hasOneUser()) {
2229     auto *SingleUser = cast<Instruction>(*I->user_begin());
2230     SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
2231     if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands, Src)) {
2232       // For adds only count the second operand as free if both operands are
2233       // extends but not the same operation. (i.e both operands are not free in
2234       // add(sext, zext)).
2235       if (SingleUser->getOpcode() == Instruction::Add) {
2236         if (I == SingleUser->getOperand(1) ||
2237             (isa<CastInst>(SingleUser->getOperand(1)) &&
2238              cast<CastInst>(SingleUser->getOperand(1))->getOpcode() == Opcode))
2239           return 0;
2240       } else // Others are free so long as isWideningInstruction returned true.
2241         return 0;
2242     }
2243 
2244     // The cast will be free for the s/urhadd instructions
2245     if ((isa<ZExtInst>(I) || isa<SExtInst>(I)) &&
2246         isExtPartOfAvgExpr(SingleUser, Dst, Src))
2247       return 0;
2248   }
2249 
2250   // TODO: Allow non-throughput costs that aren't binary.
2251   auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost {
2252     if (CostKind != TTI::TCK_RecipThroughput)
2253       return Cost == 0 ? 0 : 1;
2254     return Cost;
2255   };
2256 
2257   EVT SrcTy = TLI->getValueType(DL, Src);
2258   EVT DstTy = TLI->getValueType(DL, Dst);
2259 
2260   if (!SrcTy.isSimple() || !DstTy.isSimple())
2261     return AdjustCost(
2262         BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
2263 
2264   static const TypeConversionCostTblEntry
2265   ConversionTbl[] = {
2266     { ISD::TRUNCATE, MVT::v2i8,   MVT::v2i64,  1},  // xtn
2267     { ISD::TRUNCATE, MVT::v2i16,  MVT::v2i64,  1},  // xtn
2268     { ISD::TRUNCATE, MVT::v2i32,  MVT::v2i64,  1},  // xtn
2269     { ISD::TRUNCATE, MVT::v4i8,   MVT::v4i32,  1},  // xtn
2270     { ISD::TRUNCATE, MVT::v4i8,   MVT::v4i64,  3},  // 2 xtn + 1 uzp1
2271     { ISD::TRUNCATE, MVT::v4i16,  MVT::v4i32,  1},  // xtn
2272     { ISD::TRUNCATE, MVT::v4i16,  MVT::v4i64,  2},  // 1 uzp1 + 1 xtn
2273     { ISD::TRUNCATE, MVT::v4i32,  MVT::v4i64,  1},  // 1 uzp1
2274     { ISD::TRUNCATE, MVT::v8i8,   MVT::v8i16,  1},  // 1 xtn
2275     { ISD::TRUNCATE, MVT::v8i8,   MVT::v8i32,  2},  // 1 uzp1 + 1 xtn
2276     { ISD::TRUNCATE, MVT::v8i8,   MVT::v8i64,  4},  // 3 x uzp1 + xtn
2277     { ISD::TRUNCATE, MVT::v8i16,  MVT::v8i32,  1},  // 1 uzp1
2278     { ISD::TRUNCATE, MVT::v8i16,  MVT::v8i64,  3},  // 3 x uzp1
2279     { ISD::TRUNCATE, MVT::v8i32,  MVT::v8i64,  2},  // 2 x uzp1
2280     { ISD::TRUNCATE, MVT::v16i8,  MVT::v16i16, 1},  // uzp1
2281     { ISD::TRUNCATE, MVT::v16i8,  MVT::v16i32, 3},  // (2 + 1) x uzp1
2282     { ISD::TRUNCATE, MVT::v16i8,  MVT::v16i64, 7},  // (4 + 2 + 1) x uzp1
2283     { ISD::TRUNCATE, MVT::v16i16, MVT::v16i32, 2},  // 2 x uzp1
2284     { ISD::TRUNCATE, MVT::v16i16, MVT::v16i64, 6},  // (4 + 2) x uzp1
2285     { ISD::TRUNCATE, MVT::v16i32, MVT::v16i64, 4},  // 4 x uzp1
2286 
2287     // Truncations on nxvmiN
2288     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i16, 1 },
2289     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i32, 1 },
2290     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i64, 1 },
2291     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i16, 1 },
2292     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i32, 1 },
2293     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i64, 2 },
2294     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i16, 1 },
2295     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i32, 3 },
2296     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i64, 5 },
2297     { ISD::TRUNCATE, MVT::nxv16i1, MVT::nxv16i8, 1 },
2298     { ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i32, 1 },
2299     { ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 1 },
2300     { ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 1 },
2301     { ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 2 },
2302     { ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 3 },
2303     { ISD::TRUNCATE, MVT::nxv8i32, MVT::nxv8i64, 6 },
2304 
2305     // The number of shll instructions for the extension.
2306     { ISD::SIGN_EXTEND, MVT::v4i64,  MVT::v4i16, 3 },
2307     { ISD::ZERO_EXTEND, MVT::v4i64,  MVT::v4i16, 3 },
2308     { ISD::SIGN_EXTEND, MVT::v4i64,  MVT::v4i32, 2 },
2309     { ISD::ZERO_EXTEND, MVT::v4i64,  MVT::v4i32, 2 },
2310     { ISD::SIGN_EXTEND, MVT::v8i32,  MVT::v8i8,  3 },
2311     { ISD::ZERO_EXTEND, MVT::v8i32,  MVT::v8i8,  3 },
2312     { ISD::SIGN_EXTEND, MVT::v8i32,  MVT::v8i16, 2 },
2313     { ISD::ZERO_EXTEND, MVT::v8i32,  MVT::v8i16, 2 },
2314     { ISD::SIGN_EXTEND, MVT::v8i64,  MVT::v8i8,  7 },
2315     { ISD::ZERO_EXTEND, MVT::v8i64,  MVT::v8i8,  7 },
2316     { ISD::SIGN_EXTEND, MVT::v8i64,  MVT::v8i16, 6 },
2317     { ISD::ZERO_EXTEND, MVT::v8i64,  MVT::v8i16, 6 },
2318     { ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
2319     { ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
2320     { ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
2321     { ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
2322 
2323     // LowerVectorINT_TO_FP:
2324     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
2325     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
2326     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
2327     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
2328     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
2329     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
2330 
2331     // Complex: to v2f32
2332     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i8,  3 },
2333     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
2334     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
2335     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i8,  3 },
2336     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
2337     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
2338 
2339     // Complex: to v4f32
2340     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i8,  4 },
2341     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
2342     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8,  3 },
2343     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
2344 
2345     // Complex: to v8f32
2346     { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
2347     { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
2348     { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
2349     { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
2350 
2351     // Complex: to v16f32
2352     { ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
2353     { ISD::UINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
2354 
2355     // Complex: to v2f64
2356     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
2357     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
2358     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
2359     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
2360     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
2361     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
2362 
2363     // Complex: to v4f64
2364     { ISD::SINT_TO_FP, MVT::v4f64, MVT::v4i32,  4 },
2365     { ISD::UINT_TO_FP, MVT::v4f64, MVT::v4i32,  4 },
2366 
2367     // LowerVectorFP_TO_INT
2368     { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f32, 1 },
2369     { ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f32, 1 },
2370     { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f64, 1 },
2371     { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f32, 1 },
2372     { ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f32, 1 },
2373     { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f64, 1 },
2374 
2375     // Complex, from v2f32: legal type is v2i32 (no cost) or v2i64 (1 ext).
2376     { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f32, 2 },
2377     { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f32, 1 },
2378     { ISD::FP_TO_SINT, MVT::v2i8,  MVT::v2f32, 1 },
2379     { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f32, 2 },
2380     { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f32, 1 },
2381     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f32, 1 },
2382 
2383     // Complex, from v4f32: legal type is v4i16, 1 narrowing => ~2
2384     { ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f32, 2 },
2385     { ISD::FP_TO_SINT, MVT::v4i8,  MVT::v4f32, 2 },
2386     { ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 2 },
2387     { ISD::FP_TO_UINT, MVT::v4i8,  MVT::v4f32, 2 },
2388 
2389     // Complex, from nxv2f32.
2390     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
2391     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
2392     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
2393     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f32, 1 },
2394     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
2395     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
2396     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
2397     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f32, 1 },
2398 
2399     // Complex, from v2f64: legal type is v2i32, 1 narrowing => ~2.
2400     { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f64, 2 },
2401     { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f64, 2 },
2402     { ISD::FP_TO_SINT, MVT::v2i8,  MVT::v2f64, 2 },
2403     { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f64, 2 },
2404     { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f64, 2 },
2405     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f64, 2 },
2406 
2407     // Complex, from nxv2f64.
2408     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
2409     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
2410     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
2411     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f64, 1 },
2412     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
2413     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
2414     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
2415     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f64, 1 },
2416 
2417     // Complex, from nxv4f32.
2418     { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
2419     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
2420     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
2421     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f32, 1 },
2422     { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
2423     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
2424     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
2425     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f32, 1 },
2426 
2427     // Complex, from nxv8f64. Illegal -> illegal conversions not required.
2428     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
2429     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f64, 7 },
2430     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
2431     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f64, 7 },
2432 
2433     // Complex, from nxv4f64. Illegal -> illegal conversions not required.
2434     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
2435     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
2436     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f64, 3 },
2437     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
2438     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
2439     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f64, 3 },
2440 
2441     // Complex, from nxv8f32. Illegal -> illegal conversions not required.
2442     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
2443     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f32, 3 },
2444     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
2445     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f32, 3 },
2446 
2447     // Complex, from nxv8f16.
2448     { ISD::FP_TO_SINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
2449     { ISD::FP_TO_SINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
2450     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
2451     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f16, 1 },
2452     { ISD::FP_TO_UINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
2453     { ISD::FP_TO_UINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
2454     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
2455     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f16, 1 },
2456 
2457     // Complex, from nxv4f16.
2458     { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
2459     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
2460     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
2461     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f16, 1 },
2462     { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
2463     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
2464     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
2465     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f16, 1 },
2466 
2467     // Complex, from nxv2f16.
2468     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
2469     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
2470     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
2471     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f16, 1 },
2472     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
2473     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
2474     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
2475     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f16, 1 },
2476 
2477     // Truncate from nxvmf32 to nxvmf16.
2478     { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f32, 1 },
2479     { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1 },
2480     { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3 },
2481 
2482     // Truncate from nxvmf64 to nxvmf16.
2483     { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1 },
2484     { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3 },
2485     { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7 },
2486 
2487     // Truncate from nxvmf64 to nxvmf32.
2488     { ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1 },
2489     { ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3 },
2490     { ISD::FP_ROUND, MVT::nxv8f32, MVT::nxv8f64, 6 },
2491 
2492     // Extend from nxvmf16 to nxvmf32.
2493     { ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2f16, 1},
2494     { ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1},
2495     { ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2},
2496 
2497     // Extend from nxvmf16 to nxvmf64.
2498     { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1},
2499     { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2},
2500     { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4},
2501 
2502     // Extend from nxvmf32 to nxvmf64.
2503     { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1},
2504     { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2},
2505     { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f32, 6},
2506 
2507     // Bitcasts from float to integer
2508     { ISD::BITCAST, MVT::nxv2f16, MVT::nxv2i16, 0 },
2509     { ISD::BITCAST, MVT::nxv4f16, MVT::nxv4i16, 0 },
2510     { ISD::BITCAST, MVT::nxv2f32, MVT::nxv2i32, 0 },
2511 
2512     // Bitcasts from integer to float
2513     { ISD::BITCAST, MVT::nxv2i16, MVT::nxv2f16, 0 },
2514     { ISD::BITCAST, MVT::nxv4i16, MVT::nxv4f16, 0 },
2515     { ISD::BITCAST, MVT::nxv2i32, MVT::nxv2f32, 0 },
2516 
2517     // Add cost for extending to illegal -too wide- scalable vectors.
2518     // zero/sign extend are implemented by multiple unpack operations,
2519     // where each operation has a cost of 1.
2520     { ISD::ZERO_EXTEND, MVT::nxv16i16, MVT::nxv16i8, 2},
2521     { ISD::ZERO_EXTEND, MVT::nxv16i32, MVT::nxv16i8, 6},
2522     { ISD::ZERO_EXTEND, MVT::nxv16i64, MVT::nxv16i8, 14},
2523     { ISD::ZERO_EXTEND, MVT::nxv8i32, MVT::nxv8i16, 2},
2524     { ISD::ZERO_EXTEND, MVT::nxv8i64, MVT::nxv8i16, 6},
2525     { ISD::ZERO_EXTEND, MVT::nxv4i64, MVT::nxv4i32, 2},
2526 
2527     { ISD::SIGN_EXTEND, MVT::nxv16i16, MVT::nxv16i8, 2},
2528     { ISD::SIGN_EXTEND, MVT::nxv16i32, MVT::nxv16i8, 6},
2529     { ISD::SIGN_EXTEND, MVT::nxv16i64, MVT::nxv16i8, 14},
2530     { ISD::SIGN_EXTEND, MVT::nxv8i32, MVT::nxv8i16, 2},
2531     { ISD::SIGN_EXTEND, MVT::nxv8i64, MVT::nxv8i16, 6},
2532     { ISD::SIGN_EXTEND, MVT::nxv4i64, MVT::nxv4i32, 2},
2533   };
2534 
2535   // We have to estimate a cost of fixed length operation upon
2536   // SVE registers(operations) with the number of registers required
2537   // for a fixed type to be represented upon SVE registers.
2538   EVT WiderTy = SrcTy.bitsGT(DstTy) ? SrcTy : DstTy;
2539   if (SrcTy.isFixedLengthVector() && DstTy.isFixedLengthVector() &&
2540       SrcTy.getVectorNumElements() == DstTy.getVectorNumElements() &&
2541       ST->useSVEForFixedLengthVectors(WiderTy)) {
2542     std::pair<InstructionCost, MVT> LT =
2543         getTypeLegalizationCost(WiderTy.getTypeForEVT(Dst->getContext()));
2544     unsigned NumElements = AArch64::SVEBitsPerBlock /
2545                            LT.second.getVectorElementType().getSizeInBits();
2546     return AdjustCost(
2547         LT.first *
2548         getCastInstrCost(
2549             Opcode, ScalableVectorType::get(Dst->getScalarType(), NumElements),
2550             ScalableVectorType::get(Src->getScalarType(), NumElements), CCH,
2551             CostKind, I));
2552   }
2553 
2554   if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD,
2555                                                  DstTy.getSimpleVT(),
2556                                                  SrcTy.getSimpleVT()))
2557     return AdjustCost(Entry->Cost);
2558 
2559   static const TypeConversionCostTblEntry FP16Tbl[] = {
2560       {ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f16, 1}, // fcvtzs
2561       {ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f16, 1},
2562       {ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f16, 1}, // fcvtzs
2563       {ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f16, 1},
2564       {ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f16, 2}, // fcvtl+fcvtzs
2565       {ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f16, 2},
2566       {ISD::FP_TO_SINT, MVT::v8i8, MVT::v8f16, 2}, // fcvtzs+xtn
2567       {ISD::FP_TO_UINT, MVT::v8i8, MVT::v8f16, 2},
2568       {ISD::FP_TO_SINT, MVT::v8i16, MVT::v8f16, 1}, // fcvtzs
2569       {ISD::FP_TO_UINT, MVT::v8i16, MVT::v8f16, 1},
2570       {ISD::FP_TO_SINT, MVT::v8i32, MVT::v8f16, 4}, // 2*fcvtl+2*fcvtzs
2571       {ISD::FP_TO_UINT, MVT::v8i32, MVT::v8f16, 4},
2572       {ISD::FP_TO_SINT, MVT::v16i8, MVT::v16f16, 3}, // 2*fcvtzs+xtn
2573       {ISD::FP_TO_UINT, MVT::v16i8, MVT::v16f16, 3},
2574       {ISD::FP_TO_SINT, MVT::v16i16, MVT::v16f16, 2}, // 2*fcvtzs
2575       {ISD::FP_TO_UINT, MVT::v16i16, MVT::v16f16, 2},
2576       {ISD::FP_TO_SINT, MVT::v16i32, MVT::v16f16, 8}, // 4*fcvtl+4*fcvtzs
2577       {ISD::FP_TO_UINT, MVT::v16i32, MVT::v16f16, 8},
2578       {ISD::UINT_TO_FP, MVT::v8f16, MVT::v8i8, 2},   // ushll + ucvtf
2579       {ISD::SINT_TO_FP, MVT::v8f16, MVT::v8i8, 2},   // sshll + scvtf
2580       {ISD::UINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * ushl(2) + 2 * ucvtf
2581       {ISD::SINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * sshl(2) + 2 * scvtf
2582   };
2583 
2584   if (ST->hasFullFP16())
2585     if (const auto *Entry = ConvertCostTableLookup(
2586             FP16Tbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
2587       return AdjustCost(Entry->Cost);
2588 
2589   if ((ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND) &&
2590       CCH == TTI::CastContextHint::Masked && ST->hasSVEorSME() &&
2591       TLI->getTypeAction(Src->getContext(), SrcTy) ==
2592           TargetLowering::TypePromoteInteger &&
2593       TLI->getTypeAction(Dst->getContext(), DstTy) ==
2594           TargetLowering::TypeSplitVector) {
2595     // The standard behaviour in the backend for these cases is to split the
2596     // extend up into two parts:
2597     //  1. Perform an extending load or masked load up to the legal type.
2598     //  2. Extend the loaded data to the final type.
2599     std::pair<InstructionCost, MVT> SrcLT = getTypeLegalizationCost(Src);
2600     Type *LegalTy = EVT(SrcLT.second).getTypeForEVT(Src->getContext());
2601     InstructionCost Part1 = AArch64TTIImpl::getCastInstrCost(
2602         Opcode, LegalTy, Src, CCH, CostKind, I);
2603     InstructionCost Part2 = AArch64TTIImpl::getCastInstrCost(
2604         Opcode, Dst, LegalTy, TTI::CastContextHint::None, CostKind, I);
2605     return Part1 + Part2;
2606   }
2607 
2608   // The BasicTTIImpl version only deals with CCH==TTI::CastContextHint::Normal,
2609   // but we also want to include the TTI::CastContextHint::Masked case too.
2610   if ((ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND) &&
2611       CCH == TTI::CastContextHint::Masked && ST->hasSVEorSME() &&
2612       TLI->isTypeLegal(DstTy))
2613     CCH = TTI::CastContextHint::Normal;
2614 
2615   return AdjustCost(
2616       BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
2617 }
2618 
2619 InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode,
2620                                                          Type *Dst,
2621                                                          VectorType *VecTy,
2622                                                          unsigned Index) {
2623 
2624   // Make sure we were given a valid extend opcode.
2625   assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) &&
2626          "Invalid opcode");
2627 
2628   // We are extending an element we extract from a vector, so the source type
2629   // of the extend is the element type of the vector.
2630   auto *Src = VecTy->getElementType();
2631 
2632   // Sign- and zero-extends are for integer types only.
2633   assert(isa<IntegerType>(Dst) && isa<IntegerType>(Src) && "Invalid type");
2634 
2635   // Get the cost for the extract. We compute the cost (if any) for the extend
2636   // below.
2637   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2638   InstructionCost Cost = getVectorInstrCost(Instruction::ExtractElement, VecTy,
2639                                             CostKind, Index, nullptr, nullptr);
2640 
2641   // Legalize the types.
2642   auto VecLT = getTypeLegalizationCost(VecTy);
2643   auto DstVT = TLI->getValueType(DL, Dst);
2644   auto SrcVT = TLI->getValueType(DL, Src);
2645 
2646   // If the resulting type is still a vector and the destination type is legal,
2647   // we may get the extension for free. If not, get the default cost for the
2648   // extend.
2649   if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT))
2650     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
2651                                    CostKind);
2652 
2653   // The destination type should be larger than the element type. If not, get
2654   // the default cost for the extend.
2655   if (DstVT.getFixedSizeInBits() < SrcVT.getFixedSizeInBits())
2656     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
2657                                    CostKind);
2658 
2659   switch (Opcode) {
2660   default:
2661     llvm_unreachable("Opcode should be either SExt or ZExt");
2662 
2663   // For sign-extends, we only need a smov, which performs the extension
2664   // automatically.
2665   case Instruction::SExt:
2666     return Cost;
2667 
2668   // For zero-extends, the extend is performed automatically by a umov unless
2669   // the destination type is i64 and the element type is i8 or i16.
2670   case Instruction::ZExt:
2671     if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u)
2672       return Cost;
2673   }
2674 
2675   // If we are unable to perform the extend for free, get the default cost.
2676   return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
2677                                  CostKind);
2678 }
2679 
2680 InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
2681                                                TTI::TargetCostKind CostKind,
2682                                                const Instruction *I) {
2683   if (CostKind != TTI::TCK_RecipThroughput)
2684     return Opcode == Instruction::PHI ? 0 : 1;
2685   assert(CostKind == TTI::TCK_RecipThroughput && "unexpected CostKind");
2686   // Branches are assumed to be predicted.
2687   return 0;
2688 }
2689 
2690 InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
2691                                                          Type *Val,
2692                                                          unsigned Index,
2693                                                          bool HasRealUse) {
2694   assert(Val->isVectorTy() && "This must be a vector type");
2695 
2696   if (Index != -1U) {
2697     // Legalize the type.
2698     std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Val);
2699 
2700     // This type is legalized to a scalar type.
2701     if (!LT.second.isVector())
2702       return 0;
2703 
2704     // The type may be split. For fixed-width vectors we can normalize the
2705     // index to the new type.
2706     if (LT.second.isFixedLengthVector()) {
2707       unsigned Width = LT.second.getVectorNumElements();
2708       Index = Index % Width;
2709     }
2710 
2711     // The element at index zero is already inside the vector.
2712     // - For a physical (HasRealUse==true) insert-element or extract-element
2713     // instruction that extracts integers, an explicit FPR -> GPR move is
2714     // needed. So it has non-zero cost.
2715     // - For the rest of cases (virtual instruction or element type is float),
2716     // consider the instruction free.
2717     if (Index == 0 && (!HasRealUse || !Val->getScalarType()->isIntegerTy()))
2718       return 0;
2719 
2720     // This is recognising a LD1 single-element structure to one lane of one
2721     // register instruction. I.e., if this is an `insertelement` instruction,
2722     // and its second operand is a load, then we will generate a LD1, which
2723     // are expensive instructions.
2724     if (I && dyn_cast<LoadInst>(I->getOperand(1)))
2725       return ST->getVectorInsertExtractBaseCost() + 1;
2726 
2727     // i1 inserts and extract will include an extra cset or cmp of the vector
2728     // value. Increase the cost by 1 to account.
2729     if (Val->getScalarSizeInBits() == 1)
2730       return ST->getVectorInsertExtractBaseCost() + 1;
2731 
2732     // FIXME:
2733     // If the extract-element and insert-element instructions could be
2734     // simplified away (e.g., could be combined into users by looking at use-def
2735     // context), they have no cost. This is not done in the first place for
2736     // compile-time considerations.
2737   }
2738 
2739   // All other insert/extracts cost this much.
2740   return ST->getVectorInsertExtractBaseCost();
2741 }
2742 
2743 InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
2744                                                    TTI::TargetCostKind CostKind,
2745                                                    unsigned Index, Value *Op0,
2746                                                    Value *Op1) {
2747   bool HasRealUse =
2748       Opcode == Instruction::InsertElement && Op0 && !isa<UndefValue>(Op0);
2749   return getVectorInstrCostHelper(nullptr, Val, Index, HasRealUse);
2750 }
2751 
2752 InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,
2753                                                    Type *Val,
2754                                                    TTI::TargetCostKind CostKind,
2755                                                    unsigned Index) {
2756   return getVectorInstrCostHelper(&I, Val, Index, true /* HasRealUse */);
2757 }
2758 
2759 InstructionCost AArch64TTIImpl::getScalarizationOverhead(
2760     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
2761     TTI::TargetCostKind CostKind) {
2762   if (isa<ScalableVectorType>(Ty))
2763     return InstructionCost::getInvalid();
2764   if (Ty->getElementType()->isFloatingPointTy())
2765     return BaseT::getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
2766                                            CostKind);
2767   return DemandedElts.popcount() * (Insert + Extract) *
2768          ST->getVectorInsertExtractBaseCost();
2769 }
2770 
2771 InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
2772     unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
2773     TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info,
2774     ArrayRef<const Value *> Args,
2775     const Instruction *CxtI) {
2776 
2777   // TODO: Handle more cost kinds.
2778   if (CostKind != TTI::TCK_RecipThroughput)
2779     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
2780                                          Op2Info, Args, CxtI);
2781 
2782   // Legalize the type.
2783   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
2784   int ISD = TLI->InstructionOpcodeToISD(Opcode);
2785 
2786   switch (ISD) {
2787   default:
2788     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
2789                                          Op2Info);
2790   case ISD::SDIV:
2791     if (Op2Info.isConstant() && Op2Info.isUniform() && Op2Info.isPowerOf2()) {
2792       // On AArch64, scalar signed division by constants power-of-two are
2793       // normally expanded to the sequence ADD + CMP + SELECT + SRA.
2794       // The OperandValue properties many not be same as that of previous
2795       // operation; conservatively assume OP_None.
2796       InstructionCost Cost = getArithmeticInstrCost(
2797           Instruction::Add, Ty, CostKind,
2798           Op1Info.getNoProps(), Op2Info.getNoProps());
2799       Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind,
2800                                      Op1Info.getNoProps(), Op2Info.getNoProps());
2801       Cost += getArithmeticInstrCost(
2802           Instruction::Select, Ty, CostKind,
2803           Op1Info.getNoProps(), Op2Info.getNoProps());
2804       Cost += getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
2805                                      Op1Info.getNoProps(), Op2Info.getNoProps());
2806       return Cost;
2807     }
2808     [[fallthrough]];
2809   case ISD::UDIV: {
2810     if (Op2Info.isConstant() && Op2Info.isUniform()) {
2811       auto VT = TLI->getValueType(DL, Ty);
2812       if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT)) {
2813         // Vector signed division by constant are expanded to the
2814         // sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division
2815         // to MULHS + SUB + SRL + ADD + SRL.
2816         InstructionCost MulCost = getArithmeticInstrCost(
2817             Instruction::Mul, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
2818         InstructionCost AddCost = getArithmeticInstrCost(
2819             Instruction::Add, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
2820         InstructionCost ShrCost = getArithmeticInstrCost(
2821             Instruction::AShr, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
2822         return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1;
2823       }
2824     }
2825 
2826     InstructionCost Cost = BaseT::getArithmeticInstrCost(
2827         Opcode, Ty, CostKind, Op1Info, Op2Info);
2828     if (Ty->isVectorTy()) {
2829       if (TLI->isOperationLegalOrCustom(ISD, LT.second) && ST->hasSVE()) {
2830         // SDIV/UDIV operations are lowered using SVE, then we can have less
2831         // costs.
2832         if (isa<FixedVectorType>(Ty) && cast<FixedVectorType>(Ty)
2833                                                 ->getPrimitiveSizeInBits()
2834                                                 .getFixedValue() < 128) {
2835           EVT VT = TLI->getValueType(DL, Ty);
2836           static const CostTblEntry DivTbl[]{
2837               {ISD::SDIV, MVT::v2i8, 5},  {ISD::SDIV, MVT::v4i8, 8},
2838               {ISD::SDIV, MVT::v8i8, 8},  {ISD::SDIV, MVT::v2i16, 5},
2839               {ISD::SDIV, MVT::v4i16, 5}, {ISD::SDIV, MVT::v2i32, 1},
2840               {ISD::UDIV, MVT::v2i8, 5},  {ISD::UDIV, MVT::v4i8, 8},
2841               {ISD::UDIV, MVT::v8i8, 8},  {ISD::UDIV, MVT::v2i16, 5},
2842               {ISD::UDIV, MVT::v4i16, 5}, {ISD::UDIV, MVT::v2i32, 1}};
2843 
2844           const auto *Entry = CostTableLookup(DivTbl, ISD, VT.getSimpleVT());
2845           if (nullptr != Entry)
2846             return Entry->Cost;
2847         }
2848         // For 8/16-bit elements, the cost is higher because the type
2849         // requires promotion and possibly splitting:
2850         if (LT.second.getScalarType() == MVT::i8)
2851           Cost *= 8;
2852         else if (LT.second.getScalarType() == MVT::i16)
2853           Cost *= 4;
2854         return Cost;
2855       } else {
2856         // If one of the operands is a uniform constant then the cost for each
2857         // element is Cost for insertion, extraction and division.
2858         // Insertion cost = 2, Extraction Cost = 2, Division = cost for the
2859         // operation with scalar type
2860         if ((Op1Info.isConstant() && Op1Info.isUniform()) ||
2861             (Op2Info.isConstant() && Op2Info.isUniform())) {
2862           if (auto *VTy = dyn_cast<FixedVectorType>(Ty)) {
2863             InstructionCost DivCost = BaseT::getArithmeticInstrCost(
2864                 Opcode, Ty->getScalarType(), CostKind, Op1Info, Op2Info);
2865             return (4 + DivCost) * VTy->getNumElements();
2866           }
2867         }
2868         // On AArch64, without SVE, vector divisions are expanded
2869         // into scalar divisions of each pair of elements.
2870         Cost += getArithmeticInstrCost(Instruction::ExtractElement, Ty,
2871                                        CostKind, Op1Info, Op2Info);
2872         Cost += getArithmeticInstrCost(Instruction::InsertElement, Ty, CostKind,
2873                                        Op1Info, Op2Info);
2874       }
2875 
2876       // TODO: if one of the arguments is scalar, then it's not necessary to
2877       // double the cost of handling the vector elements.
2878       Cost += Cost;
2879     }
2880     return Cost;
2881   }
2882   case ISD::MUL:
2883     // When SVE is available, then we can lower the v2i64 operation using
2884     // the SVE mul instruction, which has a lower cost.
2885     if (LT.second == MVT::v2i64 && ST->hasSVE())
2886       return LT.first;
2887 
2888     // When SVE is not available, there is no MUL.2d instruction,
2889     // which means mul <2 x i64> is expensive as elements are extracted
2890     // from the vectors and the muls scalarized.
2891     // As getScalarizationOverhead is a bit too pessimistic, we
2892     // estimate the cost for a i64 vector directly here, which is:
2893     // - four 2-cost i64 extracts,
2894     // - two 2-cost i64 inserts, and
2895     // - two 1-cost muls.
2896     // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with
2897     // LT.first = 2 the cost is 28. If both operands are extensions it will not
2898     // need to scalarize so the cost can be cheaper (smull or umull).
2899     // so the cost can be cheaper (smull or umull).
2900     if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args))
2901       return LT.first;
2902     return LT.first * 14;
2903   case ISD::ADD:
2904   case ISD::XOR:
2905   case ISD::OR:
2906   case ISD::AND:
2907   case ISD::SRL:
2908   case ISD::SRA:
2909   case ISD::SHL:
2910     // These nodes are marked as 'custom' for combining purposes only.
2911     // We know that they are legal. See LowerAdd in ISelLowering.
2912     return LT.first;
2913 
2914   case ISD::FNEG:
2915   case ISD::FADD:
2916   case ISD::FSUB:
2917     // Increase the cost for half and bfloat types if not architecturally
2918     // supported.
2919     if ((Ty->getScalarType()->isHalfTy() && !ST->hasFullFP16()) ||
2920         (Ty->getScalarType()->isBFloatTy() && !ST->hasBF16()))
2921       return 2 * LT.first;
2922     if (!Ty->getScalarType()->isFP128Ty())
2923       return LT.first;
2924     [[fallthrough]];
2925   case ISD::FMUL:
2926   case ISD::FDIV:
2927     // These nodes are marked as 'custom' just to lower them to SVE.
2928     // We know said lowering will incur no additional cost.
2929     if (!Ty->getScalarType()->isFP128Ty())
2930       return 2 * LT.first;
2931 
2932     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
2933                                          Op2Info);
2934   }
2935 }
2936 
2937 InstructionCost AArch64TTIImpl::getAddressComputationCost(Type *Ty,
2938                                                           ScalarEvolution *SE,
2939                                                           const SCEV *Ptr) {
2940   // Address computations in vectorized code with non-consecutive addresses will
2941   // likely result in more instructions compared to scalar code where the
2942   // computation can more often be merged into the index mode. The resulting
2943   // extra micro-ops can significantly decrease throughput.
2944   unsigned NumVectorInstToHideOverhead = NeonNonConstStrideOverhead;
2945   int MaxMergeDistance = 64;
2946 
2947   if (Ty->isVectorTy() && SE &&
2948       !BaseT::isConstantStridedAccessLessThan(SE, Ptr, MaxMergeDistance + 1))
2949     return NumVectorInstToHideOverhead;
2950 
2951   // In many cases the address computation is not merged into the instruction
2952   // addressing mode.
2953   return 1;
2954 }
2955 
2956 InstructionCost AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
2957                                                    Type *CondTy,
2958                                                    CmpInst::Predicate VecPred,
2959                                                    TTI::TargetCostKind CostKind,
2960                                                    const Instruction *I) {
2961   // TODO: Handle other cost kinds.
2962   if (CostKind != TTI::TCK_RecipThroughput)
2963     return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
2964                                      I);
2965 
2966   int ISD = TLI->InstructionOpcodeToISD(Opcode);
2967   // We don't lower some vector selects well that are wider than the register
2968   // width.
2969   if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT) {
2970     // We would need this many instructions to hide the scalarization happening.
2971     const int AmortizationCost = 20;
2972 
2973     // If VecPred is not set, check if we can get a predicate from the context
2974     // instruction, if its type matches the requested ValTy.
2975     if (VecPred == CmpInst::BAD_ICMP_PREDICATE && I && I->getType() == ValTy) {
2976       CmpInst::Predicate CurrentPred;
2977       if (match(I, m_Select(m_Cmp(CurrentPred, m_Value(), m_Value()), m_Value(),
2978                             m_Value())))
2979         VecPred = CurrentPred;
2980     }
2981     // Check if we have a compare/select chain that can be lowered using
2982     // a (F)CMxx & BFI pair.
2983     if (CmpInst::isIntPredicate(VecPred) || VecPred == CmpInst::FCMP_OLE ||
2984         VecPred == CmpInst::FCMP_OLT || VecPred == CmpInst::FCMP_OGT ||
2985         VecPred == CmpInst::FCMP_OGE || VecPred == CmpInst::FCMP_OEQ ||
2986         VecPred == CmpInst::FCMP_UNE) {
2987       static const auto ValidMinMaxTys = {
2988           MVT::v8i8,  MVT::v16i8, MVT::v4i16, MVT::v8i16, MVT::v2i32,
2989           MVT::v4i32, MVT::v2i64, MVT::v2f32, MVT::v4f32, MVT::v2f64};
2990       static const auto ValidFP16MinMaxTys = {MVT::v4f16, MVT::v8f16};
2991 
2992       auto LT = getTypeLegalizationCost(ValTy);
2993       if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }) ||
2994           (ST->hasFullFP16() &&
2995            any_of(ValidFP16MinMaxTys, [&LT](MVT M) { return M == LT.second; })))
2996         return LT.first;
2997     }
2998 
2999     static const TypeConversionCostTblEntry
3000     VectorSelectTbl[] = {
3001       { ISD::SELECT, MVT::v2i1, MVT::v2f32, 2 },
3002       { ISD::SELECT, MVT::v2i1, MVT::v2f64, 2 },
3003       { ISD::SELECT, MVT::v4i1, MVT::v4f32, 2 },
3004       { ISD::SELECT, MVT::v4i1, MVT::v4f16, 2 },
3005       { ISD::SELECT, MVT::v8i1, MVT::v8f16, 2 },
3006       { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 },
3007       { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 },
3008       { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 },
3009       { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost },
3010       { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost },
3011       { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost }
3012     };
3013 
3014     EVT SelCondTy = TLI->getValueType(DL, CondTy);
3015     EVT SelValTy = TLI->getValueType(DL, ValTy);
3016     if (SelCondTy.isSimple() && SelValTy.isSimple()) {
3017       if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD,
3018                                                      SelCondTy.getSimpleVT(),
3019                                                      SelValTy.getSimpleVT()))
3020         return Entry->Cost;
3021     }
3022   }
3023 
3024   if (isa<FixedVectorType>(ValTy) && ISD == ISD::SETCC) {
3025     auto LT = getTypeLegalizationCost(ValTy);
3026     // Cost v4f16 FCmp without FP16 support via converting to v4f32 and back.
3027     if (LT.second == MVT::v4f16 && !ST->hasFullFP16())
3028       return LT.first * 4; // fcvtl + fcvtl + fcmp + xtn
3029   }
3030 
3031   // Treat the icmp in icmp(and, 0) as free, as we can make use of ands.
3032   // FIXME: This can apply to more conditions and add/sub if it can be shown to
3033   // be profitable.
3034   if (ValTy->isIntegerTy() && ISD == ISD::SETCC && I &&
3035       ICmpInst::isEquality(VecPred) &&
3036       TLI->isTypeLegal(TLI->getValueType(DL, ValTy)) &&
3037       match(I->getOperand(1), m_Zero()) &&
3038       match(I->getOperand(0), m_And(m_Value(), m_Value())))
3039     return 0;
3040 
3041   // The base case handles scalable vectors fine for now, since it treats the
3042   // cost as 1 * legalization cost.
3043   return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I);
3044 }
3045 
3046 AArch64TTIImpl::TTI::MemCmpExpansionOptions
3047 AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
3048   TTI::MemCmpExpansionOptions Options;
3049   if (ST->requiresStrictAlign()) {
3050     // TODO: Add cost modeling for strict align. Misaligned loads expand to
3051     // a bunch of instructions when strict align is enabled.
3052     return Options;
3053   }
3054   Options.AllowOverlappingLoads = true;
3055   Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize);
3056   Options.NumLoadsPerBlock = Options.MaxNumLoads;
3057   // TODO: Though vector loads usually perform well on AArch64, in some targets
3058   // they may wake up the FP unit, which raises the power consumption.  Perhaps
3059   // they could be used with no holds barred (-O3).
3060   Options.LoadSizes = {8, 4, 2, 1};
3061   Options.AllowedTailExpansions = {3, 5, 6};
3062   return Options;
3063 }
3064 
3065 bool AArch64TTIImpl::prefersVectorizedAddressing() const {
3066   return ST->hasSVE();
3067 }
3068 
3069 InstructionCost
3070 AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
3071                                       Align Alignment, unsigned AddressSpace,
3072                                       TTI::TargetCostKind CostKind) {
3073   if (useNeonVector(Src))
3074     return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace,
3075                                         CostKind);
3076   auto LT = getTypeLegalizationCost(Src);
3077   if (!LT.first.isValid())
3078     return InstructionCost::getInvalid();
3079 
3080   // The code-generator is currently not able to handle scalable vectors
3081   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
3082   // it. This change will be removed when code-generation for these types is
3083   // sufficiently reliable.
3084   if (cast<VectorType>(Src)->getElementCount() == ElementCount::getScalable(1))
3085     return InstructionCost::getInvalid();
3086 
3087   return LT.first;
3088 }
3089 
3090 static unsigned getSVEGatherScatterOverhead(unsigned Opcode) {
3091   return Opcode == Instruction::Load ? SVEGatherOverhead : SVEScatterOverhead;
3092 }
3093 
3094 InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
3095     unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
3096     Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
3097   if (useNeonVector(DataTy) || !isLegalMaskedGatherScatter(DataTy))
3098     return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
3099                                          Alignment, CostKind, I);
3100   auto *VT = cast<VectorType>(DataTy);
3101   auto LT = getTypeLegalizationCost(DataTy);
3102   if (!LT.first.isValid())
3103     return InstructionCost::getInvalid();
3104 
3105   if (!LT.second.isVector() ||
3106       !isElementTypeLegalForScalableVector(VT->getElementType()))
3107     return InstructionCost::getInvalid();
3108 
3109   // The code-generator is currently not able to handle scalable vectors
3110   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
3111   // it. This change will be removed when code-generation for these types is
3112   // sufficiently reliable.
3113   if (cast<VectorType>(DataTy)->getElementCount() ==
3114       ElementCount::getScalable(1))
3115     return InstructionCost::getInvalid();
3116 
3117   ElementCount LegalVF = LT.second.getVectorElementCount();
3118   InstructionCost MemOpCost =
3119       getMemoryOpCost(Opcode, VT->getElementType(), Alignment, 0, CostKind,
3120                       {TTI::OK_AnyValue, TTI::OP_None}, I);
3121   // Add on an overhead cost for using gathers/scatters.
3122   // TODO: At the moment this is applied unilaterally for all CPUs, but at some
3123   // point we may want a per-CPU overhead.
3124   MemOpCost *= getSVEGatherScatterOverhead(Opcode);
3125   return LT.first * MemOpCost * getMaxNumElements(LegalVF);
3126 }
3127 
3128 bool AArch64TTIImpl::useNeonVector(const Type *Ty) const {
3129   return isa<FixedVectorType>(Ty) && !ST->useSVEForFixedLengthVectors();
3130 }
3131 
3132 InstructionCost AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty,
3133                                                 MaybeAlign Alignment,
3134                                                 unsigned AddressSpace,
3135                                                 TTI::TargetCostKind CostKind,
3136                                                 TTI::OperandValueInfo OpInfo,
3137                                                 const Instruction *I) {
3138   EVT VT = TLI->getValueType(DL, Ty, true);
3139   // Type legalization can't handle structs
3140   if (VT == MVT::Other)
3141     return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
3142                                   CostKind);
3143 
3144   auto LT = getTypeLegalizationCost(Ty);
3145   if (!LT.first.isValid())
3146     return InstructionCost::getInvalid();
3147 
3148   // The code-generator is currently not able to handle scalable vectors
3149   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
3150   // it. This change will be removed when code-generation for these types is
3151   // sufficiently reliable.
3152   if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
3153     if (VTy->getElementCount() == ElementCount::getScalable(1))
3154       return InstructionCost::getInvalid();
3155 
3156   // TODO: consider latency as well for TCK_SizeAndLatency.
3157   if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency)
3158     return LT.first;
3159 
3160   if (CostKind != TTI::TCK_RecipThroughput)
3161     return 1;
3162 
3163   if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store &&
3164       LT.second.is128BitVector() && (!Alignment || *Alignment < Align(16))) {
3165     // Unaligned stores are extremely inefficient. We don't split all
3166     // unaligned 128-bit stores because the negative impact that has shown in
3167     // practice on inlined block copy code.
3168     // We make such stores expensive so that we will only vectorize if there
3169     // are 6 other instructions getting vectorized.
3170     const int AmortizationCost = 6;
3171 
3172     return LT.first * 2 * AmortizationCost;
3173   }
3174 
3175   // Opaque ptr or ptr vector types are i64s and can be lowered to STP/LDPs.
3176   if (Ty->isPtrOrPtrVectorTy())
3177     return LT.first;
3178 
3179   // Check truncating stores and extending loads.
3180   if (useNeonVector(Ty) &&
3181       Ty->getScalarSizeInBits() != LT.second.getScalarSizeInBits()) {
3182     // v4i8 types are lowered to scalar a load/store and sshll/xtn.
3183     if (VT == MVT::v4i8)
3184       return 2;
3185     // Otherwise we need to scalarize.
3186     return cast<FixedVectorType>(Ty)->getNumElements() * 2;
3187   }
3188 
3189   return LT.first;
3190 }
3191 
3192 InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
3193     unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
3194     Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
3195     bool UseMaskForCond, bool UseMaskForGaps) {
3196   assert(Factor >= 2 && "Invalid interleave factor");
3197   auto *VecVTy = cast<VectorType>(VecTy);
3198 
3199   if (VecTy->isScalableTy() && (!ST->hasSVE() || Factor != 2))
3200     return InstructionCost::getInvalid();
3201 
3202   // Vectorization for masked interleaved accesses is only enabled for scalable
3203   // VF.
3204   if (!VecTy->isScalableTy() && (UseMaskForCond || UseMaskForGaps))
3205     return InstructionCost::getInvalid();
3206 
3207   if (!UseMaskForGaps && Factor <= TLI->getMaxSupportedInterleaveFactor()) {
3208     unsigned MinElts = VecVTy->getElementCount().getKnownMinValue();
3209     auto *SubVecTy =
3210         VectorType::get(VecVTy->getElementType(),
3211                         VecVTy->getElementCount().divideCoefficientBy(Factor));
3212 
3213     // ldN/stN only support legal vector types of size 64 or 128 in bits.
3214     // Accesses having vector types that are a multiple of 128 bits can be
3215     // matched to more than one ldN/stN instruction.
3216     bool UseScalable;
3217     if (MinElts % Factor == 0 &&
3218         TLI->isLegalInterleavedAccessType(SubVecTy, DL, UseScalable))
3219       return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL, UseScalable);
3220   }
3221 
3222   return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices,
3223                                            Alignment, AddressSpace, CostKind,
3224                                            UseMaskForCond, UseMaskForGaps);
3225 }
3226 
3227 InstructionCost
3228 AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
3229   InstructionCost Cost = 0;
3230   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
3231   for (auto *I : Tys) {
3232     if (!I->isVectorTy())
3233       continue;
3234     if (I->getScalarSizeInBits() * cast<FixedVectorType>(I)->getNumElements() ==
3235         128)
3236       Cost += getMemoryOpCost(Instruction::Store, I, Align(128), 0, CostKind) +
3237               getMemoryOpCost(Instruction::Load, I, Align(128), 0, CostKind);
3238   }
3239   return Cost;
3240 }
3241 
3242 unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
3243   return ST->getMaxInterleaveFactor();
3244 }
3245 
3246 // For Falkor, we want to avoid having too many strided loads in a loop since
3247 // that can exhaust the HW prefetcher resources.  We adjust the unroller
3248 // MaxCount preference below to attempt to ensure unrolling doesn't create too
3249 // many strided loads.
3250 static void
3251 getFalkorUnrollingPreferences(Loop *L, ScalarEvolution &SE,
3252                               TargetTransformInfo::UnrollingPreferences &UP) {
3253   enum { MaxStridedLoads = 7 };
3254   auto countStridedLoads = [](Loop *L, ScalarEvolution &SE) {
3255     int StridedLoads = 0;
3256     // FIXME? We could make this more precise by looking at the CFG and
3257     // e.g. not counting loads in each side of an if-then-else diamond.
3258     for (const auto BB : L->blocks()) {
3259       for (auto &I : *BB) {
3260         LoadInst *LMemI = dyn_cast<LoadInst>(&I);
3261         if (!LMemI)
3262           continue;
3263 
3264         Value *PtrValue = LMemI->getPointerOperand();
3265         if (L->isLoopInvariant(PtrValue))
3266           continue;
3267 
3268         const SCEV *LSCEV = SE.getSCEV(PtrValue);
3269         const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
3270         if (!LSCEVAddRec || !LSCEVAddRec->isAffine())
3271           continue;
3272 
3273         // FIXME? We could take pairing of unrolled load copies into account
3274         // by looking at the AddRec, but we would probably have to limit this
3275         // to loops with no stores or other memory optimization barriers.
3276         ++StridedLoads;
3277         // We've seen enough strided loads that seeing more won't make a
3278         // difference.
3279         if (StridedLoads > MaxStridedLoads / 2)
3280           return StridedLoads;
3281       }
3282     }
3283     return StridedLoads;
3284   };
3285 
3286   int StridedLoads = countStridedLoads(L, SE);
3287   LLVM_DEBUG(dbgs() << "falkor-hwpf: detected " << StridedLoads
3288                     << " strided loads\n");
3289   // Pick the largest power of 2 unroll count that won't result in too many
3290   // strided loads.
3291   if (StridedLoads) {
3292     UP.MaxCount = 1 << Log2_32(MaxStridedLoads / StridedLoads);
3293     LLVM_DEBUG(dbgs() << "falkor-hwpf: setting unroll MaxCount to "
3294                       << UP.MaxCount << '\n');
3295   }
3296 }
3297 
3298 void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
3299                                              TTI::UnrollingPreferences &UP,
3300                                              OptimizationRemarkEmitter *ORE) {
3301   // Enable partial unrolling and runtime unrolling.
3302   BaseT::getUnrollingPreferences(L, SE, UP, ORE);
3303 
3304   UP.UpperBound = true;
3305 
3306   // For inner loop, it is more likely to be a hot one, and the runtime check
3307   // can be promoted out from LICM pass, so the overhead is less, let's try
3308   // a larger threshold to unroll more loops.
3309   if (L->getLoopDepth() > 1)
3310     UP.PartialThreshold *= 2;
3311 
3312   // Disable partial & runtime unrolling on -Os.
3313   UP.PartialOptSizeThreshold = 0;
3314 
3315   if (ST->getProcFamily() == AArch64Subtarget::Falkor &&
3316       EnableFalkorHWPFUnrollFix)
3317     getFalkorUnrollingPreferences(L, SE, UP);
3318 
3319   // Scan the loop: don't unroll loops with calls as this could prevent
3320   // inlining. Don't unroll vector loops either, as they don't benefit much from
3321   // unrolling.
3322   for (auto *BB : L->getBlocks()) {
3323     for (auto &I : *BB) {
3324       // Don't unroll vectorised loop.
3325       if (I.getType()->isVectorTy())
3326         return;
3327 
3328       if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
3329         if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
3330           if (!isLoweredToCall(F))
3331             continue;
3332         }
3333         return;
3334       }
3335     }
3336   }
3337 
3338   // Enable runtime unrolling for in-order models
3339   // If mcpu is omitted, getProcFamily() returns AArch64Subtarget::Others, so by
3340   // checking for that case, we can ensure that the default behaviour is
3341   // unchanged
3342   if (ST->getProcFamily() != AArch64Subtarget::Others &&
3343       !ST->getSchedModel().isOutOfOrder()) {
3344     UP.Runtime = true;
3345     UP.Partial = true;
3346     UP.UnrollRemainder = true;
3347     UP.DefaultUnrollRuntimeCount = 4;
3348 
3349     UP.UnrollAndJam = true;
3350     UP.UnrollAndJamInnerLoopThreshold = 60;
3351   }
3352 }
3353 
3354 void AArch64TTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE,
3355                                            TTI::PeelingPreferences &PP) {
3356   BaseT::getPeelingPreferences(L, SE, PP);
3357 }
3358 
3359 Value *AArch64TTIImpl::getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst,
3360                                                          Type *ExpectedType) {
3361   switch (Inst->getIntrinsicID()) {
3362   default:
3363     return nullptr;
3364   case Intrinsic::aarch64_neon_st2:
3365   case Intrinsic::aarch64_neon_st3:
3366   case Intrinsic::aarch64_neon_st4: {
3367     // Create a struct type
3368     StructType *ST = dyn_cast<StructType>(ExpectedType);
3369     if (!ST)
3370       return nullptr;
3371     unsigned NumElts = Inst->arg_size() - 1;
3372     if (ST->getNumElements() != NumElts)
3373       return nullptr;
3374     for (unsigned i = 0, e = NumElts; i != e; ++i) {
3375       if (Inst->getArgOperand(i)->getType() != ST->getElementType(i))
3376         return nullptr;
3377     }
3378     Value *Res = PoisonValue::get(ExpectedType);
3379     IRBuilder<> Builder(Inst);
3380     for (unsigned i = 0, e = NumElts; i != e; ++i) {
3381       Value *L = Inst->getArgOperand(i);
3382       Res = Builder.CreateInsertValue(Res, L, i);
3383     }
3384     return Res;
3385   }
3386   case Intrinsic::aarch64_neon_ld2:
3387   case Intrinsic::aarch64_neon_ld3:
3388   case Intrinsic::aarch64_neon_ld4:
3389     if (Inst->getType() == ExpectedType)
3390       return Inst;
3391     return nullptr;
3392   }
3393 }
3394 
3395 bool AArch64TTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
3396                                         MemIntrinsicInfo &Info) {
3397   switch (Inst->getIntrinsicID()) {
3398   default:
3399     break;
3400   case Intrinsic::aarch64_neon_ld2:
3401   case Intrinsic::aarch64_neon_ld3:
3402   case Intrinsic::aarch64_neon_ld4:
3403     Info.ReadMem = true;
3404     Info.WriteMem = false;
3405     Info.PtrVal = Inst->getArgOperand(0);
3406     break;
3407   case Intrinsic::aarch64_neon_st2:
3408   case Intrinsic::aarch64_neon_st3:
3409   case Intrinsic::aarch64_neon_st4:
3410     Info.ReadMem = false;
3411     Info.WriteMem = true;
3412     Info.PtrVal = Inst->getArgOperand(Inst->arg_size() - 1);
3413     break;
3414   }
3415 
3416   switch (Inst->getIntrinsicID()) {
3417   default:
3418     return false;
3419   case Intrinsic::aarch64_neon_ld2:
3420   case Intrinsic::aarch64_neon_st2:
3421     Info.MatchingId = VECTOR_LDST_TWO_ELEMENTS;
3422     break;
3423   case Intrinsic::aarch64_neon_ld3:
3424   case Intrinsic::aarch64_neon_st3:
3425     Info.MatchingId = VECTOR_LDST_THREE_ELEMENTS;
3426     break;
3427   case Intrinsic::aarch64_neon_ld4:
3428   case Intrinsic::aarch64_neon_st4:
3429     Info.MatchingId = VECTOR_LDST_FOUR_ELEMENTS;
3430     break;
3431   }
3432   return true;
3433 }
3434 
3435 /// See if \p I should be considered for address type promotion. We check if \p
3436 /// I is a sext with right type and used in memory accesses. If it used in a
3437 /// "complex" getelementptr, we allow it to be promoted without finding other
3438 /// sext instructions that sign extended the same initial value. A getelementptr
3439 /// is considered as "complex" if it has more than 2 operands.
3440 bool AArch64TTIImpl::shouldConsiderAddressTypePromotion(
3441     const Instruction &I, bool &AllowPromotionWithoutCommonHeader) {
3442   bool Considerable = false;
3443   AllowPromotionWithoutCommonHeader = false;
3444   if (!isa<SExtInst>(&I))
3445     return false;
3446   Type *ConsideredSExtType =
3447       Type::getInt64Ty(I.getParent()->getParent()->getContext());
3448   if (I.getType() != ConsideredSExtType)
3449     return false;
3450   // See if the sext is the one with the right type and used in at least one
3451   // GetElementPtrInst.
3452   for (const User *U : I.users()) {
3453     if (const GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(U)) {
3454       Considerable = true;
3455       // A getelementptr is considered as "complex" if it has more than 2
3456       // operands. We will promote a SExt used in such complex GEP as we
3457       // expect some computation to be merged if they are done on 64 bits.
3458       if (GEPInst->getNumOperands() > 2) {
3459         AllowPromotionWithoutCommonHeader = true;
3460         break;
3461       }
3462     }
3463   }
3464   return Considerable;
3465 }
3466 
3467 bool AArch64TTIImpl::isLegalToVectorizeReduction(
3468     const RecurrenceDescriptor &RdxDesc, ElementCount VF) const {
3469   if (!VF.isScalable())
3470     return true;
3471 
3472   Type *Ty = RdxDesc.getRecurrenceType();
3473   if (Ty->isBFloatTy() || !isElementTypeLegalForScalableVector(Ty))
3474     return false;
3475 
3476   switch (RdxDesc.getRecurrenceKind()) {
3477   case RecurKind::Add:
3478   case RecurKind::FAdd:
3479   case RecurKind::And:
3480   case RecurKind::Or:
3481   case RecurKind::Xor:
3482   case RecurKind::SMin:
3483   case RecurKind::SMax:
3484   case RecurKind::UMin:
3485   case RecurKind::UMax:
3486   case RecurKind::FMin:
3487   case RecurKind::FMax:
3488   case RecurKind::FMulAdd:
3489   case RecurKind::IAnyOf:
3490   case RecurKind::FAnyOf:
3491     return true;
3492   default:
3493     return false;
3494   }
3495 }
3496 
3497 InstructionCost
3498 AArch64TTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty,
3499                                        FastMathFlags FMF,
3500                                        TTI::TargetCostKind CostKind) {
3501   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
3502 
3503   if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
3504     return BaseT::getMinMaxReductionCost(IID, Ty, FMF, CostKind);
3505 
3506   InstructionCost LegalizationCost = 0;
3507   if (LT.first > 1) {
3508     Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext());
3509     IntrinsicCostAttributes Attrs(IID, LegalVTy, {LegalVTy, LegalVTy}, FMF);
3510     LegalizationCost = getIntrinsicInstrCost(Attrs, CostKind) * (LT.first - 1);
3511   }
3512 
3513   return LegalizationCost + /*Cost of horizontal reduction*/ 2;
3514 }
3515 
3516 InstructionCost AArch64TTIImpl::getArithmeticReductionCostSVE(
3517     unsigned Opcode, VectorType *ValTy, TTI::TargetCostKind CostKind) {
3518   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
3519   InstructionCost LegalizationCost = 0;
3520   if (LT.first > 1) {
3521     Type *LegalVTy = EVT(LT.second).getTypeForEVT(ValTy->getContext());
3522     LegalizationCost = getArithmeticInstrCost(Opcode, LegalVTy, CostKind);
3523     LegalizationCost *= LT.first - 1;
3524   }
3525 
3526   int ISD = TLI->InstructionOpcodeToISD(Opcode);
3527   assert(ISD && "Invalid opcode");
3528   // Add the final reduction cost for the legal horizontal reduction
3529   switch (ISD) {
3530   case ISD::ADD:
3531   case ISD::AND:
3532   case ISD::OR:
3533   case ISD::XOR:
3534   case ISD::FADD:
3535     return LegalizationCost + 2;
3536   default:
3537     return InstructionCost::getInvalid();
3538   }
3539 }
3540 
3541 InstructionCost
3542 AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
3543                                            std::optional<FastMathFlags> FMF,
3544                                            TTI::TargetCostKind CostKind) {
3545   if (TTI::requiresOrderedReduction(FMF)) {
3546     if (auto *FixedVTy = dyn_cast<FixedVectorType>(ValTy)) {
3547       InstructionCost BaseCost =
3548           BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
3549       // Add on extra cost to reflect the extra overhead on some CPUs. We still
3550       // end up vectorizing for more computationally intensive loops.
3551       return BaseCost + FixedVTy->getNumElements();
3552     }
3553 
3554     if (Opcode != Instruction::FAdd)
3555       return InstructionCost::getInvalid();
3556 
3557     auto *VTy = cast<ScalableVectorType>(ValTy);
3558     InstructionCost Cost =
3559         getArithmeticInstrCost(Opcode, VTy->getScalarType(), CostKind);
3560     Cost *= getMaxNumElements(VTy->getElementCount());
3561     return Cost;
3562   }
3563 
3564   if (isa<ScalableVectorType>(ValTy))
3565     return getArithmeticReductionCostSVE(Opcode, ValTy, CostKind);
3566 
3567   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
3568   MVT MTy = LT.second;
3569   int ISD = TLI->InstructionOpcodeToISD(Opcode);
3570   assert(ISD && "Invalid opcode");
3571 
3572   // Horizontal adds can use the 'addv' instruction. We model the cost of these
3573   // instructions as twice a normal vector add, plus 1 for each legalization
3574   // step (LT.first). This is the only arithmetic vector reduction operation for
3575   // which we have an instruction.
3576   // OR, XOR and AND costs should match the codegen from:
3577   // OR: llvm/test/CodeGen/AArch64/reduce-or.ll
3578   // XOR: llvm/test/CodeGen/AArch64/reduce-xor.ll
3579   // AND: llvm/test/CodeGen/AArch64/reduce-and.ll
3580   static const CostTblEntry CostTblNoPairwise[]{
3581       {ISD::ADD, MVT::v8i8,   2},
3582       {ISD::ADD, MVT::v16i8,  2},
3583       {ISD::ADD, MVT::v4i16,  2},
3584       {ISD::ADD, MVT::v8i16,  2},
3585       {ISD::ADD, MVT::v4i32,  2},
3586       {ISD::ADD, MVT::v2i64,  2},
3587       {ISD::OR,  MVT::v8i8,  15},
3588       {ISD::OR,  MVT::v16i8, 17},
3589       {ISD::OR,  MVT::v4i16,  7},
3590       {ISD::OR,  MVT::v8i16,  9},
3591       {ISD::OR,  MVT::v2i32,  3},
3592       {ISD::OR,  MVT::v4i32,  5},
3593       {ISD::OR,  MVT::v2i64,  3},
3594       {ISD::XOR, MVT::v8i8,  15},
3595       {ISD::XOR, MVT::v16i8, 17},
3596       {ISD::XOR, MVT::v4i16,  7},
3597       {ISD::XOR, MVT::v8i16,  9},
3598       {ISD::XOR, MVT::v2i32,  3},
3599       {ISD::XOR, MVT::v4i32,  5},
3600       {ISD::XOR, MVT::v2i64,  3},
3601       {ISD::AND, MVT::v8i8,  15},
3602       {ISD::AND, MVT::v16i8, 17},
3603       {ISD::AND, MVT::v4i16,  7},
3604       {ISD::AND, MVT::v8i16,  9},
3605       {ISD::AND, MVT::v2i32,  3},
3606       {ISD::AND, MVT::v4i32,  5},
3607       {ISD::AND, MVT::v2i64,  3},
3608   };
3609   switch (ISD) {
3610   default:
3611     break;
3612   case ISD::ADD:
3613     if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy))
3614       return (LT.first - 1) + Entry->Cost;
3615     break;
3616   case ISD::XOR:
3617   case ISD::AND:
3618   case ISD::OR:
3619     const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy);
3620     if (!Entry)
3621       break;
3622     auto *ValVTy = cast<FixedVectorType>(ValTy);
3623     if (MTy.getVectorNumElements() <= ValVTy->getNumElements() &&
3624         isPowerOf2_32(ValVTy->getNumElements())) {
3625       InstructionCost ExtraCost = 0;
3626       if (LT.first != 1) {
3627         // Type needs to be split, so there is an extra cost of LT.first - 1
3628         // arithmetic ops.
3629         auto *Ty = FixedVectorType::get(ValTy->getElementType(),
3630                                         MTy.getVectorNumElements());
3631         ExtraCost = getArithmeticInstrCost(Opcode, Ty, CostKind);
3632         ExtraCost *= LT.first - 1;
3633       }
3634       // All and/or/xor of i1 will be lowered with maxv/minv/addv + fmov
3635       auto Cost = ValVTy->getElementType()->isIntegerTy(1) ? 2 : Entry->Cost;
3636       return Cost + ExtraCost;
3637     }
3638     break;
3639   }
3640   return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
3641 }
3642 
3643 InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
3644   static const CostTblEntry ShuffleTbl[] = {
3645       { TTI::SK_Splice, MVT::nxv16i8,  1 },
3646       { TTI::SK_Splice, MVT::nxv8i16,  1 },
3647       { TTI::SK_Splice, MVT::nxv4i32,  1 },
3648       { TTI::SK_Splice, MVT::nxv2i64,  1 },
3649       { TTI::SK_Splice, MVT::nxv2f16,  1 },
3650       { TTI::SK_Splice, MVT::nxv4f16,  1 },
3651       { TTI::SK_Splice, MVT::nxv8f16,  1 },
3652       { TTI::SK_Splice, MVT::nxv2bf16, 1 },
3653       { TTI::SK_Splice, MVT::nxv4bf16, 1 },
3654       { TTI::SK_Splice, MVT::nxv8bf16, 1 },
3655       { TTI::SK_Splice, MVT::nxv2f32,  1 },
3656       { TTI::SK_Splice, MVT::nxv4f32,  1 },
3657       { TTI::SK_Splice, MVT::nxv2f64,  1 },
3658   };
3659 
3660   // The code-generator is currently not able to handle scalable vectors
3661   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
3662   // it. This change will be removed when code-generation for these types is
3663   // sufficiently reliable.
3664   if (Tp->getElementCount() == ElementCount::getScalable(1))
3665     return InstructionCost::getInvalid();
3666 
3667   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
3668   Type *LegalVTy = EVT(LT.second).getTypeForEVT(Tp->getContext());
3669   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
3670   EVT PromotedVT = LT.second.getScalarType() == MVT::i1
3671                        ? TLI->getPromotedVTForPredicate(EVT(LT.second))
3672                        : LT.second;
3673   Type *PromotedVTy = EVT(PromotedVT).getTypeForEVT(Tp->getContext());
3674   InstructionCost LegalizationCost = 0;
3675   if (Index < 0) {
3676     LegalizationCost =
3677         getCmpSelInstrCost(Instruction::ICmp, PromotedVTy, PromotedVTy,
3678                            CmpInst::BAD_ICMP_PREDICATE, CostKind) +
3679         getCmpSelInstrCost(Instruction::Select, PromotedVTy, LegalVTy,
3680                            CmpInst::BAD_ICMP_PREDICATE, CostKind);
3681   }
3682 
3683   // Predicated splice are promoted when lowering. See AArch64ISelLowering.cpp
3684   // Cost performed on a promoted type.
3685   if (LT.second.getScalarType() == MVT::i1) {
3686     LegalizationCost +=
3687         getCastInstrCost(Instruction::ZExt, PromotedVTy, LegalVTy,
3688                          TTI::CastContextHint::None, CostKind) +
3689         getCastInstrCost(Instruction::Trunc, LegalVTy, PromotedVTy,
3690                          TTI::CastContextHint::None, CostKind);
3691   }
3692   const auto *Entry =
3693       CostTableLookup(ShuffleTbl, TTI::SK_Splice, PromotedVT.getSimpleVT());
3694   assert(Entry && "Illegal Type for Splice");
3695   LegalizationCost += Entry->Cost;
3696   return LegalizationCost * LT.first;
3697 }
3698 
3699 InstructionCost AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
3700                                                VectorType *Tp,
3701                                                ArrayRef<int> Mask,
3702                                                TTI::TargetCostKind CostKind,
3703                                                int Index, VectorType *SubTp,
3704                                                ArrayRef<const Value *> Args) {
3705   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
3706   // If we have a Mask, and the LT is being legalized somehow, split the Mask
3707   // into smaller vectors and sum the cost of each shuffle.
3708   if (!Mask.empty() && isa<FixedVectorType>(Tp) && LT.second.isVector() &&
3709       Tp->getScalarSizeInBits() == LT.second.getScalarSizeInBits() &&
3710       Mask.size() > LT.second.getVectorNumElements() && !Index && !SubTp) {
3711     unsigned TpNumElts = Mask.size();
3712     unsigned LTNumElts = LT.second.getVectorNumElements();
3713     unsigned NumVecs = (TpNumElts + LTNumElts - 1) / LTNumElts;
3714     VectorType *NTp =
3715         VectorType::get(Tp->getScalarType(), LT.second.getVectorElementCount());
3716     InstructionCost Cost;
3717     for (unsigned N = 0; N < NumVecs; N++) {
3718       SmallVector<int> NMask;
3719       // Split the existing mask into chunks of size LTNumElts. Track the source
3720       // sub-vectors to ensure the result has at most 2 inputs.
3721       unsigned Source1, Source2;
3722       unsigned NumSources = 0;
3723       for (unsigned E = 0; E < LTNumElts; E++) {
3724         int MaskElt = (N * LTNumElts + E < TpNumElts) ? Mask[N * LTNumElts + E]
3725                                                       : PoisonMaskElem;
3726         if (MaskElt < 0) {
3727           NMask.push_back(PoisonMaskElem);
3728           continue;
3729         }
3730 
3731         // Calculate which source from the input this comes from and whether it
3732         // is new to us.
3733         unsigned Source = MaskElt / LTNumElts;
3734         if (NumSources == 0) {
3735           Source1 = Source;
3736           NumSources = 1;
3737         } else if (NumSources == 1 && Source != Source1) {
3738           Source2 = Source;
3739           NumSources = 2;
3740         } else if (NumSources >= 2 && Source != Source1 && Source != Source2) {
3741           NumSources++;
3742         }
3743 
3744         // Add to the new mask. For the NumSources>2 case these are not correct,
3745         // but are only used for the modular lane number.
3746         if (Source == Source1)
3747           NMask.push_back(MaskElt % LTNumElts);
3748         else if (Source == Source2)
3749           NMask.push_back(MaskElt % LTNumElts + LTNumElts);
3750         else
3751           NMask.push_back(MaskElt % LTNumElts);
3752       }
3753       // If the sub-mask has at most 2 input sub-vectors then re-cost it using
3754       // getShuffleCost. If not then cost it using the worst case.
3755       if (NumSources <= 2)
3756         Cost += getShuffleCost(NumSources <= 1 ? TTI::SK_PermuteSingleSrc
3757                                                : TTI::SK_PermuteTwoSrc,
3758                                NTp, NMask, CostKind, 0, nullptr, Args);
3759       else if (any_of(enumerate(NMask), [&](const auto &ME) {
3760                  return ME.value() % LTNumElts == ME.index();
3761                }))
3762         Cost += LTNumElts - 1;
3763       else
3764         Cost += LTNumElts;
3765     }
3766     return Cost;
3767   }
3768 
3769   Kind = improveShuffleKindFromMask(Kind, Mask, Tp, Index, SubTp);
3770 
3771   // Check for broadcast loads, which are supported by the LD1R instruction.
3772   // In terms of code-size, the shuffle vector is free when a load + dup get
3773   // folded into a LD1R. That's what we check and return here. For performance
3774   // and reciprocal throughput, a LD1R is not completely free. In this case, we
3775   // return the cost for the broadcast below (i.e. 1 for most/all types), so
3776   // that we model the load + dup sequence slightly higher because LD1R is a
3777   // high latency instruction.
3778   if (CostKind == TTI::TCK_CodeSize && Kind == TTI::SK_Broadcast) {
3779     bool IsLoad = !Args.empty() && isa<LoadInst>(Args[0]);
3780     if (IsLoad && LT.second.isVector() &&
3781         isLegalBroadcastLoad(Tp->getElementType(),
3782                              LT.second.getVectorElementCount()))
3783       return 0;
3784   }
3785 
3786   // If we have 4 elements for the shuffle and a Mask, get the cost straight
3787   // from the perfect shuffle tables.
3788   if (Mask.size() == 4 && Tp->getElementCount() == ElementCount::getFixed(4) &&
3789       (Tp->getScalarSizeInBits() == 16 || Tp->getScalarSizeInBits() == 32) &&
3790       all_of(Mask, [](int E) { return E < 8; }))
3791     return getPerfectShuffleCost(Mask);
3792 
3793   if (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Transpose ||
3794       Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc ||
3795       Kind == TTI::SK_Reverse || Kind == TTI::SK_Splice) {
3796     static const CostTblEntry ShuffleTbl[] = {
3797         // Broadcast shuffle kinds can be performed with 'dup'.
3798         {TTI::SK_Broadcast, MVT::v8i8, 1},
3799         {TTI::SK_Broadcast, MVT::v16i8, 1},
3800         {TTI::SK_Broadcast, MVT::v4i16, 1},
3801         {TTI::SK_Broadcast, MVT::v8i16, 1},
3802         {TTI::SK_Broadcast, MVT::v2i32, 1},
3803         {TTI::SK_Broadcast, MVT::v4i32, 1},
3804         {TTI::SK_Broadcast, MVT::v2i64, 1},
3805         {TTI::SK_Broadcast, MVT::v4f16, 1},
3806         {TTI::SK_Broadcast, MVT::v8f16, 1},
3807         {TTI::SK_Broadcast, MVT::v2f32, 1},
3808         {TTI::SK_Broadcast, MVT::v4f32, 1},
3809         {TTI::SK_Broadcast, MVT::v2f64, 1},
3810         // Transpose shuffle kinds can be performed with 'trn1/trn2' and
3811         // 'zip1/zip2' instructions.
3812         {TTI::SK_Transpose, MVT::v8i8, 1},
3813         {TTI::SK_Transpose, MVT::v16i8, 1},
3814         {TTI::SK_Transpose, MVT::v4i16, 1},
3815         {TTI::SK_Transpose, MVT::v8i16, 1},
3816         {TTI::SK_Transpose, MVT::v2i32, 1},
3817         {TTI::SK_Transpose, MVT::v4i32, 1},
3818         {TTI::SK_Transpose, MVT::v2i64, 1},
3819         {TTI::SK_Transpose, MVT::v4f16, 1},
3820         {TTI::SK_Transpose, MVT::v8f16, 1},
3821         {TTI::SK_Transpose, MVT::v2f32, 1},
3822         {TTI::SK_Transpose, MVT::v4f32, 1},
3823         {TTI::SK_Transpose, MVT::v2f64, 1},
3824         // Select shuffle kinds.
3825         // TODO: handle vXi8/vXi16.
3826         {TTI::SK_Select, MVT::v2i32, 1}, // mov.
3827         {TTI::SK_Select, MVT::v4i32, 2}, // rev+trn (or similar).
3828         {TTI::SK_Select, MVT::v2i64, 1}, // mov.
3829         {TTI::SK_Select, MVT::v2f32, 1}, // mov.
3830         {TTI::SK_Select, MVT::v4f32, 2}, // rev+trn (or similar).
3831         {TTI::SK_Select, MVT::v2f64, 1}, // mov.
3832         // PermuteSingleSrc shuffle kinds.
3833         {TTI::SK_PermuteSingleSrc, MVT::v2i32, 1}, // mov.
3834         {TTI::SK_PermuteSingleSrc, MVT::v4i32, 3}, // perfectshuffle worst case.
3835         {TTI::SK_PermuteSingleSrc, MVT::v2i64, 1}, // mov.
3836         {TTI::SK_PermuteSingleSrc, MVT::v2f32, 1}, // mov.
3837         {TTI::SK_PermuteSingleSrc, MVT::v4f32, 3}, // perfectshuffle worst case.
3838         {TTI::SK_PermuteSingleSrc, MVT::v2f64, 1}, // mov.
3839         {TTI::SK_PermuteSingleSrc, MVT::v4i16, 3}, // perfectshuffle worst case.
3840         {TTI::SK_PermuteSingleSrc, MVT::v4f16, 3}, // perfectshuffle worst case.
3841         {TTI::SK_PermuteSingleSrc, MVT::v4bf16, 3}, // same
3842         {TTI::SK_PermuteSingleSrc, MVT::v8i16, 8},  // constpool + load + tbl
3843         {TTI::SK_PermuteSingleSrc, MVT::v8f16, 8},  // constpool + load + tbl
3844         {TTI::SK_PermuteSingleSrc, MVT::v8bf16, 8}, // constpool + load + tbl
3845         {TTI::SK_PermuteSingleSrc, MVT::v8i8, 8},   // constpool + load + tbl
3846         {TTI::SK_PermuteSingleSrc, MVT::v16i8, 8},  // constpool + load + tbl
3847         // Reverse can be lowered with `rev`.
3848         {TTI::SK_Reverse, MVT::v2i32, 1}, // REV64
3849         {TTI::SK_Reverse, MVT::v4i32, 2}, // REV64; EXT
3850         {TTI::SK_Reverse, MVT::v2i64, 1}, // EXT
3851         {TTI::SK_Reverse, MVT::v2f32, 1}, // REV64
3852         {TTI::SK_Reverse, MVT::v4f32, 2}, // REV64; EXT
3853         {TTI::SK_Reverse, MVT::v2f64, 1}, // EXT
3854         {TTI::SK_Reverse, MVT::v8f16, 2}, // REV64; EXT
3855         {TTI::SK_Reverse, MVT::v8i16, 2}, // REV64; EXT
3856         {TTI::SK_Reverse, MVT::v16i8, 2}, // REV64; EXT
3857         {TTI::SK_Reverse, MVT::v4f16, 1}, // REV64
3858         {TTI::SK_Reverse, MVT::v4i16, 1}, // REV64
3859         {TTI::SK_Reverse, MVT::v8i8, 1},  // REV64
3860         // Splice can all be lowered as `ext`.
3861         {TTI::SK_Splice, MVT::v2i32, 1},
3862         {TTI::SK_Splice, MVT::v4i32, 1},
3863         {TTI::SK_Splice, MVT::v2i64, 1},
3864         {TTI::SK_Splice, MVT::v2f32, 1},
3865         {TTI::SK_Splice, MVT::v4f32, 1},
3866         {TTI::SK_Splice, MVT::v2f64, 1},
3867         {TTI::SK_Splice, MVT::v8f16, 1},
3868         {TTI::SK_Splice, MVT::v8bf16, 1},
3869         {TTI::SK_Splice, MVT::v8i16, 1},
3870         {TTI::SK_Splice, MVT::v16i8, 1},
3871         {TTI::SK_Splice, MVT::v4bf16, 1},
3872         {TTI::SK_Splice, MVT::v4f16, 1},
3873         {TTI::SK_Splice, MVT::v4i16, 1},
3874         {TTI::SK_Splice, MVT::v8i8, 1},
3875         // Broadcast shuffle kinds for scalable vectors
3876         {TTI::SK_Broadcast, MVT::nxv16i8, 1},
3877         {TTI::SK_Broadcast, MVT::nxv8i16, 1},
3878         {TTI::SK_Broadcast, MVT::nxv4i32, 1},
3879         {TTI::SK_Broadcast, MVT::nxv2i64, 1},
3880         {TTI::SK_Broadcast, MVT::nxv2f16, 1},
3881         {TTI::SK_Broadcast, MVT::nxv4f16, 1},
3882         {TTI::SK_Broadcast, MVT::nxv8f16, 1},
3883         {TTI::SK_Broadcast, MVT::nxv2bf16, 1},
3884         {TTI::SK_Broadcast, MVT::nxv4bf16, 1},
3885         {TTI::SK_Broadcast, MVT::nxv8bf16, 1},
3886         {TTI::SK_Broadcast, MVT::nxv2f32, 1},
3887         {TTI::SK_Broadcast, MVT::nxv4f32, 1},
3888         {TTI::SK_Broadcast, MVT::nxv2f64, 1},
3889         {TTI::SK_Broadcast, MVT::nxv16i1, 1},
3890         {TTI::SK_Broadcast, MVT::nxv8i1, 1},
3891         {TTI::SK_Broadcast, MVT::nxv4i1, 1},
3892         {TTI::SK_Broadcast, MVT::nxv2i1, 1},
3893         // Handle the cases for vector.reverse with scalable vectors
3894         {TTI::SK_Reverse, MVT::nxv16i8, 1},
3895         {TTI::SK_Reverse, MVT::nxv8i16, 1},
3896         {TTI::SK_Reverse, MVT::nxv4i32, 1},
3897         {TTI::SK_Reverse, MVT::nxv2i64, 1},
3898         {TTI::SK_Reverse, MVT::nxv2f16, 1},
3899         {TTI::SK_Reverse, MVT::nxv4f16, 1},
3900         {TTI::SK_Reverse, MVT::nxv8f16, 1},
3901         {TTI::SK_Reverse, MVT::nxv2bf16, 1},
3902         {TTI::SK_Reverse, MVT::nxv4bf16, 1},
3903         {TTI::SK_Reverse, MVT::nxv8bf16, 1},
3904         {TTI::SK_Reverse, MVT::nxv2f32, 1},
3905         {TTI::SK_Reverse, MVT::nxv4f32, 1},
3906         {TTI::SK_Reverse, MVT::nxv2f64, 1},
3907         {TTI::SK_Reverse, MVT::nxv16i1, 1},
3908         {TTI::SK_Reverse, MVT::nxv8i1, 1},
3909         {TTI::SK_Reverse, MVT::nxv4i1, 1},
3910         {TTI::SK_Reverse, MVT::nxv2i1, 1},
3911     };
3912     if (const auto *Entry = CostTableLookup(ShuffleTbl, Kind, LT.second))
3913       return LT.first * Entry->Cost;
3914   }
3915 
3916   if (Kind == TTI::SK_Splice && isa<ScalableVectorType>(Tp))
3917     return getSpliceCost(Tp, Index);
3918 
3919   // Inserting a subvector can often be done with either a D, S or H register
3920   // move, so long as the inserted vector is "aligned".
3921   if (Kind == TTI::SK_InsertSubvector && LT.second.isFixedLengthVector() &&
3922       LT.second.getSizeInBits() <= 128 && SubTp) {
3923     std::pair<InstructionCost, MVT> SubLT = getTypeLegalizationCost(SubTp);
3924     if (SubLT.second.isVector()) {
3925       int NumElts = LT.second.getVectorNumElements();
3926       int NumSubElts = SubLT.second.getVectorNumElements();
3927       if ((Index % NumSubElts) == 0 && (NumElts % NumSubElts) == 0)
3928         return SubLT.first;
3929     }
3930   }
3931 
3932   return BaseT::getShuffleCost(Kind, Tp, Mask, CostKind, Index, SubTp);
3933 }
3934 
3935 static bool containsDecreasingPointers(Loop *TheLoop,
3936                                        PredicatedScalarEvolution *PSE) {
3937   const auto &Strides = DenseMap<Value *, const SCEV *>();
3938   for (BasicBlock *BB : TheLoop->blocks()) {
3939     // Scan the instructions in the block and look for addresses that are
3940     // consecutive and decreasing.
3941     for (Instruction &I : *BB) {
3942       if (isa<LoadInst>(&I) || isa<StoreInst>(&I)) {
3943         Value *Ptr = getLoadStorePointerOperand(&I);
3944         Type *AccessTy = getLoadStoreType(&I);
3945         if (getPtrStride(*PSE, AccessTy, Ptr, TheLoop, Strides, /*Assume=*/true,
3946                          /*ShouldCheckWrap=*/false)
3947                 .value_or(0) < 0)
3948           return true;
3949       }
3950     }
3951   }
3952   return false;
3953 }
3954 
3955 bool AArch64TTIImpl::preferPredicateOverEpilogue(TailFoldingInfo *TFI) {
3956   if (!ST->hasSVE())
3957     return false;
3958 
3959   // We don't currently support vectorisation with interleaving for SVE - with
3960   // such loops we're better off not using tail-folding. This gives us a chance
3961   // to fall back on fixed-width vectorisation using NEON's ld2/st2/etc.
3962   if (TFI->IAI->hasGroups())
3963     return false;
3964 
3965   TailFoldingOpts Required = TailFoldingOpts::Disabled;
3966   if (TFI->LVL->getReductionVars().size())
3967     Required |= TailFoldingOpts::Reductions;
3968   if (TFI->LVL->getFixedOrderRecurrences().size())
3969     Required |= TailFoldingOpts::Recurrences;
3970 
3971   // We call this to discover whether any load/store pointers in the loop have
3972   // negative strides. This will require extra work to reverse the loop
3973   // predicate, which may be expensive.
3974   if (containsDecreasingPointers(TFI->LVL->getLoop(),
3975                                  TFI->LVL->getPredicatedScalarEvolution()))
3976     Required |= TailFoldingOpts::Reverse;
3977   if (Required == TailFoldingOpts::Disabled)
3978     Required |= TailFoldingOpts::Simple;
3979 
3980   if (!TailFoldingOptionLoc.satisfies(ST->getSVETailFoldingDefaultOpts(),
3981                                       Required))
3982     return false;
3983 
3984   // Don't tail-fold for tight loops where we would be better off interleaving
3985   // with an unpredicated loop.
3986   unsigned NumInsns = 0;
3987   for (BasicBlock *BB : TFI->LVL->getLoop()->blocks()) {
3988     NumInsns += BB->sizeWithoutDebug();
3989   }
3990 
3991   // We expect 4 of these to be a IV PHI, IV add, IV compare and branch.
3992   return NumInsns >= SVETailFoldInsnThreshold;
3993 }
3994 
3995 InstructionCost
3996 AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
3997                                      int64_t BaseOffset, bool HasBaseReg,
3998                                      int64_t Scale, unsigned AddrSpace) const {
3999   // Scaling factors are not free at all.
4000   // Operands                     | Rt Latency
4001   // -------------------------------------------
4002   // Rt, [Xn, Xm]                 | 4
4003   // -------------------------------------------
4004   // Rt, [Xn, Xm, lsl #imm]       | Rn: 4 Rm: 5
4005   // Rt, [Xn, Wm, <extend> #imm]  |
4006   TargetLoweringBase::AddrMode AM;
4007   AM.BaseGV = BaseGV;
4008   AM.BaseOffs = BaseOffset;
4009   AM.HasBaseReg = HasBaseReg;
4010   AM.Scale = Scale;
4011   if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
4012     // Scale represents reg2 * scale, thus account for 1 if
4013     // it is not equal to 0 or 1.
4014     return AM.Scale != 0 && AM.Scale != 1;
4015   return -1;
4016 }
4017