xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (revision 1db9f3b21e39176dd5b67cf8ac378633b172463e)
1 //===-- AArch64TargetTransformInfo.cpp - AArch64 specific TTI -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "AArch64TargetTransformInfo.h"
10 #include "AArch64ExpandImm.h"
11 #include "AArch64PerfectShuffle.h"
12 #include "MCTargetDesc/AArch64AddressingModes.h"
13 #include "llvm/Analysis/IVDescriptors.h"
14 #include "llvm/Analysis/LoopInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/CodeGen/BasicTTIImpl.h"
17 #include "llvm/CodeGen/CostTable.h"
18 #include "llvm/CodeGen/TargetLowering.h"
19 #include "llvm/IR/IntrinsicInst.h"
20 #include "llvm/IR/Intrinsics.h"
21 #include "llvm/IR/IntrinsicsAArch64.h"
22 #include "llvm/IR/PatternMatch.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Transforms/InstCombine/InstCombiner.h"
25 #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
26 #include <algorithm>
27 #include <optional>
28 using namespace llvm;
29 using namespace llvm::PatternMatch;
30 
31 #define DEBUG_TYPE "aarch64tti"
32 
33 static cl::opt<bool> EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix",
34                                                cl::init(true), cl::Hidden);
35 
36 static cl::opt<unsigned> SVEGatherOverhead("sve-gather-overhead", cl::init(10),
37                                            cl::Hidden);
38 
39 static cl::opt<unsigned> SVEScatterOverhead("sve-scatter-overhead",
40                                             cl::init(10), cl::Hidden);
41 
42 static cl::opt<unsigned> SVETailFoldInsnThreshold("sve-tail-folding-insn-threshold",
43                                                   cl::init(15), cl::Hidden);
44 
45 static cl::opt<unsigned>
46     NeonNonConstStrideOverhead("neon-nonconst-stride-overhead", cl::init(10),
47                                cl::Hidden);
48 
49 static cl::opt<unsigned> CallPenaltyChangeSM(
50     "call-penalty-sm-change", cl::init(5), cl::Hidden,
51     cl::desc(
52         "Penalty of calling a function that requires a change to PSTATE.SM"));
53 
54 static cl::opt<unsigned> InlineCallPenaltyChangeSM(
55     "inline-call-penalty-sm-change", cl::init(10), cl::Hidden,
56     cl::desc("Penalty of inlining a call that requires a change to PSTATE.SM"));
57 
58 namespace {
59 class TailFoldingOption {
60   // These bitfields will only ever be set to something non-zero in operator=,
61   // when setting the -sve-tail-folding option. This option should always be of
62   // the form (default|simple|all|disable)[+(Flag1|Flag2|etc)], where here
63   // InitialBits is one of (disabled|all|simple). EnableBits represents
64   // additional flags we're enabling, and DisableBits for those flags we're
65   // disabling. The default flag is tracked in the variable NeedsDefault, since
66   // at the time of setting the option we may not know what the default value
67   // for the CPU is.
68   TailFoldingOpts InitialBits = TailFoldingOpts::Disabled;
69   TailFoldingOpts EnableBits = TailFoldingOpts::Disabled;
70   TailFoldingOpts DisableBits = TailFoldingOpts::Disabled;
71 
72   // This value needs to be initialised to true in case the user does not
73   // explicitly set the -sve-tail-folding option.
74   bool NeedsDefault = true;
75 
76   void setInitialBits(TailFoldingOpts Bits) { InitialBits = Bits; }
77 
78   void setNeedsDefault(bool V) { NeedsDefault = V; }
79 
80   void setEnableBit(TailFoldingOpts Bit) {
81     EnableBits |= Bit;
82     DisableBits &= ~Bit;
83   }
84 
85   void setDisableBit(TailFoldingOpts Bit) {
86     EnableBits &= ~Bit;
87     DisableBits |= Bit;
88   }
89 
90   TailFoldingOpts getBits(TailFoldingOpts DefaultBits) const {
91     TailFoldingOpts Bits = TailFoldingOpts::Disabled;
92 
93     assert((InitialBits == TailFoldingOpts::Disabled || !NeedsDefault) &&
94            "Initial bits should only include one of "
95            "(disabled|all|simple|default)");
96     Bits = NeedsDefault ? DefaultBits : InitialBits;
97     Bits |= EnableBits;
98     Bits &= ~DisableBits;
99 
100     return Bits;
101   }
102 
103   void reportError(std::string Opt) {
104     errs() << "invalid argument '" << Opt
105            << "' to -sve-tail-folding=; the option should be of the form\n"
106               "  (disabled|all|default|simple)[+(reductions|recurrences"
107               "|reverse|noreductions|norecurrences|noreverse)]\n";
108     report_fatal_error("Unrecognised tail-folding option");
109   }
110 
111 public:
112 
113   void operator=(const std::string &Val) {
114     // If the user explicitly sets -sve-tail-folding= then treat as an error.
115     if (Val.empty()) {
116       reportError("");
117       return;
118     }
119 
120     // Since the user is explicitly setting the option we don't automatically
121     // need the default unless they require it.
122     setNeedsDefault(false);
123 
124     SmallVector<StringRef, 4> TailFoldTypes;
125     StringRef(Val).split(TailFoldTypes, '+', -1, false);
126 
127     unsigned StartIdx = 1;
128     if (TailFoldTypes[0] == "disabled")
129       setInitialBits(TailFoldingOpts::Disabled);
130     else if (TailFoldTypes[0] == "all")
131       setInitialBits(TailFoldingOpts::All);
132     else if (TailFoldTypes[0] == "default")
133       setNeedsDefault(true);
134     else if (TailFoldTypes[0] == "simple")
135       setInitialBits(TailFoldingOpts::Simple);
136     else {
137       StartIdx = 0;
138       setInitialBits(TailFoldingOpts::Disabled);
139     }
140 
141     for (unsigned I = StartIdx; I < TailFoldTypes.size(); I++) {
142       if (TailFoldTypes[I] == "reductions")
143         setEnableBit(TailFoldingOpts::Reductions);
144       else if (TailFoldTypes[I] == "recurrences")
145         setEnableBit(TailFoldingOpts::Recurrences);
146       else if (TailFoldTypes[I] == "reverse")
147         setEnableBit(TailFoldingOpts::Reverse);
148       else if (TailFoldTypes[I] == "noreductions")
149         setDisableBit(TailFoldingOpts::Reductions);
150       else if (TailFoldTypes[I] == "norecurrences")
151         setDisableBit(TailFoldingOpts::Recurrences);
152       else if (TailFoldTypes[I] == "noreverse")
153         setDisableBit(TailFoldingOpts::Reverse);
154       else
155         reportError(Val);
156     }
157   }
158 
159   bool satisfies(TailFoldingOpts DefaultBits, TailFoldingOpts Required) const {
160     return (getBits(DefaultBits) & Required) == Required;
161   }
162 };
163 } // namespace
164 
165 TailFoldingOption TailFoldingOptionLoc;
166 
167 cl::opt<TailFoldingOption, true, cl::parser<std::string>> SVETailFolding(
168     "sve-tail-folding",
169     cl::desc(
170         "Control the use of vectorisation using tail-folding for SVE where the"
171         " option is specified in the form (Initial)[+(Flag1|Flag2|...)]:"
172         "\ndisabled      (Initial) No loop types will vectorize using "
173         "tail-folding"
174         "\ndefault       (Initial) Uses the default tail-folding settings for "
175         "the target CPU"
176         "\nall           (Initial) All legal loop types will vectorize using "
177         "tail-folding"
178         "\nsimple        (Initial) Use tail-folding for simple loops (not "
179         "reductions or recurrences)"
180         "\nreductions    Use tail-folding for loops containing reductions"
181         "\nnoreductions  Inverse of above"
182         "\nrecurrences   Use tail-folding for loops containing fixed order "
183         "recurrences"
184         "\nnorecurrences Inverse of above"
185         "\nreverse       Use tail-folding for loops requiring reversed "
186         "predicates"
187         "\nnoreverse     Inverse of above"),
188     cl::location(TailFoldingOptionLoc));
189 
190 // Experimental option that will only be fully functional when the
191 // code-generator is changed to use SVE instead of NEON for all fixed-width
192 // operations.
193 static cl::opt<bool> EnableFixedwidthAutovecInStreamingMode(
194     "enable-fixedwidth-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
195 
196 // Experimental option that will only be fully functional when the cost-model
197 // and code-generator have been changed to avoid using scalable vector
198 // instructions that are not legal in streaming SVE mode.
199 static cl::opt<bool> EnableScalableAutovecInStreamingMode(
200     "enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
201 
202 static bool isSMEABIRoutineCall(const CallInst &CI) {
203   const auto *F = CI.getCalledFunction();
204   return F && StringSwitch<bool>(F->getName())
205                   .Case("__arm_sme_state", true)
206                   .Case("__arm_tpidr2_save", true)
207                   .Case("__arm_tpidr2_restore", true)
208                   .Case("__arm_za_disable", true)
209                   .Default(false);
210 }
211 
212 /// Returns true if the function has explicit operations that can only be
213 /// lowered using incompatible instructions for the selected mode. This also
214 /// returns true if the function F may use or modify ZA state.
215 static bool hasPossibleIncompatibleOps(const Function *F) {
216   for (const BasicBlock &BB : *F) {
217     for (const Instruction &I : BB) {
218       // Be conservative for now and assume that any call to inline asm or to
219       // intrinsics could could result in non-streaming ops (e.g. calls to
220       // @llvm.aarch64.* or @llvm.gather/scatter intrinsics). We can assume that
221       // all native LLVM instructions can be lowered to compatible instructions.
222       if (isa<CallInst>(I) && !I.isDebugOrPseudoInst() &&
223           (cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
224            isSMEABIRoutineCall(cast<CallInst>(I))))
225         return true;
226     }
227   }
228   return false;
229 }
230 
231 bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
232                                          const Function *Callee) const {
233   SMEAttrs CallerAttrs(*Caller);
234   SMEAttrs CalleeAttrs(*Callee);
235   if (CalleeAttrs.hasNewZABody())
236     return false;
237 
238   if (CallerAttrs.requiresLazySave(CalleeAttrs) ||
239       CallerAttrs.requiresSMChange(CalleeAttrs,
240                                    /*BodyOverridesInterface=*/true)) {
241     if (hasPossibleIncompatibleOps(Callee))
242       return false;
243   }
244 
245   const TargetMachine &TM = getTLI()->getTargetMachine();
246 
247   const FeatureBitset &CallerBits =
248       TM.getSubtargetImpl(*Caller)->getFeatureBits();
249   const FeatureBitset &CalleeBits =
250       TM.getSubtargetImpl(*Callee)->getFeatureBits();
251 
252   // Inline a callee if its target-features are a subset of the callers
253   // target-features.
254   return (CallerBits & CalleeBits) == CalleeBits;
255 }
256 
257 bool AArch64TTIImpl::areTypesABICompatible(
258     const Function *Caller, const Function *Callee,
259     const ArrayRef<Type *> &Types) const {
260   if (!BaseT::areTypesABICompatible(Caller, Callee, Types))
261     return false;
262 
263   // We need to ensure that argument promotion does not attempt to promote
264   // pointers to fixed-length vector types larger than 128 bits like
265   // <8 x float> (and pointers to aggregate types which have such fixed-length
266   // vector type members) into the values of the pointees. Such vector types
267   // are used for SVE VLS but there is no ABI for SVE VLS arguments and the
268   // backend cannot lower such value arguments. The 128-bit fixed-length SVE
269   // types can be safely treated as 128-bit NEON types and they cannot be
270   // distinguished in IR.
271   if (ST->useSVEForFixedLengthVectors() && llvm::any_of(Types, [](Type *Ty) {
272         auto FVTy = dyn_cast<FixedVectorType>(Ty);
273         return FVTy &&
274                FVTy->getScalarSizeInBits() * FVTy->getNumElements() > 128;
275       }))
276     return false;
277 
278   return true;
279 }
280 
281 unsigned
282 AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
283                                      unsigned DefaultCallPenalty) const {
284   // This function calculates a penalty for executing Call in F.
285   //
286   // There are two ways this function can be called:
287   // (1)  F:
288   //       call from F -> G (the call here is Call)
289   //
290   // For (1), Call.getCaller() == F, so it will always return a high cost if
291   // a streaming-mode change is required (thus promoting the need to inline the
292   // function)
293   //
294   // (2)  F:
295   //       call from F -> G (the call here is not Call)
296   //      G:
297   //       call from G -> H (the call here is Call)
298   //
299   // For (2), if after inlining the body of G into F the call to H requires a
300   // streaming-mode change, and the call to G from F would also require a
301   // streaming-mode change, then there is benefit to do the streaming-mode
302   // change only once and avoid inlining of G into F.
303   SMEAttrs FAttrs(*F);
304   SMEAttrs CalleeAttrs(Call);
305   if (FAttrs.requiresSMChange(CalleeAttrs)) {
306     if (F == Call.getCaller()) // (1)
307       return CallPenaltyChangeSM * DefaultCallPenalty;
308     if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2)
309       return InlineCallPenaltyChangeSM * DefaultCallPenalty;
310   }
311 
312   return DefaultCallPenalty;
313 }
314 
315 bool AArch64TTIImpl::shouldMaximizeVectorBandwidth(
316     TargetTransformInfo::RegisterKind K) const {
317   assert(K != TargetTransformInfo::RGK_Scalar);
318   return (K == TargetTransformInfo::RGK_FixedWidthVector &&
319           ST->isNeonAvailable());
320 }
321 
322 /// Calculate the cost of materializing a 64-bit value. This helper
323 /// method might only calculate a fraction of a larger immediate. Therefore it
324 /// is valid to return a cost of ZERO.
325 InstructionCost AArch64TTIImpl::getIntImmCost(int64_t Val) {
326   // Check if the immediate can be encoded within an instruction.
327   if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64))
328     return 0;
329 
330   if (Val < 0)
331     Val = ~Val;
332 
333   // Calculate how many moves we will need to materialize this constant.
334   SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn;
335   AArch64_IMM::expandMOVImm(Val, 64, Insn);
336   return Insn.size();
337 }
338 
339 /// Calculate the cost of materializing the given constant.
340 InstructionCost AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
341                                               TTI::TargetCostKind CostKind) {
342   assert(Ty->isIntegerTy());
343 
344   unsigned BitSize = Ty->getPrimitiveSizeInBits();
345   if (BitSize == 0)
346     return ~0U;
347 
348   // Sign-extend all constants to a multiple of 64-bit.
349   APInt ImmVal = Imm;
350   if (BitSize & 0x3f)
351     ImmVal = Imm.sext((BitSize + 63) & ~0x3fU);
352 
353   // Split the constant into 64-bit chunks and calculate the cost for each
354   // chunk.
355   InstructionCost Cost = 0;
356   for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) {
357     APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64);
358     int64_t Val = Tmp.getSExtValue();
359     Cost += getIntImmCost(Val);
360   }
361   // We need at least one instruction to materialze the constant.
362   return std::max<InstructionCost>(1, Cost);
363 }
364 
365 InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
366                                                   const APInt &Imm, Type *Ty,
367                                                   TTI::TargetCostKind CostKind,
368                                                   Instruction *Inst) {
369   assert(Ty->isIntegerTy());
370 
371   unsigned BitSize = Ty->getPrimitiveSizeInBits();
372   // There is no cost model for constants with a bit size of 0. Return TCC_Free
373   // here, so that constant hoisting will ignore this constant.
374   if (BitSize == 0)
375     return TTI::TCC_Free;
376 
377   unsigned ImmIdx = ~0U;
378   switch (Opcode) {
379   default:
380     return TTI::TCC_Free;
381   case Instruction::GetElementPtr:
382     // Always hoist the base address of a GetElementPtr.
383     if (Idx == 0)
384       return 2 * TTI::TCC_Basic;
385     return TTI::TCC_Free;
386   case Instruction::Store:
387     ImmIdx = 0;
388     break;
389   case Instruction::Add:
390   case Instruction::Sub:
391   case Instruction::Mul:
392   case Instruction::UDiv:
393   case Instruction::SDiv:
394   case Instruction::URem:
395   case Instruction::SRem:
396   case Instruction::And:
397   case Instruction::Or:
398   case Instruction::Xor:
399   case Instruction::ICmp:
400     ImmIdx = 1;
401     break;
402   // Always return TCC_Free for the shift value of a shift instruction.
403   case Instruction::Shl:
404   case Instruction::LShr:
405   case Instruction::AShr:
406     if (Idx == 1)
407       return TTI::TCC_Free;
408     break;
409   case Instruction::Trunc:
410   case Instruction::ZExt:
411   case Instruction::SExt:
412   case Instruction::IntToPtr:
413   case Instruction::PtrToInt:
414   case Instruction::BitCast:
415   case Instruction::PHI:
416   case Instruction::Call:
417   case Instruction::Select:
418   case Instruction::Ret:
419   case Instruction::Load:
420     break;
421   }
422 
423   if (Idx == ImmIdx) {
424     int NumConstants = (BitSize + 63) / 64;
425     InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
426     return (Cost <= NumConstants * TTI::TCC_Basic)
427                ? static_cast<int>(TTI::TCC_Free)
428                : Cost;
429   }
430   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
431 }
432 
433 InstructionCost
434 AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
435                                     const APInt &Imm, Type *Ty,
436                                     TTI::TargetCostKind CostKind) {
437   assert(Ty->isIntegerTy());
438 
439   unsigned BitSize = Ty->getPrimitiveSizeInBits();
440   // There is no cost model for constants with a bit size of 0. Return TCC_Free
441   // here, so that constant hoisting will ignore this constant.
442   if (BitSize == 0)
443     return TTI::TCC_Free;
444 
445   // Most (all?) AArch64 intrinsics do not support folding immediates into the
446   // selected instruction, so we compute the materialization cost for the
447   // immediate directly.
448   if (IID >= Intrinsic::aarch64_addg && IID <= Intrinsic::aarch64_udiv)
449     return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
450 
451   switch (IID) {
452   default:
453     return TTI::TCC_Free;
454   case Intrinsic::sadd_with_overflow:
455   case Intrinsic::uadd_with_overflow:
456   case Intrinsic::ssub_with_overflow:
457   case Intrinsic::usub_with_overflow:
458   case Intrinsic::smul_with_overflow:
459   case Intrinsic::umul_with_overflow:
460     if (Idx == 1) {
461       int NumConstants = (BitSize + 63) / 64;
462       InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
463       return (Cost <= NumConstants * TTI::TCC_Basic)
464                  ? static_cast<int>(TTI::TCC_Free)
465                  : Cost;
466     }
467     break;
468   case Intrinsic::experimental_stackmap:
469     if ((Idx < 2) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
470       return TTI::TCC_Free;
471     break;
472   case Intrinsic::experimental_patchpoint_void:
473   case Intrinsic::experimental_patchpoint_i64:
474     if ((Idx < 4) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
475       return TTI::TCC_Free;
476     break;
477   case Intrinsic::experimental_gc_statepoint:
478     if ((Idx < 5) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
479       return TTI::TCC_Free;
480     break;
481   }
482   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
483 }
484 
485 TargetTransformInfo::PopcntSupportKind
486 AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) {
487   assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2");
488   if (TyWidth == 32 || TyWidth == 64)
489     return TTI::PSK_FastHardware;
490   // TODO: AArch64TargetLowering::LowerCTPOP() supports 128bit popcount.
491   return TTI::PSK_Software;
492 }
493 
494 InstructionCost
495 AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
496                                       TTI::TargetCostKind CostKind) {
497   auto *RetTy = ICA.getReturnType();
498   switch (ICA.getID()) {
499   case Intrinsic::umin:
500   case Intrinsic::umax:
501   case Intrinsic::smin:
502   case Intrinsic::smax: {
503     static const auto ValidMinMaxTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
504                                         MVT::v8i16, MVT::v2i32, MVT::v4i32,
505                                         MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32,
506                                         MVT::nxv2i64};
507     auto LT = getTypeLegalizationCost(RetTy);
508     // v2i64 types get converted to cmp+bif hence the cost of 2
509     if (LT.second == MVT::v2i64)
510       return LT.first * 2;
511     if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }))
512       return LT.first;
513     break;
514   }
515   case Intrinsic::sadd_sat:
516   case Intrinsic::ssub_sat:
517   case Intrinsic::uadd_sat:
518   case Intrinsic::usub_sat: {
519     static const auto ValidSatTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
520                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
521                                      MVT::v2i64};
522     auto LT = getTypeLegalizationCost(RetTy);
523     // This is a base cost of 1 for the vadd, plus 3 extract shifts if we
524     // need to extend the type, as it uses shr(qadd(shl, shl)).
525     unsigned Instrs =
526         LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4;
527     if (any_of(ValidSatTys, [&LT](MVT M) { return M == LT.second; }))
528       return LT.first * Instrs;
529     break;
530   }
531   case Intrinsic::abs: {
532     static const auto ValidAbsTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
533                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
534                                      MVT::v2i64};
535     auto LT = getTypeLegalizationCost(RetTy);
536     if (any_of(ValidAbsTys, [&LT](MVT M) { return M == LT.second; }))
537       return LT.first;
538     break;
539   }
540   case Intrinsic::bswap: {
541     static const auto ValidAbsTys = {MVT::v4i16, MVT::v8i16, MVT::v2i32,
542                                      MVT::v4i32, MVT::v2i64};
543     auto LT = getTypeLegalizationCost(RetTy);
544     if (any_of(ValidAbsTys, [&LT](MVT M) { return M == LT.second; }) &&
545         LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits())
546       return LT.first;
547     break;
548   }
549   case Intrinsic::experimental_stepvector: {
550     InstructionCost Cost = 1; // Cost of the `index' instruction
551     auto LT = getTypeLegalizationCost(RetTy);
552     // Legalisation of illegal vectors involves an `index' instruction plus
553     // (LT.first - 1) vector adds.
554     if (LT.first > 1) {
555       Type *LegalVTy = EVT(LT.second).getTypeForEVT(RetTy->getContext());
556       InstructionCost AddCost =
557           getArithmeticInstrCost(Instruction::Add, LegalVTy, CostKind);
558       Cost += AddCost * (LT.first - 1);
559     }
560     return Cost;
561   }
562   case Intrinsic::bitreverse: {
563     static const CostTblEntry BitreverseTbl[] = {
564         {Intrinsic::bitreverse, MVT::i32, 1},
565         {Intrinsic::bitreverse, MVT::i64, 1},
566         {Intrinsic::bitreverse, MVT::v8i8, 1},
567         {Intrinsic::bitreverse, MVT::v16i8, 1},
568         {Intrinsic::bitreverse, MVT::v4i16, 2},
569         {Intrinsic::bitreverse, MVT::v8i16, 2},
570         {Intrinsic::bitreverse, MVT::v2i32, 2},
571         {Intrinsic::bitreverse, MVT::v4i32, 2},
572         {Intrinsic::bitreverse, MVT::v1i64, 2},
573         {Intrinsic::bitreverse, MVT::v2i64, 2},
574     };
575     const auto LegalisationCost = getTypeLegalizationCost(RetTy);
576     const auto *Entry =
577         CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second);
578     if (Entry) {
579       // Cost Model is using the legal type(i32) that i8 and i16 will be
580       // converted to +1 so that we match the actual lowering cost
581       if (TLI->getValueType(DL, RetTy, true) == MVT::i8 ||
582           TLI->getValueType(DL, RetTy, true) == MVT::i16)
583         return LegalisationCost.first * Entry->Cost + 1;
584 
585       return LegalisationCost.first * Entry->Cost;
586     }
587     break;
588   }
589   case Intrinsic::ctpop: {
590     if (!ST->hasNEON()) {
591       // 32-bit or 64-bit ctpop without NEON is 12 instructions.
592       return getTypeLegalizationCost(RetTy).first * 12;
593     }
594     static const CostTblEntry CtpopCostTbl[] = {
595         {ISD::CTPOP, MVT::v2i64, 4},
596         {ISD::CTPOP, MVT::v4i32, 3},
597         {ISD::CTPOP, MVT::v8i16, 2},
598         {ISD::CTPOP, MVT::v16i8, 1},
599         {ISD::CTPOP, MVT::i64,   4},
600         {ISD::CTPOP, MVT::v2i32, 3},
601         {ISD::CTPOP, MVT::v4i16, 2},
602         {ISD::CTPOP, MVT::v8i8,  1},
603         {ISD::CTPOP, MVT::i32,   5},
604     };
605     auto LT = getTypeLegalizationCost(RetTy);
606     MVT MTy = LT.second;
607     if (const auto *Entry = CostTableLookup(CtpopCostTbl, ISD::CTPOP, MTy)) {
608       // Extra cost of +1 when illegal vector types are legalized by promoting
609       // the integer type.
610       int ExtraCost = MTy.isVector() && MTy.getScalarSizeInBits() !=
611                                             RetTy->getScalarSizeInBits()
612                           ? 1
613                           : 0;
614       return LT.first * Entry->Cost + ExtraCost;
615     }
616     break;
617   }
618   case Intrinsic::sadd_with_overflow:
619   case Intrinsic::uadd_with_overflow:
620   case Intrinsic::ssub_with_overflow:
621   case Intrinsic::usub_with_overflow:
622   case Intrinsic::smul_with_overflow:
623   case Intrinsic::umul_with_overflow: {
624     static const CostTblEntry WithOverflowCostTbl[] = {
625         {Intrinsic::sadd_with_overflow, MVT::i8, 3},
626         {Intrinsic::uadd_with_overflow, MVT::i8, 3},
627         {Intrinsic::sadd_with_overflow, MVT::i16, 3},
628         {Intrinsic::uadd_with_overflow, MVT::i16, 3},
629         {Intrinsic::sadd_with_overflow, MVT::i32, 1},
630         {Intrinsic::uadd_with_overflow, MVT::i32, 1},
631         {Intrinsic::sadd_with_overflow, MVT::i64, 1},
632         {Intrinsic::uadd_with_overflow, MVT::i64, 1},
633         {Intrinsic::ssub_with_overflow, MVT::i8, 3},
634         {Intrinsic::usub_with_overflow, MVT::i8, 3},
635         {Intrinsic::ssub_with_overflow, MVT::i16, 3},
636         {Intrinsic::usub_with_overflow, MVT::i16, 3},
637         {Intrinsic::ssub_with_overflow, MVT::i32, 1},
638         {Intrinsic::usub_with_overflow, MVT::i32, 1},
639         {Intrinsic::ssub_with_overflow, MVT::i64, 1},
640         {Intrinsic::usub_with_overflow, MVT::i64, 1},
641         {Intrinsic::smul_with_overflow, MVT::i8, 5},
642         {Intrinsic::umul_with_overflow, MVT::i8, 4},
643         {Intrinsic::smul_with_overflow, MVT::i16, 5},
644         {Intrinsic::umul_with_overflow, MVT::i16, 4},
645         {Intrinsic::smul_with_overflow, MVT::i32, 2}, // eg umull;tst
646         {Intrinsic::umul_with_overflow, MVT::i32, 2}, // eg umull;cmp sxtw
647         {Intrinsic::smul_with_overflow, MVT::i64, 3}, // eg mul;smulh;cmp
648         {Intrinsic::umul_with_overflow, MVT::i64, 3}, // eg mul;umulh;cmp asr
649     };
650     EVT MTy = TLI->getValueType(DL, RetTy->getContainedType(0), true);
651     if (MTy.isSimple())
652       if (const auto *Entry = CostTableLookup(WithOverflowCostTbl, ICA.getID(),
653                                               MTy.getSimpleVT()))
654         return Entry->Cost;
655     break;
656   }
657   case Intrinsic::fptosi_sat:
658   case Intrinsic::fptoui_sat: {
659     if (ICA.getArgTypes().empty())
660       break;
661     bool IsSigned = ICA.getID() == Intrinsic::fptosi_sat;
662     auto LT = getTypeLegalizationCost(ICA.getArgTypes()[0]);
663     EVT MTy = TLI->getValueType(DL, RetTy);
664     // Check for the legal types, which are where the size of the input and the
665     // output are the same, or we are using cvt f64->i32 or f32->i64.
666     if ((LT.second == MVT::f32 || LT.second == MVT::f64 ||
667          LT.second == MVT::v2f32 || LT.second == MVT::v4f32 ||
668          LT.second == MVT::v2f64) &&
669         (LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits() ||
670          (LT.second == MVT::f64 && MTy == MVT::i32) ||
671          (LT.second == MVT::f32 && MTy == MVT::i64)))
672       return LT.first;
673     // Similarly for fp16 sizes
674     if (ST->hasFullFP16() &&
675         ((LT.second == MVT::f16 && MTy == MVT::i32) ||
676          ((LT.second == MVT::v4f16 || LT.second == MVT::v8f16) &&
677           (LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits()))))
678       return LT.first;
679 
680     // Otherwise we use a legal convert followed by a min+max
681     if ((LT.second.getScalarType() == MVT::f32 ||
682          LT.second.getScalarType() == MVT::f64 ||
683          (ST->hasFullFP16() && LT.second.getScalarType() == MVT::f16)) &&
684         LT.second.getScalarSizeInBits() >= MTy.getScalarSizeInBits()) {
685       Type *LegalTy =
686           Type::getIntNTy(RetTy->getContext(), LT.second.getScalarSizeInBits());
687       if (LT.second.isVector())
688         LegalTy = VectorType::get(LegalTy, LT.second.getVectorElementCount());
689       InstructionCost Cost = 1;
690       IntrinsicCostAttributes Attrs1(IsSigned ? Intrinsic::smin : Intrinsic::umin,
691                                     LegalTy, {LegalTy, LegalTy});
692       Cost += getIntrinsicInstrCost(Attrs1, CostKind);
693       IntrinsicCostAttributes Attrs2(IsSigned ? Intrinsic::smax : Intrinsic::umax,
694                                     LegalTy, {LegalTy, LegalTy});
695       Cost += getIntrinsicInstrCost(Attrs2, CostKind);
696       return LT.first * Cost;
697     }
698     break;
699   }
700   case Intrinsic::fshl:
701   case Intrinsic::fshr: {
702     if (ICA.getArgs().empty())
703       break;
704 
705     // TODO: Add handling for fshl where third argument is not a constant.
706     const TTI::OperandValueInfo OpInfoZ = TTI::getOperandInfo(ICA.getArgs()[2]);
707     if (!OpInfoZ.isConstant())
708       break;
709 
710     const auto LegalisationCost = getTypeLegalizationCost(RetTy);
711     if (OpInfoZ.isUniform()) {
712       // FIXME: The costs could be lower if the codegen is better.
713       static const CostTblEntry FshlTbl[] = {
714           {Intrinsic::fshl, MVT::v4i32, 3}, // ushr + shl + orr
715           {Intrinsic::fshl, MVT::v2i64, 3}, {Intrinsic::fshl, MVT::v16i8, 4},
716           {Intrinsic::fshl, MVT::v8i16, 4}, {Intrinsic::fshl, MVT::v2i32, 3},
717           {Intrinsic::fshl, MVT::v8i8, 4},  {Intrinsic::fshl, MVT::v4i16, 4}};
718       // Costs for both fshl & fshr are the same, so just pass Intrinsic::fshl
719       // to avoid having to duplicate the costs.
720       const auto *Entry =
721           CostTableLookup(FshlTbl, Intrinsic::fshl, LegalisationCost.second);
722       if (Entry)
723         return LegalisationCost.first * Entry->Cost;
724     }
725 
726     auto TyL = getTypeLegalizationCost(RetTy);
727     if (!RetTy->isIntegerTy())
728       break;
729 
730     // Estimate cost manually, as types like i8 and i16 will get promoted to
731     // i32 and CostTableLookup will ignore the extra conversion cost.
732     bool HigherCost = (RetTy->getScalarSizeInBits() != 32 &&
733                        RetTy->getScalarSizeInBits() < 64) ||
734                       (RetTy->getScalarSizeInBits() % 64 != 0);
735     unsigned ExtraCost = HigherCost ? 1 : 0;
736     if (RetTy->getScalarSizeInBits() == 32 ||
737         RetTy->getScalarSizeInBits() == 64)
738       ExtraCost = 0; // fhsl/fshr for i32 and i64 can be lowered to a single
739                      // extr instruction.
740     else if (HigherCost)
741       ExtraCost = 1;
742     else
743       break;
744     return TyL.first + ExtraCost;
745   }
746   default:
747     break;
748   }
749   return BaseT::getIntrinsicInstrCost(ICA, CostKind);
750 }
751 
752 /// The function will remove redundant reinterprets casting in the presence
753 /// of the control flow
754 static std::optional<Instruction *> processPhiNode(InstCombiner &IC,
755                                                    IntrinsicInst &II) {
756   SmallVector<Instruction *, 32> Worklist;
757   auto RequiredType = II.getType();
758 
759   auto *PN = dyn_cast<PHINode>(II.getArgOperand(0));
760   assert(PN && "Expected Phi Node!");
761 
762   // Don't create a new Phi unless we can remove the old one.
763   if (!PN->hasOneUse())
764     return std::nullopt;
765 
766   for (Value *IncValPhi : PN->incoming_values()) {
767     auto *Reinterpret = dyn_cast<IntrinsicInst>(IncValPhi);
768     if (!Reinterpret ||
769         Reinterpret->getIntrinsicID() !=
770             Intrinsic::aarch64_sve_convert_to_svbool ||
771         RequiredType != Reinterpret->getArgOperand(0)->getType())
772       return std::nullopt;
773   }
774 
775   // Create the new Phi
776   IC.Builder.SetInsertPoint(PN);
777   PHINode *NPN = IC.Builder.CreatePHI(RequiredType, PN->getNumIncomingValues());
778   Worklist.push_back(PN);
779 
780   for (unsigned I = 0; I < PN->getNumIncomingValues(); I++) {
781     auto *Reinterpret = cast<Instruction>(PN->getIncomingValue(I));
782     NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(I));
783     Worklist.push_back(Reinterpret);
784   }
785 
786   // Cleanup Phi Node and reinterprets
787   return IC.replaceInstUsesWith(II, NPN);
788 }
789 
790 // (from_svbool (binop (to_svbool pred) (svbool_t _) (svbool_t _))))
791 // => (binop (pred) (from_svbool _) (from_svbool _))
792 //
793 // The above transformation eliminates a `to_svbool` in the predicate
794 // operand of bitwise operation `binop` by narrowing the vector width of
795 // the operation. For example, it would convert a `<vscale x 16 x i1>
796 // and` into a `<vscale x 4 x i1> and`. This is profitable because
797 // to_svbool must zero the new lanes during widening, whereas
798 // from_svbool is free.
799 static std::optional<Instruction *>
800 tryCombineFromSVBoolBinOp(InstCombiner &IC, IntrinsicInst &II) {
801   auto BinOp = dyn_cast<IntrinsicInst>(II.getOperand(0));
802   if (!BinOp)
803     return std::nullopt;
804 
805   auto IntrinsicID = BinOp->getIntrinsicID();
806   switch (IntrinsicID) {
807   case Intrinsic::aarch64_sve_and_z:
808   case Intrinsic::aarch64_sve_bic_z:
809   case Intrinsic::aarch64_sve_eor_z:
810   case Intrinsic::aarch64_sve_nand_z:
811   case Intrinsic::aarch64_sve_nor_z:
812   case Intrinsic::aarch64_sve_orn_z:
813   case Intrinsic::aarch64_sve_orr_z:
814     break;
815   default:
816     return std::nullopt;
817   }
818 
819   auto BinOpPred = BinOp->getOperand(0);
820   auto BinOpOp1 = BinOp->getOperand(1);
821   auto BinOpOp2 = BinOp->getOperand(2);
822 
823   auto PredIntr = dyn_cast<IntrinsicInst>(BinOpPred);
824   if (!PredIntr ||
825       PredIntr->getIntrinsicID() != Intrinsic::aarch64_sve_convert_to_svbool)
826     return std::nullopt;
827 
828   auto PredOp = PredIntr->getOperand(0);
829   auto PredOpTy = cast<VectorType>(PredOp->getType());
830   if (PredOpTy != II.getType())
831     return std::nullopt;
832 
833   SmallVector<Value *> NarrowedBinOpArgs = {PredOp};
834   auto NarrowBinOpOp1 = IC.Builder.CreateIntrinsic(
835       Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp1});
836   NarrowedBinOpArgs.push_back(NarrowBinOpOp1);
837   if (BinOpOp1 == BinOpOp2)
838     NarrowedBinOpArgs.push_back(NarrowBinOpOp1);
839   else
840     NarrowedBinOpArgs.push_back(IC.Builder.CreateIntrinsic(
841         Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp2}));
842 
843   auto NarrowedBinOp =
844       IC.Builder.CreateIntrinsic(IntrinsicID, {PredOpTy}, NarrowedBinOpArgs);
845   return IC.replaceInstUsesWith(II, NarrowedBinOp);
846 }
847 
848 static std::optional<Instruction *>
849 instCombineConvertFromSVBool(InstCombiner &IC, IntrinsicInst &II) {
850   // If the reinterpret instruction operand is a PHI Node
851   if (isa<PHINode>(II.getArgOperand(0)))
852     return processPhiNode(IC, II);
853 
854   if (auto BinOpCombine = tryCombineFromSVBoolBinOp(IC, II))
855     return BinOpCombine;
856 
857   // Ignore converts to/from svcount_t.
858   if (isa<TargetExtType>(II.getArgOperand(0)->getType()) ||
859       isa<TargetExtType>(II.getType()))
860     return std::nullopt;
861 
862   SmallVector<Instruction *, 32> CandidatesForRemoval;
863   Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr;
864 
865   const auto *IVTy = cast<VectorType>(II.getType());
866 
867   // Walk the chain of conversions.
868   while (Cursor) {
869     // If the type of the cursor has fewer lanes than the final result, zeroing
870     // must take place, which breaks the equivalence chain.
871     const auto *CursorVTy = cast<VectorType>(Cursor->getType());
872     if (CursorVTy->getElementCount().getKnownMinValue() <
873         IVTy->getElementCount().getKnownMinValue())
874       break;
875 
876     // If the cursor has the same type as I, it is a viable replacement.
877     if (Cursor->getType() == IVTy)
878       EarliestReplacement = Cursor;
879 
880     auto *IntrinsicCursor = dyn_cast<IntrinsicInst>(Cursor);
881 
882     // If this is not an SVE conversion intrinsic, this is the end of the chain.
883     if (!IntrinsicCursor || !(IntrinsicCursor->getIntrinsicID() ==
884                                   Intrinsic::aarch64_sve_convert_to_svbool ||
885                               IntrinsicCursor->getIntrinsicID() ==
886                                   Intrinsic::aarch64_sve_convert_from_svbool))
887       break;
888 
889     CandidatesForRemoval.insert(CandidatesForRemoval.begin(), IntrinsicCursor);
890     Cursor = IntrinsicCursor->getOperand(0);
891   }
892 
893   // If no viable replacement in the conversion chain was found, there is
894   // nothing to do.
895   if (!EarliestReplacement)
896     return std::nullopt;
897 
898   return IC.replaceInstUsesWith(II, EarliestReplacement);
899 }
900 
901 static bool isAllActivePredicate(Value *Pred) {
902   // Look through convert.from.svbool(convert.to.svbool(...) chain.
903   Value *UncastedPred;
904   if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_from_svbool>(
905                       m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>(
906                           m_Value(UncastedPred)))))
907     // If the predicate has the same or less lanes than the uncasted
908     // predicate then we know the casting has no effect.
909     if (cast<ScalableVectorType>(Pred->getType())->getMinNumElements() <=
910         cast<ScalableVectorType>(UncastedPred->getType())->getMinNumElements())
911       Pred = UncastedPred;
912 
913   return match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
914                          m_ConstantInt<AArch64SVEPredPattern::all>()));
915 }
916 
917 static std::optional<Instruction *> instCombineSVESel(InstCombiner &IC,
918                                                       IntrinsicInst &II) {
919   // svsel(ptrue, x, y) => x
920   auto *OpPredicate = II.getOperand(0);
921   if (isAllActivePredicate(OpPredicate))
922     return IC.replaceInstUsesWith(II, II.getOperand(1));
923 
924   auto Select =
925       IC.Builder.CreateSelect(OpPredicate, II.getOperand(1), II.getOperand(2));
926   return IC.replaceInstUsesWith(II, Select);
927 }
928 
929 static std::optional<Instruction *> instCombineSVEDup(InstCombiner &IC,
930                                                       IntrinsicInst &II) {
931   IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
932   if (!Pg)
933     return std::nullopt;
934 
935   if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
936     return std::nullopt;
937 
938   const auto PTruePattern =
939       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
940   if (PTruePattern != AArch64SVEPredPattern::vl1)
941     return std::nullopt;
942 
943   // The intrinsic is inserting into lane zero so use an insert instead.
944   auto *IdxTy = Type::getInt64Ty(II.getContext());
945   auto *Insert = InsertElementInst::Create(
946       II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0));
947   Insert->insertBefore(&II);
948   Insert->takeName(&II);
949 
950   return IC.replaceInstUsesWith(II, Insert);
951 }
952 
953 static std::optional<Instruction *> instCombineSVEDupX(InstCombiner &IC,
954                                                        IntrinsicInst &II) {
955   // Replace DupX with a regular IR splat.
956   auto *RetTy = cast<ScalableVectorType>(II.getType());
957   Value *Splat = IC.Builder.CreateVectorSplat(RetTy->getElementCount(),
958                                               II.getArgOperand(0));
959   Splat->takeName(&II);
960   return IC.replaceInstUsesWith(II, Splat);
961 }
962 
963 static std::optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
964                                                         IntrinsicInst &II) {
965   LLVMContext &Ctx = II.getContext();
966 
967   // Check that the predicate is all active
968   auto *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(0));
969   if (!Pg || Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
970     return std::nullopt;
971 
972   const auto PTruePattern =
973       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
974   if (PTruePattern != AArch64SVEPredPattern::all)
975     return std::nullopt;
976 
977   // Check that we have a compare of zero..
978   auto *SplatValue =
979       dyn_cast_or_null<ConstantInt>(getSplatValue(II.getArgOperand(2)));
980   if (!SplatValue || !SplatValue->isZero())
981     return std::nullopt;
982 
983   // ..against a dupq
984   auto *DupQLane = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
985   if (!DupQLane ||
986       DupQLane->getIntrinsicID() != Intrinsic::aarch64_sve_dupq_lane)
987     return std::nullopt;
988 
989   // Where the dupq is a lane 0 replicate of a vector insert
990   if (!cast<ConstantInt>(DupQLane->getArgOperand(1))->isZero())
991     return std::nullopt;
992 
993   auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0));
994   if (!VecIns || VecIns->getIntrinsicID() != Intrinsic::vector_insert)
995     return std::nullopt;
996 
997   // Where the vector insert is a fixed constant vector insert into undef at
998   // index zero
999   if (!isa<UndefValue>(VecIns->getArgOperand(0)))
1000     return std::nullopt;
1001 
1002   if (!cast<ConstantInt>(VecIns->getArgOperand(2))->isZero())
1003     return std::nullopt;
1004 
1005   auto *ConstVec = dyn_cast<Constant>(VecIns->getArgOperand(1));
1006   if (!ConstVec)
1007     return std::nullopt;
1008 
1009   auto *VecTy = dyn_cast<FixedVectorType>(ConstVec->getType());
1010   auto *OutTy = dyn_cast<ScalableVectorType>(II.getType());
1011   if (!VecTy || !OutTy || VecTy->getNumElements() != OutTy->getMinNumElements())
1012     return std::nullopt;
1013 
1014   unsigned NumElts = VecTy->getNumElements();
1015   unsigned PredicateBits = 0;
1016 
1017   // Expand intrinsic operands to a 16-bit byte level predicate
1018   for (unsigned I = 0; I < NumElts; ++I) {
1019     auto *Arg = dyn_cast<ConstantInt>(ConstVec->getAggregateElement(I));
1020     if (!Arg)
1021       return std::nullopt;
1022     if (!Arg->isZero())
1023       PredicateBits |= 1 << (I * (16 / NumElts));
1024   }
1025 
1026   // If all bits are zero bail early with an empty predicate
1027   if (PredicateBits == 0) {
1028     auto *PFalse = Constant::getNullValue(II.getType());
1029     PFalse->takeName(&II);
1030     return IC.replaceInstUsesWith(II, PFalse);
1031   }
1032 
1033   // Calculate largest predicate type used (where byte predicate is largest)
1034   unsigned Mask = 8;
1035   for (unsigned I = 0; I < 16; ++I)
1036     if ((PredicateBits & (1 << I)) != 0)
1037       Mask |= (I % 8);
1038 
1039   unsigned PredSize = Mask & -Mask;
1040   auto *PredType = ScalableVectorType::get(
1041       Type::getInt1Ty(Ctx), AArch64::SVEBitsPerBlock / (PredSize * 8));
1042 
1043   // Ensure all relevant bits are set
1044   for (unsigned I = 0; I < 16; I += PredSize)
1045     if ((PredicateBits & (1 << I)) == 0)
1046       return std::nullopt;
1047 
1048   auto *PTruePat =
1049       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
1050   auto *PTrue = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
1051                                            {PredType}, {PTruePat});
1052   auto *ConvertToSVBool = IC.Builder.CreateIntrinsic(
1053       Intrinsic::aarch64_sve_convert_to_svbool, {PredType}, {PTrue});
1054   auto *ConvertFromSVBool =
1055       IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
1056                                  {II.getType()}, {ConvertToSVBool});
1057 
1058   ConvertFromSVBool->takeName(&II);
1059   return IC.replaceInstUsesWith(II, ConvertFromSVBool);
1060 }
1061 
1062 static std::optional<Instruction *> instCombineSVELast(InstCombiner &IC,
1063                                                        IntrinsicInst &II) {
1064   Value *Pg = II.getArgOperand(0);
1065   Value *Vec = II.getArgOperand(1);
1066   auto IntrinsicID = II.getIntrinsicID();
1067   bool IsAfter = IntrinsicID == Intrinsic::aarch64_sve_lasta;
1068 
1069   // lastX(splat(X)) --> X
1070   if (auto *SplatVal = getSplatValue(Vec))
1071     return IC.replaceInstUsesWith(II, SplatVal);
1072 
1073   // If x and/or y is a splat value then:
1074   // lastX (binop (x, y)) --> binop(lastX(x), lastX(y))
1075   Value *LHS, *RHS;
1076   if (match(Vec, m_OneUse(m_BinOp(m_Value(LHS), m_Value(RHS))))) {
1077     if (isSplatValue(LHS) || isSplatValue(RHS)) {
1078       auto *OldBinOp = cast<BinaryOperator>(Vec);
1079       auto OpC = OldBinOp->getOpcode();
1080       auto *NewLHS =
1081           IC.Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, LHS});
1082       auto *NewRHS =
1083           IC.Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, RHS});
1084       auto *NewBinOp = BinaryOperator::CreateWithCopiedFlags(
1085           OpC, NewLHS, NewRHS, OldBinOp, OldBinOp->getName(), &II);
1086       return IC.replaceInstUsesWith(II, NewBinOp);
1087     }
1088   }
1089 
1090   auto *C = dyn_cast<Constant>(Pg);
1091   if (IsAfter && C && C->isNullValue()) {
1092     // The intrinsic is extracting lane 0 so use an extract instead.
1093     auto *IdxTy = Type::getInt64Ty(II.getContext());
1094     auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, 0));
1095     Extract->insertBefore(&II);
1096     Extract->takeName(&II);
1097     return IC.replaceInstUsesWith(II, Extract);
1098   }
1099 
1100   auto *IntrPG = dyn_cast<IntrinsicInst>(Pg);
1101   if (!IntrPG)
1102     return std::nullopt;
1103 
1104   if (IntrPG->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
1105     return std::nullopt;
1106 
1107   const auto PTruePattern =
1108       cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue();
1109 
1110   // Can the intrinsic's predicate be converted to a known constant index?
1111   unsigned MinNumElts = getNumElementsFromSVEPredPattern(PTruePattern);
1112   if (!MinNumElts)
1113     return std::nullopt;
1114 
1115   unsigned Idx = MinNumElts - 1;
1116   // Increment the index if extracting the element after the last active
1117   // predicate element.
1118   if (IsAfter)
1119     ++Idx;
1120 
1121   // Ignore extracts whose index is larger than the known minimum vector
1122   // length. NOTE: This is an artificial constraint where we prefer to
1123   // maintain what the user asked for until an alternative is proven faster.
1124   auto *PgVTy = cast<ScalableVectorType>(Pg->getType());
1125   if (Idx >= PgVTy->getMinNumElements())
1126     return std::nullopt;
1127 
1128   // The intrinsic is extracting a fixed lane so use an extract instead.
1129   auto *IdxTy = Type::getInt64Ty(II.getContext());
1130   auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, Idx));
1131   Extract->insertBefore(&II);
1132   Extract->takeName(&II);
1133   return IC.replaceInstUsesWith(II, Extract);
1134 }
1135 
1136 static std::optional<Instruction *> instCombineSVECondLast(InstCombiner &IC,
1137                                                            IntrinsicInst &II) {
1138   // The SIMD&FP variant of CLAST[AB] is significantly faster than the scalar
1139   // integer variant across a variety of micro-architectures. Replace scalar
1140   // integer CLAST[AB] intrinsic with optimal SIMD&FP variant. A simple
1141   // bitcast-to-fp + clast[ab] + bitcast-to-int will cost a cycle or two more
1142   // depending on the micro-architecture, but has been observed as generally
1143   // being faster, particularly when the CLAST[AB] op is a loop-carried
1144   // dependency.
1145   Value *Pg = II.getArgOperand(0);
1146   Value *Fallback = II.getArgOperand(1);
1147   Value *Vec = II.getArgOperand(2);
1148   Type *Ty = II.getType();
1149 
1150   if (!Ty->isIntegerTy())
1151     return std::nullopt;
1152 
1153   Type *FPTy;
1154   switch (cast<IntegerType>(Ty)->getBitWidth()) {
1155   default:
1156     return std::nullopt;
1157   case 16:
1158     FPTy = IC.Builder.getHalfTy();
1159     break;
1160   case 32:
1161     FPTy = IC.Builder.getFloatTy();
1162     break;
1163   case 64:
1164     FPTy = IC.Builder.getDoubleTy();
1165     break;
1166   }
1167 
1168   Value *FPFallBack = IC.Builder.CreateBitCast(Fallback, FPTy);
1169   auto *FPVTy = VectorType::get(
1170       FPTy, cast<VectorType>(Vec->getType())->getElementCount());
1171   Value *FPVec = IC.Builder.CreateBitCast(Vec, FPVTy);
1172   auto *FPII = IC.Builder.CreateIntrinsic(
1173       II.getIntrinsicID(), {FPVec->getType()}, {Pg, FPFallBack, FPVec});
1174   Value *FPIItoInt = IC.Builder.CreateBitCast(FPII, II.getType());
1175   return IC.replaceInstUsesWith(II, FPIItoInt);
1176 }
1177 
1178 static std::optional<Instruction *> instCombineRDFFR(InstCombiner &IC,
1179                                                      IntrinsicInst &II) {
1180   LLVMContext &Ctx = II.getContext();
1181   // Replace rdffr with predicated rdffr.z intrinsic, so that optimizePTestInstr
1182   // can work with RDFFR_PP for ptest elimination.
1183   auto *AllPat =
1184       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
1185   auto *PTrue = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
1186                                            {II.getType()}, {AllPat});
1187   auto *RDFFR =
1188       IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_rdffr_z, {}, {PTrue});
1189   RDFFR->takeName(&II);
1190   return IC.replaceInstUsesWith(II, RDFFR);
1191 }
1192 
1193 static std::optional<Instruction *>
1194 instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
1195   const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
1196 
1197   if (Pattern == AArch64SVEPredPattern::all) {
1198     Constant *StepVal = ConstantInt::get(II.getType(), NumElts);
1199     auto *VScale = IC.Builder.CreateVScale(StepVal);
1200     VScale->takeName(&II);
1201     return IC.replaceInstUsesWith(II, VScale);
1202   }
1203 
1204   unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern);
1205 
1206   return MinNumElts && NumElts >= MinNumElts
1207              ? std::optional<Instruction *>(IC.replaceInstUsesWith(
1208                    II, ConstantInt::get(II.getType(), MinNumElts)))
1209              : std::nullopt;
1210 }
1211 
1212 static std::optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
1213                                                         IntrinsicInst &II) {
1214   Value *PgVal = II.getArgOperand(0);
1215   Value *OpVal = II.getArgOperand(1);
1216 
1217   // PTEST_<FIRST|LAST>(X, X) is equivalent to PTEST_ANY(X, X).
1218   // Later optimizations prefer this form.
1219   if (PgVal == OpVal &&
1220       (II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_first ||
1221        II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_last)) {
1222     Value *Ops[] = {PgVal, OpVal};
1223     Type *Tys[] = {PgVal->getType()};
1224 
1225     auto *PTest =
1226         IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptest_any, Tys, Ops);
1227     PTest->takeName(&II);
1228 
1229     return IC.replaceInstUsesWith(II, PTest);
1230   }
1231 
1232   IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(PgVal);
1233   IntrinsicInst *Op = dyn_cast<IntrinsicInst>(OpVal);
1234 
1235   if (!Pg || !Op)
1236     return std::nullopt;
1237 
1238   Intrinsic::ID OpIID = Op->getIntrinsicID();
1239 
1240   if (Pg->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
1241       OpIID == Intrinsic::aarch64_sve_convert_to_svbool &&
1242       Pg->getArgOperand(0)->getType() == Op->getArgOperand(0)->getType()) {
1243     Value *Ops[] = {Pg->getArgOperand(0), Op->getArgOperand(0)};
1244     Type *Tys[] = {Pg->getArgOperand(0)->getType()};
1245 
1246     auto *PTest = IC.Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
1247 
1248     PTest->takeName(&II);
1249     return IC.replaceInstUsesWith(II, PTest);
1250   }
1251 
1252   // Transform PTEST_ANY(X=OP(PG,...), X) -> PTEST_ANY(PG, X)).
1253   // Later optimizations may rewrite sequence to use the flag-setting variant
1254   // of instruction X to remove PTEST.
1255   if ((Pg == Op) && (II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_any) &&
1256       ((OpIID == Intrinsic::aarch64_sve_brka_z) ||
1257        (OpIID == Intrinsic::aarch64_sve_brkb_z) ||
1258        (OpIID == Intrinsic::aarch64_sve_brkpa_z) ||
1259        (OpIID == Intrinsic::aarch64_sve_brkpb_z) ||
1260        (OpIID == Intrinsic::aarch64_sve_rdffr_z) ||
1261        (OpIID == Intrinsic::aarch64_sve_and_z) ||
1262        (OpIID == Intrinsic::aarch64_sve_bic_z) ||
1263        (OpIID == Intrinsic::aarch64_sve_eor_z) ||
1264        (OpIID == Intrinsic::aarch64_sve_nand_z) ||
1265        (OpIID == Intrinsic::aarch64_sve_nor_z) ||
1266        (OpIID == Intrinsic::aarch64_sve_orn_z) ||
1267        (OpIID == Intrinsic::aarch64_sve_orr_z))) {
1268     Value *Ops[] = {Pg->getArgOperand(0), Pg};
1269     Type *Tys[] = {Pg->getType()};
1270 
1271     auto *PTest = IC.Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
1272     PTest->takeName(&II);
1273 
1274     return IC.replaceInstUsesWith(II, PTest);
1275   }
1276 
1277   return std::nullopt;
1278 }
1279 
1280 template <Intrinsic::ID MulOpc, typename Intrinsic::ID FuseOpc>
1281 static std::optional<Instruction *>
1282 instCombineSVEVectorFuseMulAddSub(InstCombiner &IC, IntrinsicInst &II,
1283                                   bool MergeIntoAddendOp) {
1284   Value *P = II.getOperand(0);
1285   Value *MulOp0, *MulOp1, *AddendOp, *Mul;
1286   if (MergeIntoAddendOp) {
1287     AddendOp = II.getOperand(1);
1288     Mul = II.getOperand(2);
1289   } else {
1290     AddendOp = II.getOperand(2);
1291     Mul = II.getOperand(1);
1292   }
1293 
1294   if (!match(Mul, m_Intrinsic<MulOpc>(m_Specific(P), m_Value(MulOp0),
1295                                       m_Value(MulOp1))))
1296     return std::nullopt;
1297 
1298   if (!Mul->hasOneUse())
1299     return std::nullopt;
1300 
1301   Instruction *FMFSource = nullptr;
1302   if (II.getType()->isFPOrFPVectorTy()) {
1303     llvm::FastMathFlags FAddFlags = II.getFastMathFlags();
1304     // Stop the combine when the flags on the inputs differ in case dropping
1305     // flags would lead to us missing out on more beneficial optimizations.
1306     if (FAddFlags != cast<CallInst>(Mul)->getFastMathFlags())
1307       return std::nullopt;
1308     if (!FAddFlags.allowContract())
1309       return std::nullopt;
1310     FMFSource = &II;
1311   }
1312 
1313   CallInst *Res;
1314   if (MergeIntoAddendOp)
1315     Res = IC.Builder.CreateIntrinsic(FuseOpc, {II.getType()},
1316                                      {P, AddendOp, MulOp0, MulOp1}, FMFSource);
1317   else
1318     Res = IC.Builder.CreateIntrinsic(FuseOpc, {II.getType()},
1319                                      {P, MulOp0, MulOp1, AddendOp}, FMFSource);
1320 
1321   return IC.replaceInstUsesWith(II, Res);
1322 }
1323 
1324 static std::optional<Instruction *>
1325 instCombineSVELD1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
1326   Value *Pred = II.getOperand(0);
1327   Value *PtrOp = II.getOperand(1);
1328   Type *VecTy = II.getType();
1329 
1330   if (isAllActivePredicate(Pred)) {
1331     LoadInst *Load = IC.Builder.CreateLoad(VecTy, PtrOp);
1332     Load->copyMetadata(II);
1333     return IC.replaceInstUsesWith(II, Load);
1334   }
1335 
1336   CallInst *MaskedLoad =
1337       IC.Builder.CreateMaskedLoad(VecTy, PtrOp, PtrOp->getPointerAlignment(DL),
1338                                   Pred, ConstantAggregateZero::get(VecTy));
1339   MaskedLoad->copyMetadata(II);
1340   return IC.replaceInstUsesWith(II, MaskedLoad);
1341 }
1342 
1343 static std::optional<Instruction *>
1344 instCombineSVEST1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
1345   Value *VecOp = II.getOperand(0);
1346   Value *Pred = II.getOperand(1);
1347   Value *PtrOp = II.getOperand(2);
1348 
1349   if (isAllActivePredicate(Pred)) {
1350     StoreInst *Store = IC.Builder.CreateStore(VecOp, PtrOp);
1351     Store->copyMetadata(II);
1352     return IC.eraseInstFromFunction(II);
1353   }
1354 
1355   CallInst *MaskedStore = IC.Builder.CreateMaskedStore(
1356       VecOp, PtrOp, PtrOp->getPointerAlignment(DL), Pred);
1357   MaskedStore->copyMetadata(II);
1358   return IC.eraseInstFromFunction(II);
1359 }
1360 
1361 static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) {
1362   switch (Intrinsic) {
1363   case Intrinsic::aarch64_sve_fmul_u:
1364     return Instruction::BinaryOps::FMul;
1365   case Intrinsic::aarch64_sve_fadd_u:
1366     return Instruction::BinaryOps::FAdd;
1367   case Intrinsic::aarch64_sve_fsub_u:
1368     return Instruction::BinaryOps::FSub;
1369   default:
1370     return Instruction::BinaryOpsEnd;
1371   }
1372 }
1373 
1374 static std::optional<Instruction *>
1375 instCombineSVEVectorBinOp(InstCombiner &IC, IntrinsicInst &II) {
1376   // Bail due to missing support for ISD::STRICT_ scalable vector operations.
1377   if (II.isStrictFP())
1378     return std::nullopt;
1379 
1380   auto *OpPredicate = II.getOperand(0);
1381   auto BinOpCode = intrinsicIDToBinOpCode(II.getIntrinsicID());
1382   if (BinOpCode == Instruction::BinaryOpsEnd ||
1383       !match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
1384                               m_ConstantInt<AArch64SVEPredPattern::all>())))
1385     return std::nullopt;
1386   IRBuilderBase::FastMathFlagGuard FMFGuard(IC.Builder);
1387   IC.Builder.setFastMathFlags(II.getFastMathFlags());
1388   auto BinOp =
1389       IC.Builder.CreateBinOp(BinOpCode, II.getOperand(1), II.getOperand(2));
1390   return IC.replaceInstUsesWith(II, BinOp);
1391 }
1392 
1393 // Canonicalise operations that take an all active predicate (e.g. sve.add ->
1394 // sve.add_u).
1395 static std::optional<Instruction *> instCombineSVEAllActive(IntrinsicInst &II,
1396                                                             Intrinsic::ID IID) {
1397   auto *OpPredicate = II.getOperand(0);
1398   if (!match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>(
1399                               m_ConstantInt<AArch64SVEPredPattern::all>())))
1400     return std::nullopt;
1401 
1402   auto *Mod = II.getModule();
1403   auto *NewDecl = Intrinsic::getDeclaration(Mod, IID, {II.getType()});
1404   II.setCalledFunction(NewDecl);
1405 
1406   return &II;
1407 }
1408 
1409 static std::optional<Instruction *> instCombineSVEVectorAdd(InstCombiner &IC,
1410                                                             IntrinsicInst &II) {
1411   if (auto II_U = instCombineSVEAllActive(II, Intrinsic::aarch64_sve_add_u))
1412     return II_U;
1413   if (auto MLA = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
1414                                                    Intrinsic::aarch64_sve_mla>(
1415           IC, II, true))
1416     return MLA;
1417   if (auto MAD = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
1418                                                    Intrinsic::aarch64_sve_mad>(
1419           IC, II, false))
1420     return MAD;
1421   return std::nullopt;
1422 }
1423 
1424 static std::optional<Instruction *>
1425 instCombineSVEVectorFAdd(InstCombiner &IC, IntrinsicInst &II) {
1426   if (auto II_U = instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fadd_u))
1427     return II_U;
1428   if (auto FMLA =
1429           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1430                                             Intrinsic::aarch64_sve_fmla>(IC, II,
1431                                                                          true))
1432     return FMLA;
1433   if (auto FMAD =
1434           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1435                                             Intrinsic::aarch64_sve_fmad>(IC, II,
1436                                                                          false))
1437     return FMAD;
1438   if (auto FMLA =
1439           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
1440                                             Intrinsic::aarch64_sve_fmla>(IC, II,
1441                                                                          true))
1442     return FMLA;
1443   return std::nullopt;
1444 }
1445 
1446 static std::optional<Instruction *>
1447 instCombineSVEVectorFAddU(InstCombiner &IC, IntrinsicInst &II) {
1448   if (auto FMLA =
1449           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1450                                             Intrinsic::aarch64_sve_fmla>(IC, II,
1451                                                                          true))
1452     return FMLA;
1453   if (auto FMAD =
1454           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1455                                             Intrinsic::aarch64_sve_fmad>(IC, II,
1456                                                                          false))
1457     return FMAD;
1458   if (auto FMLA_U =
1459           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
1460                                             Intrinsic::aarch64_sve_fmla_u>(
1461               IC, II, true))
1462     return FMLA_U;
1463   return instCombineSVEVectorBinOp(IC, II);
1464 }
1465 
1466 static std::optional<Instruction *>
1467 instCombineSVEVectorFSub(InstCombiner &IC, IntrinsicInst &II) {
1468   if (auto II_U = instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fsub_u))
1469     return II_U;
1470   if (auto FMLS =
1471           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1472                                             Intrinsic::aarch64_sve_fmls>(IC, II,
1473                                                                          true))
1474     return FMLS;
1475   if (auto FMSB =
1476           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1477                                             Intrinsic::aarch64_sve_fnmsb>(
1478               IC, II, false))
1479     return FMSB;
1480   if (auto FMLS =
1481           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
1482                                             Intrinsic::aarch64_sve_fmls>(IC, II,
1483                                                                          true))
1484     return FMLS;
1485   return std::nullopt;
1486 }
1487 
1488 static std::optional<Instruction *>
1489 instCombineSVEVectorFSubU(InstCombiner &IC, IntrinsicInst &II) {
1490   if (auto FMLS =
1491           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1492                                             Intrinsic::aarch64_sve_fmls>(IC, II,
1493                                                                          true))
1494     return FMLS;
1495   if (auto FMSB =
1496           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
1497                                             Intrinsic::aarch64_sve_fnmsb>(
1498               IC, II, false))
1499     return FMSB;
1500   if (auto FMLS_U =
1501           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
1502                                             Intrinsic::aarch64_sve_fmls_u>(
1503               IC, II, true))
1504     return FMLS_U;
1505   return instCombineSVEVectorBinOp(IC, II);
1506 }
1507 
1508 static std::optional<Instruction *> instCombineSVEVectorSub(InstCombiner &IC,
1509                                                             IntrinsicInst &II) {
1510   if (auto II_U = instCombineSVEAllActive(II, Intrinsic::aarch64_sve_sub_u))
1511     return II_U;
1512   if (auto MLS = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
1513                                                    Intrinsic::aarch64_sve_mls>(
1514           IC, II, true))
1515     return MLS;
1516   return std::nullopt;
1517 }
1518 
1519 static std::optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC,
1520                                                             IntrinsicInst &II,
1521                                                             Intrinsic::ID IID) {
1522   auto *OpPredicate = II.getOperand(0);
1523   auto *OpMultiplicand = II.getOperand(1);
1524   auto *OpMultiplier = II.getOperand(2);
1525 
1526   // Canonicalise a non _u intrinsic only.
1527   if (II.getIntrinsicID() != IID)
1528     if (auto II_U = instCombineSVEAllActive(II, IID))
1529       return II_U;
1530 
1531   // Return true if a given instruction is a unit splat value, false otherwise.
1532   auto IsUnitSplat = [](auto *I) {
1533     auto *SplatValue = getSplatValue(I);
1534     if (!SplatValue)
1535       return false;
1536     return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
1537   };
1538 
1539   // Return true if a given instruction is an aarch64_sve_dup intrinsic call
1540   // with a unit splat value, false otherwise.
1541   auto IsUnitDup = [](auto *I) {
1542     auto *IntrI = dyn_cast<IntrinsicInst>(I);
1543     if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup)
1544       return false;
1545 
1546     auto *SplatValue = IntrI->getOperand(2);
1547     return match(SplatValue, m_FPOne()) || match(SplatValue, m_One());
1548   };
1549 
1550   if (IsUnitSplat(OpMultiplier)) {
1551     // [f]mul pg %n, (dupx 1) => %n
1552     OpMultiplicand->takeName(&II);
1553     return IC.replaceInstUsesWith(II, OpMultiplicand);
1554   } else if (IsUnitDup(OpMultiplier)) {
1555     // [f]mul pg %n, (dup pg 1) => %n
1556     auto *DupInst = cast<IntrinsicInst>(OpMultiplier);
1557     auto *DupPg = DupInst->getOperand(1);
1558     // TODO: this is naive. The optimization is still valid if DupPg
1559     // 'encompasses' OpPredicate, not only if they're the same predicate.
1560     if (OpPredicate == DupPg) {
1561       OpMultiplicand->takeName(&II);
1562       return IC.replaceInstUsesWith(II, OpMultiplicand);
1563     }
1564   }
1565 
1566   return instCombineSVEVectorBinOp(IC, II);
1567 }
1568 
1569 static std::optional<Instruction *> instCombineSVEUnpack(InstCombiner &IC,
1570                                                          IntrinsicInst &II) {
1571   Value *UnpackArg = II.getArgOperand(0);
1572   auto *RetTy = cast<ScalableVectorType>(II.getType());
1573   bool IsSigned = II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpkhi ||
1574                   II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpklo;
1575 
1576   // Hi = uunpkhi(splat(X)) --> Hi = splat(extend(X))
1577   // Lo = uunpklo(splat(X)) --> Lo = splat(extend(X))
1578   if (auto *ScalarArg = getSplatValue(UnpackArg)) {
1579     ScalarArg =
1580         IC.Builder.CreateIntCast(ScalarArg, RetTy->getScalarType(), IsSigned);
1581     Value *NewVal =
1582         IC.Builder.CreateVectorSplat(RetTy->getElementCount(), ScalarArg);
1583     NewVal->takeName(&II);
1584     return IC.replaceInstUsesWith(II, NewVal);
1585   }
1586 
1587   return std::nullopt;
1588 }
1589 static std::optional<Instruction *> instCombineSVETBL(InstCombiner &IC,
1590                                                       IntrinsicInst &II) {
1591   auto *OpVal = II.getOperand(0);
1592   auto *OpIndices = II.getOperand(1);
1593   VectorType *VTy = cast<VectorType>(II.getType());
1594 
1595   // Check whether OpIndices is a constant splat value < minimal element count
1596   // of result.
1597   auto *SplatValue = dyn_cast_or_null<ConstantInt>(getSplatValue(OpIndices));
1598   if (!SplatValue ||
1599       SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue()))
1600     return std::nullopt;
1601 
1602   // Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to
1603   // splat_vector(extractelement(OpVal, SplatValue)) for further optimization.
1604   auto *Extract = IC.Builder.CreateExtractElement(OpVal, SplatValue);
1605   auto *VectorSplat =
1606       IC.Builder.CreateVectorSplat(VTy->getElementCount(), Extract);
1607 
1608   VectorSplat->takeName(&II);
1609   return IC.replaceInstUsesWith(II, VectorSplat);
1610 }
1611 
1612 static std::optional<Instruction *> instCombineSVEZip(InstCombiner &IC,
1613                                                       IntrinsicInst &II) {
1614   // zip1(uzp1(A, B), uzp2(A, B)) --> A
1615   // zip2(uzp1(A, B), uzp2(A, B)) --> B
1616   Value *A, *B;
1617   if (match(II.getArgOperand(0),
1618             m_Intrinsic<Intrinsic::aarch64_sve_uzp1>(m_Value(A), m_Value(B))) &&
1619       match(II.getArgOperand(1), m_Intrinsic<Intrinsic::aarch64_sve_uzp2>(
1620                                      m_Specific(A), m_Specific(B))))
1621     return IC.replaceInstUsesWith(
1622         II, (II.getIntrinsicID() == Intrinsic::aarch64_sve_zip1 ? A : B));
1623 
1624   return std::nullopt;
1625 }
1626 
1627 static std::optional<Instruction *>
1628 instCombineLD1GatherIndex(InstCombiner &IC, IntrinsicInst &II) {
1629   Value *Mask = II.getOperand(0);
1630   Value *BasePtr = II.getOperand(1);
1631   Value *Index = II.getOperand(2);
1632   Type *Ty = II.getType();
1633   Value *PassThru = ConstantAggregateZero::get(Ty);
1634 
1635   // Contiguous gather => masked load.
1636   // (sve.ld1.gather.index Mask BasePtr (sve.index IndexBase 1))
1637   // => (masked.load (gep BasePtr IndexBase) Align Mask zeroinitializer)
1638   Value *IndexBase;
1639   if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
1640                        m_Value(IndexBase), m_SpecificInt(1)))) {
1641     Align Alignment =
1642         BasePtr->getPointerAlignment(II.getModule()->getDataLayout());
1643 
1644     Type *VecPtrTy = PointerType::getUnqual(Ty);
1645     Value *Ptr = IC.Builder.CreateGEP(cast<VectorType>(Ty)->getElementType(),
1646                                       BasePtr, IndexBase);
1647     Ptr = IC.Builder.CreateBitCast(Ptr, VecPtrTy);
1648     CallInst *MaskedLoad =
1649         IC.Builder.CreateMaskedLoad(Ty, Ptr, Alignment, Mask, PassThru);
1650     MaskedLoad->takeName(&II);
1651     return IC.replaceInstUsesWith(II, MaskedLoad);
1652   }
1653 
1654   return std::nullopt;
1655 }
1656 
1657 static std::optional<Instruction *>
1658 instCombineST1ScatterIndex(InstCombiner &IC, IntrinsicInst &II) {
1659   Value *Val = II.getOperand(0);
1660   Value *Mask = II.getOperand(1);
1661   Value *BasePtr = II.getOperand(2);
1662   Value *Index = II.getOperand(3);
1663   Type *Ty = Val->getType();
1664 
1665   // Contiguous scatter => masked store.
1666   // (sve.st1.scatter.index Value Mask BasePtr (sve.index IndexBase 1))
1667   // => (masked.store Value (gep BasePtr IndexBase) Align Mask)
1668   Value *IndexBase;
1669   if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
1670                        m_Value(IndexBase), m_SpecificInt(1)))) {
1671     Align Alignment =
1672         BasePtr->getPointerAlignment(II.getModule()->getDataLayout());
1673 
1674     Value *Ptr = IC.Builder.CreateGEP(cast<VectorType>(Ty)->getElementType(),
1675                                       BasePtr, IndexBase);
1676     Type *VecPtrTy = PointerType::getUnqual(Ty);
1677     Ptr = IC.Builder.CreateBitCast(Ptr, VecPtrTy);
1678 
1679     (void)IC.Builder.CreateMaskedStore(Val, Ptr, Alignment, Mask);
1680 
1681     return IC.eraseInstFromFunction(II);
1682   }
1683 
1684   return std::nullopt;
1685 }
1686 
1687 static std::optional<Instruction *> instCombineSVESDIV(InstCombiner &IC,
1688                                                        IntrinsicInst &II) {
1689   Type *Int32Ty = IC.Builder.getInt32Ty();
1690   Value *Pred = II.getOperand(0);
1691   Value *Vec = II.getOperand(1);
1692   Value *DivVec = II.getOperand(2);
1693 
1694   Value *SplatValue = getSplatValue(DivVec);
1695   ConstantInt *SplatConstantInt = dyn_cast_or_null<ConstantInt>(SplatValue);
1696   if (!SplatConstantInt)
1697     return std::nullopt;
1698   APInt Divisor = SplatConstantInt->getValue();
1699 
1700   if (Divisor.isPowerOf2()) {
1701     Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
1702     auto ASRD = IC.Builder.CreateIntrinsic(
1703         Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
1704     return IC.replaceInstUsesWith(II, ASRD);
1705   }
1706   if (Divisor.isNegatedPowerOf2()) {
1707     Divisor.negate();
1708     Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
1709     auto ASRD = IC.Builder.CreateIntrinsic(
1710         Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
1711     auto NEG = IC.Builder.CreateIntrinsic(
1712         Intrinsic::aarch64_sve_neg, {ASRD->getType()}, {ASRD, Pred, ASRD});
1713     return IC.replaceInstUsesWith(II, NEG);
1714   }
1715 
1716   return std::nullopt;
1717 }
1718 
1719 bool SimplifyValuePattern(SmallVector<Value *> &Vec, bool AllowPoison) {
1720   size_t VecSize = Vec.size();
1721   if (VecSize == 1)
1722     return true;
1723   if (!isPowerOf2_64(VecSize))
1724     return false;
1725   size_t HalfVecSize = VecSize / 2;
1726 
1727   for (auto LHS = Vec.begin(), RHS = Vec.begin() + HalfVecSize;
1728        RHS != Vec.end(); LHS++, RHS++) {
1729     if (*LHS != nullptr && *RHS != nullptr) {
1730       if (*LHS == *RHS)
1731         continue;
1732       else
1733         return false;
1734     }
1735     if (!AllowPoison)
1736       return false;
1737     if (*LHS == nullptr && *RHS != nullptr)
1738       *LHS = *RHS;
1739   }
1740 
1741   Vec.resize(HalfVecSize);
1742   SimplifyValuePattern(Vec, AllowPoison);
1743   return true;
1744 }
1745 
1746 // Try to simplify dupqlane patterns like dupqlane(f32 A, f32 B, f32 A, f32 B)
1747 // to dupqlane(f64(C)) where C is A concatenated with B
1748 static std::optional<Instruction *> instCombineSVEDupqLane(InstCombiner &IC,
1749                                                            IntrinsicInst &II) {
1750   Value *CurrentInsertElt = nullptr, *Default = nullptr;
1751   if (!match(II.getOperand(0),
1752              m_Intrinsic<Intrinsic::vector_insert>(
1753                  m_Value(Default), m_Value(CurrentInsertElt), m_Value())) ||
1754       !isa<FixedVectorType>(CurrentInsertElt->getType()))
1755     return std::nullopt;
1756   auto IIScalableTy = cast<ScalableVectorType>(II.getType());
1757 
1758   // Insert the scalars into a container ordered by InsertElement index
1759   SmallVector<Value *> Elts(IIScalableTy->getMinNumElements(), nullptr);
1760   while (auto InsertElt = dyn_cast<InsertElementInst>(CurrentInsertElt)) {
1761     auto Idx = cast<ConstantInt>(InsertElt->getOperand(2));
1762     Elts[Idx->getValue().getZExtValue()] = InsertElt->getOperand(1);
1763     CurrentInsertElt = InsertElt->getOperand(0);
1764   }
1765 
1766   bool AllowPoison =
1767       isa<PoisonValue>(CurrentInsertElt) && isa<PoisonValue>(Default);
1768   if (!SimplifyValuePattern(Elts, AllowPoison))
1769     return std::nullopt;
1770 
1771   // Rebuild the simplified chain of InsertElements. e.g. (a, b, a, b) as (a, b)
1772   Value *InsertEltChain = PoisonValue::get(CurrentInsertElt->getType());
1773   for (size_t I = 0; I < Elts.size(); I++) {
1774     if (Elts[I] == nullptr)
1775       continue;
1776     InsertEltChain = IC.Builder.CreateInsertElement(InsertEltChain, Elts[I],
1777                                                     IC.Builder.getInt64(I));
1778   }
1779   if (InsertEltChain == nullptr)
1780     return std::nullopt;
1781 
1782   // Splat the simplified sequence, e.g. (f16 a, f16 b, f16 c, f16 d) as one i64
1783   // value or (f16 a, f16 b) as one i32 value. This requires an InsertSubvector
1784   // be bitcast to a type wide enough to fit the sequence, be splatted, and then
1785   // be narrowed back to the original type.
1786   unsigned PatternWidth = IIScalableTy->getScalarSizeInBits() * Elts.size();
1787   unsigned PatternElementCount = IIScalableTy->getScalarSizeInBits() *
1788                                  IIScalableTy->getMinNumElements() /
1789                                  PatternWidth;
1790 
1791   IntegerType *WideTy = IC.Builder.getIntNTy(PatternWidth);
1792   auto *WideScalableTy = ScalableVectorType::get(WideTy, PatternElementCount);
1793   auto *WideShuffleMaskTy =
1794       ScalableVectorType::get(IC.Builder.getInt32Ty(), PatternElementCount);
1795 
1796   auto ZeroIdx = ConstantInt::get(IC.Builder.getInt64Ty(), APInt(64, 0));
1797   auto InsertSubvector = IC.Builder.CreateInsertVector(
1798       II.getType(), PoisonValue::get(II.getType()), InsertEltChain, ZeroIdx);
1799   auto WideBitcast =
1800       IC.Builder.CreateBitOrPointerCast(InsertSubvector, WideScalableTy);
1801   auto WideShuffleMask = ConstantAggregateZero::get(WideShuffleMaskTy);
1802   auto WideShuffle = IC.Builder.CreateShuffleVector(
1803       WideBitcast, PoisonValue::get(WideScalableTy), WideShuffleMask);
1804   auto NarrowBitcast =
1805       IC.Builder.CreateBitOrPointerCast(WideShuffle, II.getType());
1806 
1807   return IC.replaceInstUsesWith(II, NarrowBitcast);
1808 }
1809 
1810 static std::optional<Instruction *> instCombineMaxMinNM(InstCombiner &IC,
1811                                                         IntrinsicInst &II) {
1812   Value *A = II.getArgOperand(0);
1813   Value *B = II.getArgOperand(1);
1814   if (A == B)
1815     return IC.replaceInstUsesWith(II, A);
1816 
1817   return std::nullopt;
1818 }
1819 
1820 static std::optional<Instruction *> instCombineSVESrshl(InstCombiner &IC,
1821                                                         IntrinsicInst &II) {
1822   Value *Pred = II.getOperand(0);
1823   Value *Vec = II.getOperand(1);
1824   Value *Shift = II.getOperand(2);
1825 
1826   // Convert SRSHL into the simpler LSL intrinsic when fed by an ABS intrinsic.
1827   Value *AbsPred, *MergedValue;
1828   if (!match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_sqabs>(
1829                       m_Value(MergedValue), m_Value(AbsPred), m_Value())) &&
1830       !match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_abs>(
1831                       m_Value(MergedValue), m_Value(AbsPred), m_Value())))
1832 
1833     return std::nullopt;
1834 
1835   // Transform is valid if any of the following are true:
1836   // * The ABS merge value is an undef or non-negative
1837   // * The ABS predicate is all active
1838   // * The ABS predicate and the SRSHL predicates are the same
1839   if (!isa<UndefValue>(MergedValue) && !match(MergedValue, m_NonNegative()) &&
1840       AbsPred != Pred && !isAllActivePredicate(AbsPred))
1841     return std::nullopt;
1842 
1843   // Only valid when the shift amount is non-negative, otherwise the rounding
1844   // behaviour of SRSHL cannot be ignored.
1845   if (!match(Shift, m_NonNegative()))
1846     return std::nullopt;
1847 
1848   auto LSL = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_lsl,
1849                                         {II.getType()}, {Pred, Vec, Shift});
1850 
1851   return IC.replaceInstUsesWith(II, LSL);
1852 }
1853 
1854 std::optional<Instruction *>
1855 AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
1856                                      IntrinsicInst &II) const {
1857   Intrinsic::ID IID = II.getIntrinsicID();
1858   switch (IID) {
1859   default:
1860     break;
1861   case Intrinsic::aarch64_neon_fmaxnm:
1862   case Intrinsic::aarch64_neon_fminnm:
1863     return instCombineMaxMinNM(IC, II);
1864   case Intrinsic::aarch64_sve_convert_from_svbool:
1865     return instCombineConvertFromSVBool(IC, II);
1866   case Intrinsic::aarch64_sve_dup:
1867     return instCombineSVEDup(IC, II);
1868   case Intrinsic::aarch64_sve_dup_x:
1869     return instCombineSVEDupX(IC, II);
1870   case Intrinsic::aarch64_sve_cmpne:
1871   case Intrinsic::aarch64_sve_cmpne_wide:
1872     return instCombineSVECmpNE(IC, II);
1873   case Intrinsic::aarch64_sve_rdffr:
1874     return instCombineRDFFR(IC, II);
1875   case Intrinsic::aarch64_sve_lasta:
1876   case Intrinsic::aarch64_sve_lastb:
1877     return instCombineSVELast(IC, II);
1878   case Intrinsic::aarch64_sve_clasta_n:
1879   case Intrinsic::aarch64_sve_clastb_n:
1880     return instCombineSVECondLast(IC, II);
1881   case Intrinsic::aarch64_sve_cntd:
1882     return instCombineSVECntElts(IC, II, 2);
1883   case Intrinsic::aarch64_sve_cntw:
1884     return instCombineSVECntElts(IC, II, 4);
1885   case Intrinsic::aarch64_sve_cnth:
1886     return instCombineSVECntElts(IC, II, 8);
1887   case Intrinsic::aarch64_sve_cntb:
1888     return instCombineSVECntElts(IC, II, 16);
1889   case Intrinsic::aarch64_sve_ptest_any:
1890   case Intrinsic::aarch64_sve_ptest_first:
1891   case Intrinsic::aarch64_sve_ptest_last:
1892     return instCombineSVEPTest(IC, II);
1893   case Intrinsic::aarch64_sve_fabd:
1894     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fabd_u);
1895   case Intrinsic::aarch64_sve_fadd:
1896     return instCombineSVEVectorFAdd(IC, II);
1897   case Intrinsic::aarch64_sve_fadd_u:
1898     return instCombineSVEVectorFAddU(IC, II);
1899   case Intrinsic::aarch64_sve_fdiv:
1900     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fdiv_u);
1901   case Intrinsic::aarch64_sve_fmax:
1902     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmax_u);
1903   case Intrinsic::aarch64_sve_fmaxnm:
1904     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmaxnm_u);
1905   case Intrinsic::aarch64_sve_fmin:
1906     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmin_u);
1907   case Intrinsic::aarch64_sve_fminnm:
1908     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fminnm_u);
1909   case Intrinsic::aarch64_sve_fmla:
1910     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmla_u);
1911   case Intrinsic::aarch64_sve_fmls:
1912     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmls_u);
1913   case Intrinsic::aarch64_sve_fmul:
1914   case Intrinsic::aarch64_sve_fmul_u:
1915     return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_fmul_u);
1916   case Intrinsic::aarch64_sve_fmulx:
1917     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmulx_u);
1918   case Intrinsic::aarch64_sve_fnmla:
1919     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fnmla_u);
1920   case Intrinsic::aarch64_sve_fnmls:
1921     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fnmls_u);
1922   case Intrinsic::aarch64_sve_fsub:
1923     return instCombineSVEVectorFSub(IC, II);
1924   case Intrinsic::aarch64_sve_fsub_u:
1925     return instCombineSVEVectorFSubU(IC, II);
1926   case Intrinsic::aarch64_sve_add:
1927     return instCombineSVEVectorAdd(IC, II);
1928   case Intrinsic::aarch64_sve_add_u:
1929     return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul_u,
1930                                              Intrinsic::aarch64_sve_mla_u>(
1931         IC, II, true);
1932   case Intrinsic::aarch64_sve_mla:
1933     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_mla_u);
1934   case Intrinsic::aarch64_sve_mls:
1935     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_mls_u);
1936   case Intrinsic::aarch64_sve_mul:
1937   case Intrinsic::aarch64_sve_mul_u:
1938     return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_mul_u);
1939   case Intrinsic::aarch64_sve_sabd:
1940     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_sabd_u);
1941   case Intrinsic::aarch64_sve_smax:
1942     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_smax_u);
1943   case Intrinsic::aarch64_sve_smin:
1944     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_smin_u);
1945   case Intrinsic::aarch64_sve_smulh:
1946     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_smulh_u);
1947   case Intrinsic::aarch64_sve_sub:
1948     return instCombineSVEVectorSub(IC, II);
1949   case Intrinsic::aarch64_sve_sub_u:
1950     return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul_u,
1951                                              Intrinsic::aarch64_sve_mls_u>(
1952         IC, II, true);
1953   case Intrinsic::aarch64_sve_uabd:
1954     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_uabd_u);
1955   case Intrinsic::aarch64_sve_umax:
1956     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_umax_u);
1957   case Intrinsic::aarch64_sve_umin:
1958     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_umin_u);
1959   case Intrinsic::aarch64_sve_umulh:
1960     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_umulh_u);
1961   case Intrinsic::aarch64_sve_asr:
1962     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_asr_u);
1963   case Intrinsic::aarch64_sve_lsl:
1964     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_lsl_u);
1965   case Intrinsic::aarch64_sve_lsr:
1966     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_lsr_u);
1967   case Intrinsic::aarch64_sve_and:
1968     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_and_u);
1969   case Intrinsic::aarch64_sve_bic:
1970     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_bic_u);
1971   case Intrinsic::aarch64_sve_eor:
1972     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_eor_u);
1973   case Intrinsic::aarch64_sve_orr:
1974     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_orr_u);
1975   case Intrinsic::aarch64_sve_sqsub:
1976     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_sqsub_u);
1977   case Intrinsic::aarch64_sve_uqsub:
1978     return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_uqsub_u);
1979   case Intrinsic::aarch64_sve_tbl:
1980     return instCombineSVETBL(IC, II);
1981   case Intrinsic::aarch64_sve_uunpkhi:
1982   case Intrinsic::aarch64_sve_uunpklo:
1983   case Intrinsic::aarch64_sve_sunpkhi:
1984   case Intrinsic::aarch64_sve_sunpklo:
1985     return instCombineSVEUnpack(IC, II);
1986   case Intrinsic::aarch64_sve_zip1:
1987   case Intrinsic::aarch64_sve_zip2:
1988     return instCombineSVEZip(IC, II);
1989   case Intrinsic::aarch64_sve_ld1_gather_index:
1990     return instCombineLD1GatherIndex(IC, II);
1991   case Intrinsic::aarch64_sve_st1_scatter_index:
1992     return instCombineST1ScatterIndex(IC, II);
1993   case Intrinsic::aarch64_sve_ld1:
1994     return instCombineSVELD1(IC, II, DL);
1995   case Intrinsic::aarch64_sve_st1:
1996     return instCombineSVEST1(IC, II, DL);
1997   case Intrinsic::aarch64_sve_sdiv:
1998     return instCombineSVESDIV(IC, II);
1999   case Intrinsic::aarch64_sve_sel:
2000     return instCombineSVESel(IC, II);
2001   case Intrinsic::aarch64_sve_srshl:
2002     return instCombineSVESrshl(IC, II);
2003   case Intrinsic::aarch64_sve_dupq_lane:
2004     return instCombineSVEDupqLane(IC, II);
2005   }
2006 
2007   return std::nullopt;
2008 }
2009 
2010 std::optional<Value *> AArch64TTIImpl::simplifyDemandedVectorEltsIntrinsic(
2011     InstCombiner &IC, IntrinsicInst &II, APInt OrigDemandedElts,
2012     APInt &UndefElts, APInt &UndefElts2, APInt &UndefElts3,
2013     std::function<void(Instruction *, unsigned, APInt, APInt &)>
2014         SimplifyAndSetOp) const {
2015   switch (II.getIntrinsicID()) {
2016   default:
2017     break;
2018   case Intrinsic::aarch64_neon_fcvtxn:
2019   case Intrinsic::aarch64_neon_rshrn:
2020   case Intrinsic::aarch64_neon_sqrshrn:
2021   case Intrinsic::aarch64_neon_sqrshrun:
2022   case Intrinsic::aarch64_neon_sqshrn:
2023   case Intrinsic::aarch64_neon_sqshrun:
2024   case Intrinsic::aarch64_neon_sqxtn:
2025   case Intrinsic::aarch64_neon_sqxtun:
2026   case Intrinsic::aarch64_neon_uqrshrn:
2027   case Intrinsic::aarch64_neon_uqshrn:
2028   case Intrinsic::aarch64_neon_uqxtn:
2029     SimplifyAndSetOp(&II, 0, OrigDemandedElts, UndefElts);
2030     break;
2031   }
2032 
2033   return std::nullopt;
2034 }
2035 
2036 TypeSize
2037 AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
2038   switch (K) {
2039   case TargetTransformInfo::RGK_Scalar:
2040     return TypeSize::getFixed(64);
2041   case TargetTransformInfo::RGK_FixedWidthVector:
2042     if (!ST->isNeonAvailable() && !EnableFixedwidthAutovecInStreamingMode)
2043       return TypeSize::getFixed(0);
2044 
2045     if (ST->hasSVE())
2046       return TypeSize::getFixed(
2047           std::max(ST->getMinSVEVectorSizeInBits(), 128u));
2048 
2049     return TypeSize::getFixed(ST->hasNEON() ? 128 : 0);
2050   case TargetTransformInfo::RGK_ScalableVector:
2051     if (!ST->isSVEAvailable() && !EnableScalableAutovecInStreamingMode)
2052       return TypeSize::getScalable(0);
2053 
2054     return TypeSize::getScalable(ST->hasSVE() ? 128 : 0);
2055   }
2056   llvm_unreachable("Unsupported register kind");
2057 }
2058 
2059 bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
2060                                            ArrayRef<const Value *> Args,
2061                                            Type *SrcOverrideTy) {
2062   // A helper that returns a vector type from the given type. The number of
2063   // elements in type Ty determines the vector width.
2064   auto toVectorTy = [&](Type *ArgTy) {
2065     return VectorType::get(ArgTy->getScalarType(),
2066                            cast<VectorType>(DstTy)->getElementCount());
2067   };
2068 
2069   // Exit early if DstTy is not a vector type whose elements are one of [i16,
2070   // i32, i64]. SVE doesn't generally have the same set of instructions to
2071   // perform an extend with the add/sub/mul. There are SMULLB style
2072   // instructions, but they operate on top/bottom, requiring some sort of lane
2073   // interleaving to be used with zext/sext.
2074   unsigned DstEltSize = DstTy->getScalarSizeInBits();
2075   if (!useNeonVector(DstTy) || Args.size() != 2 ||
2076       (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64))
2077     return false;
2078 
2079   // Determine if the operation has a widening variant. We consider both the
2080   // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
2081   // instructions.
2082   //
2083   // TODO: Add additional widening operations (e.g., shl, etc.) once we
2084   //       verify that their extending operands are eliminated during code
2085   //       generation.
2086   Type *SrcTy = SrcOverrideTy;
2087   switch (Opcode) {
2088   case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
2089   case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
2090     // The second operand needs to be an extend
2091     if (isa<SExtInst>(Args[1]) || isa<ZExtInst>(Args[1])) {
2092       if (!SrcTy)
2093         SrcTy =
2094             toVectorTy(cast<Instruction>(Args[1])->getOperand(0)->getType());
2095     } else
2096       return false;
2097     break;
2098   case Instruction::Mul: { // SMULL(2), UMULL(2)
2099     // Both operands need to be extends of the same type.
2100     if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) ||
2101         (isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) {
2102       if (!SrcTy)
2103         SrcTy =
2104             toVectorTy(cast<Instruction>(Args[0])->getOperand(0)->getType());
2105     } else if (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1])) {
2106       // If one of the operands is a Zext and the other has enough zero bits to
2107       // be treated as unsigned, we can still general a umull, meaning the zext
2108       // is free.
2109       KnownBits Known =
2110           computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL);
2111       if (Args[0]->getType()->getScalarSizeInBits() -
2112               Known.Zero.countLeadingOnes() >
2113           DstTy->getScalarSizeInBits() / 2)
2114         return false;
2115       if (!SrcTy)
2116         SrcTy = toVectorTy(Type::getIntNTy(DstTy->getContext(),
2117                                            DstTy->getScalarSizeInBits() / 2));
2118     } else
2119       return false;
2120     break;
2121   }
2122   default:
2123     return false;
2124   }
2125 
2126   // Legalize the destination type and ensure it can be used in a widening
2127   // operation.
2128   auto DstTyL = getTypeLegalizationCost(DstTy);
2129   if (!DstTyL.second.isVector() || DstEltSize != DstTy->getScalarSizeInBits())
2130     return false;
2131 
2132   // Legalize the source type and ensure it can be used in a widening
2133   // operation.
2134   assert(SrcTy && "Expected some SrcTy");
2135   auto SrcTyL = getTypeLegalizationCost(SrcTy);
2136   unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits();
2137   if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits())
2138     return false;
2139 
2140   // Get the total number of vector elements in the legalized types.
2141   InstructionCost NumDstEls =
2142       DstTyL.first * DstTyL.second.getVectorMinNumElements();
2143   InstructionCost NumSrcEls =
2144       SrcTyL.first * SrcTyL.second.getVectorMinNumElements();
2145 
2146   // Return true if the legalized types have the same number of vector elements
2147   // and the destination element type size is twice that of the source type.
2148   return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize;
2149 }
2150 
2151 // s/urhadd instructions implement the following pattern, making the
2152 // extends free:
2153 //   %x = add ((zext i8 -> i16), 1)
2154 //   %y = (zext i8 -> i16)
2155 //   trunc i16 (lshr (add %x, %y), 1) -> i8
2156 //
2157 bool AArch64TTIImpl::isExtPartOfAvgExpr(const Instruction *ExtUser, Type *Dst,
2158                                         Type *Src) {
2159   // The source should be a legal vector type.
2160   if (!Src->isVectorTy() || !TLI->isTypeLegal(TLI->getValueType(DL, Src)) ||
2161       (Src->isScalableTy() && !ST->hasSVE2()))
2162     return false;
2163 
2164   if (ExtUser->getOpcode() != Instruction::Add || !ExtUser->hasOneUse())
2165     return false;
2166 
2167   // Look for trunc/shl/add before trying to match the pattern.
2168   const Instruction *Add = ExtUser;
2169   auto *AddUser =
2170       dyn_cast_or_null<Instruction>(Add->getUniqueUndroppableUser());
2171   if (AddUser && AddUser->getOpcode() == Instruction::Add)
2172     Add = AddUser;
2173 
2174   auto *Shr = dyn_cast_or_null<Instruction>(Add->getUniqueUndroppableUser());
2175   if (!Shr || Shr->getOpcode() != Instruction::LShr)
2176     return false;
2177 
2178   auto *Trunc = dyn_cast_or_null<Instruction>(Shr->getUniqueUndroppableUser());
2179   if (!Trunc || Trunc->getOpcode() != Instruction::Trunc ||
2180       Src->getScalarSizeInBits() !=
2181           cast<CastInst>(Trunc)->getDestTy()->getScalarSizeInBits())
2182     return false;
2183 
2184   // Try to match the whole pattern. Ext could be either the first or second
2185   // m_ZExtOrSExt matched.
2186   Instruction *Ex1, *Ex2;
2187   if (!(match(Add, m_c_Add(m_Instruction(Ex1),
2188                            m_c_Add(m_Instruction(Ex2), m_SpecificInt(1))))))
2189     return false;
2190 
2191   // Ensure both extends are of the same type
2192   if (match(Ex1, m_ZExtOrSExt(m_Value())) &&
2193       Ex1->getOpcode() == Ex2->getOpcode())
2194     return true;
2195 
2196   return false;
2197 }
2198 
2199 InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
2200                                                  Type *Src,
2201                                                  TTI::CastContextHint CCH,
2202                                                  TTI::TargetCostKind CostKind,
2203                                                  const Instruction *I) {
2204   int ISD = TLI->InstructionOpcodeToISD(Opcode);
2205   assert(ISD && "Invalid opcode");
2206   // If the cast is observable, and it is used by a widening instruction (e.g.,
2207   // uaddl, saddw, etc.), it may be free.
2208   if (I && I->hasOneUser()) {
2209     auto *SingleUser = cast<Instruction>(*I->user_begin());
2210     SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
2211     if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands, Src)) {
2212       // For adds only count the second operand as free if both operands are
2213       // extends but not the same operation. (i.e both operands are not free in
2214       // add(sext, zext)).
2215       if (SingleUser->getOpcode() == Instruction::Add) {
2216         if (I == SingleUser->getOperand(1) ||
2217             (isa<CastInst>(SingleUser->getOperand(1)) &&
2218              cast<CastInst>(SingleUser->getOperand(1))->getOpcode() == Opcode))
2219           return 0;
2220       } else // Others are free so long as isWideningInstruction returned true.
2221         return 0;
2222     }
2223 
2224     // The cast will be free for the s/urhadd instructions
2225     if ((isa<ZExtInst>(I) || isa<SExtInst>(I)) &&
2226         isExtPartOfAvgExpr(SingleUser, Dst, Src))
2227       return 0;
2228   }
2229 
2230   // TODO: Allow non-throughput costs that aren't binary.
2231   auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost {
2232     if (CostKind != TTI::TCK_RecipThroughput)
2233       return Cost == 0 ? 0 : 1;
2234     return Cost;
2235   };
2236 
2237   EVT SrcTy = TLI->getValueType(DL, Src);
2238   EVT DstTy = TLI->getValueType(DL, Dst);
2239 
2240   if (!SrcTy.isSimple() || !DstTy.isSimple())
2241     return AdjustCost(
2242         BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
2243 
2244   static const TypeConversionCostTblEntry
2245   ConversionTbl[] = {
2246     { ISD::TRUNCATE, MVT::v2i8,   MVT::v2i64,  1},  // xtn
2247     { ISD::TRUNCATE, MVT::v2i16,  MVT::v2i64,  1},  // xtn
2248     { ISD::TRUNCATE, MVT::v2i32,  MVT::v2i64,  1},  // xtn
2249     { ISD::TRUNCATE, MVT::v4i8,   MVT::v4i32,  1},  // xtn
2250     { ISD::TRUNCATE, MVT::v4i8,   MVT::v4i64,  3},  // 2 xtn + 1 uzp1
2251     { ISD::TRUNCATE, MVT::v4i16,  MVT::v4i32,  1},  // xtn
2252     { ISD::TRUNCATE, MVT::v4i16,  MVT::v4i64,  2},  // 1 uzp1 + 1 xtn
2253     { ISD::TRUNCATE, MVT::v4i32,  MVT::v4i64,  1},  // 1 uzp1
2254     { ISD::TRUNCATE, MVT::v8i8,   MVT::v8i16,  1},  // 1 xtn
2255     { ISD::TRUNCATE, MVT::v8i8,   MVT::v8i32,  2},  // 1 uzp1 + 1 xtn
2256     { ISD::TRUNCATE, MVT::v8i8,   MVT::v8i64,  4},  // 3 x uzp1 + xtn
2257     { ISD::TRUNCATE, MVT::v8i16,  MVT::v8i32,  1},  // 1 uzp1
2258     { ISD::TRUNCATE, MVT::v8i16,  MVT::v8i64,  3},  // 3 x uzp1
2259     { ISD::TRUNCATE, MVT::v8i32,  MVT::v8i64,  2},  // 2 x uzp1
2260     { ISD::TRUNCATE, MVT::v16i8,  MVT::v16i16, 1},  // uzp1
2261     { ISD::TRUNCATE, MVT::v16i8,  MVT::v16i32, 3},  // (2 + 1) x uzp1
2262     { ISD::TRUNCATE, MVT::v16i8,  MVT::v16i64, 7},  // (4 + 2 + 1) x uzp1
2263     { ISD::TRUNCATE, MVT::v16i16, MVT::v16i32, 2},  // 2 x uzp1
2264     { ISD::TRUNCATE, MVT::v16i16, MVT::v16i64, 6},  // (4 + 2) x uzp1
2265     { ISD::TRUNCATE, MVT::v16i32, MVT::v16i64, 4},  // 4 x uzp1
2266 
2267     // Truncations on nxvmiN
2268     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i16, 1 },
2269     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i32, 1 },
2270     { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i64, 1 },
2271     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i16, 1 },
2272     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i32, 1 },
2273     { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i64, 2 },
2274     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i16, 1 },
2275     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i32, 3 },
2276     { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i64, 5 },
2277     { ISD::TRUNCATE, MVT::nxv16i1, MVT::nxv16i8, 1 },
2278     { ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i32, 1 },
2279     { ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 1 },
2280     { ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 1 },
2281     { ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 2 },
2282     { ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 3 },
2283     { ISD::TRUNCATE, MVT::nxv8i32, MVT::nxv8i64, 6 },
2284 
2285     // The number of shll instructions for the extension.
2286     { ISD::SIGN_EXTEND, MVT::v4i64,  MVT::v4i16, 3 },
2287     { ISD::ZERO_EXTEND, MVT::v4i64,  MVT::v4i16, 3 },
2288     { ISD::SIGN_EXTEND, MVT::v4i64,  MVT::v4i32, 2 },
2289     { ISD::ZERO_EXTEND, MVT::v4i64,  MVT::v4i32, 2 },
2290     { ISD::SIGN_EXTEND, MVT::v8i32,  MVT::v8i8,  3 },
2291     { ISD::ZERO_EXTEND, MVT::v8i32,  MVT::v8i8,  3 },
2292     { ISD::SIGN_EXTEND, MVT::v8i32,  MVT::v8i16, 2 },
2293     { ISD::ZERO_EXTEND, MVT::v8i32,  MVT::v8i16, 2 },
2294     { ISD::SIGN_EXTEND, MVT::v8i64,  MVT::v8i8,  7 },
2295     { ISD::ZERO_EXTEND, MVT::v8i64,  MVT::v8i8,  7 },
2296     { ISD::SIGN_EXTEND, MVT::v8i64,  MVT::v8i16, 6 },
2297     { ISD::ZERO_EXTEND, MVT::v8i64,  MVT::v8i16, 6 },
2298     { ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
2299     { ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i8, 2 },
2300     { ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
2301     { ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i8, 6 },
2302 
2303     // LowerVectorINT_TO_FP:
2304     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
2305     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
2306     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
2307     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 },
2308     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 },
2309     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 },
2310 
2311     // Complex: to v2f32
2312     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i8,  3 },
2313     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
2314     { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
2315     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i8,  3 },
2316     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 },
2317     { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 },
2318 
2319     // Complex: to v4f32
2320     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i8,  4 },
2321     { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
2322     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8,  3 },
2323     { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 },
2324 
2325     // Complex: to v8f32
2326     { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
2327     { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
2328     { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8,  10 },
2329     { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 },
2330 
2331     // Complex: to v16f32
2332     { ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
2333     { ISD::UINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 },
2334 
2335     // Complex: to v2f64
2336     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
2337     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
2338     { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
2339     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i8,  4 },
2340     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 },
2341     { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 },
2342 
2343     // Complex: to v4f64
2344     { ISD::SINT_TO_FP, MVT::v4f64, MVT::v4i32,  4 },
2345     { ISD::UINT_TO_FP, MVT::v4f64, MVT::v4i32,  4 },
2346 
2347     // LowerVectorFP_TO_INT
2348     { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f32, 1 },
2349     { ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f32, 1 },
2350     { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f64, 1 },
2351     { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f32, 1 },
2352     { ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f32, 1 },
2353     { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f64, 1 },
2354 
2355     // Complex, from v2f32: legal type is v2i32 (no cost) or v2i64 (1 ext).
2356     { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f32, 2 },
2357     { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f32, 1 },
2358     { ISD::FP_TO_SINT, MVT::v2i8,  MVT::v2f32, 1 },
2359     { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f32, 2 },
2360     { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f32, 1 },
2361     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f32, 1 },
2362 
2363     // Complex, from v4f32: legal type is v4i16, 1 narrowing => ~2
2364     { ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f32, 2 },
2365     { ISD::FP_TO_SINT, MVT::v4i8,  MVT::v4f32, 2 },
2366     { ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 2 },
2367     { ISD::FP_TO_UINT, MVT::v4i8,  MVT::v4f32, 2 },
2368 
2369     // Complex, from nxv2f32.
2370     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
2371     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
2372     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
2373     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f32, 1 },
2374     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f32, 1 },
2375     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f32, 1 },
2376     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f32, 1 },
2377     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f32, 1 },
2378 
2379     // Complex, from v2f64: legal type is v2i32, 1 narrowing => ~2.
2380     { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f64, 2 },
2381     { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f64, 2 },
2382     { ISD::FP_TO_SINT, MVT::v2i8,  MVT::v2f64, 2 },
2383     { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f64, 2 },
2384     { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f64, 2 },
2385     { ISD::FP_TO_UINT, MVT::v2i8,  MVT::v2f64, 2 },
2386 
2387     // Complex, from nxv2f64.
2388     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
2389     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
2390     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
2391     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f64, 1 },
2392     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f64, 1 },
2393     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f64, 1 },
2394     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f64, 1 },
2395     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f64, 1 },
2396 
2397     // Complex, from nxv4f32.
2398     { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
2399     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
2400     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
2401     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f32, 1 },
2402     { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f32, 4 },
2403     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f32, 1 },
2404     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f32, 1 },
2405     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f32, 1 },
2406 
2407     // Complex, from nxv8f64. Illegal -> illegal conversions not required.
2408     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
2409     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f64, 7 },
2410     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f64, 7 },
2411     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f64, 7 },
2412 
2413     // Complex, from nxv4f64. Illegal -> illegal conversions not required.
2414     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
2415     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
2416     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f64, 3 },
2417     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f64, 3 },
2418     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f64, 3 },
2419     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f64, 3 },
2420 
2421     // Complex, from nxv8f32. Illegal -> illegal conversions not required.
2422     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
2423     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f32, 3 },
2424     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f32, 3 },
2425     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f32, 3 },
2426 
2427     // Complex, from nxv8f16.
2428     { ISD::FP_TO_SINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
2429     { ISD::FP_TO_SINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
2430     { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
2431     { ISD::FP_TO_SINT, MVT::nxv8i8,  MVT::nxv8f16, 1 },
2432     { ISD::FP_TO_UINT, MVT::nxv8i64, MVT::nxv8f16, 10 },
2433     { ISD::FP_TO_UINT, MVT::nxv8i32, MVT::nxv8f16, 4 },
2434     { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f16, 1 },
2435     { ISD::FP_TO_UINT, MVT::nxv8i8,  MVT::nxv8f16, 1 },
2436 
2437     // Complex, from nxv4f16.
2438     { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
2439     { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
2440     { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
2441     { ISD::FP_TO_SINT, MVT::nxv4i8,  MVT::nxv4f16, 1 },
2442     { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f16, 4 },
2443     { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f16, 1 },
2444     { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f16, 1 },
2445     { ISD::FP_TO_UINT, MVT::nxv4i8,  MVT::nxv4f16, 1 },
2446 
2447     // Complex, from nxv2f16.
2448     { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
2449     { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
2450     { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
2451     { ISD::FP_TO_SINT, MVT::nxv2i8,  MVT::nxv2f16, 1 },
2452     { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f16, 1 },
2453     { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f16, 1 },
2454     { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f16, 1 },
2455     { ISD::FP_TO_UINT, MVT::nxv2i8,  MVT::nxv2f16, 1 },
2456 
2457     // Truncate from nxvmf32 to nxvmf16.
2458     { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f32, 1 },
2459     { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1 },
2460     { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3 },
2461 
2462     // Truncate from nxvmf64 to nxvmf16.
2463     { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1 },
2464     { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3 },
2465     { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7 },
2466 
2467     // Truncate from nxvmf64 to nxvmf32.
2468     { ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1 },
2469     { ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3 },
2470     { ISD::FP_ROUND, MVT::nxv8f32, MVT::nxv8f64, 6 },
2471 
2472     // Extend from nxvmf16 to nxvmf32.
2473     { ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2f16, 1},
2474     { ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1},
2475     { ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2},
2476 
2477     // Extend from nxvmf16 to nxvmf64.
2478     { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1},
2479     { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2},
2480     { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4},
2481 
2482     // Extend from nxvmf32 to nxvmf64.
2483     { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1},
2484     { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2},
2485     { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f32, 6},
2486 
2487     // Bitcasts from float to integer
2488     { ISD::BITCAST, MVT::nxv2f16, MVT::nxv2i16, 0 },
2489     { ISD::BITCAST, MVT::nxv4f16, MVT::nxv4i16, 0 },
2490     { ISD::BITCAST, MVT::nxv2f32, MVT::nxv2i32, 0 },
2491 
2492     // Bitcasts from integer to float
2493     { ISD::BITCAST, MVT::nxv2i16, MVT::nxv2f16, 0 },
2494     { ISD::BITCAST, MVT::nxv4i16, MVT::nxv4f16, 0 },
2495     { ISD::BITCAST, MVT::nxv2i32, MVT::nxv2f32, 0 },
2496 
2497     // Add cost for extending to illegal -too wide- scalable vectors.
2498     // zero/sign extend are implemented by multiple unpack operations,
2499     // where each operation has a cost of 1.
2500     { ISD::ZERO_EXTEND, MVT::nxv16i16, MVT::nxv16i8, 2},
2501     { ISD::ZERO_EXTEND, MVT::nxv16i32, MVT::nxv16i8, 6},
2502     { ISD::ZERO_EXTEND, MVT::nxv16i64, MVT::nxv16i8, 14},
2503     { ISD::ZERO_EXTEND, MVT::nxv8i32, MVT::nxv8i16, 2},
2504     { ISD::ZERO_EXTEND, MVT::nxv8i64, MVT::nxv8i16, 6},
2505     { ISD::ZERO_EXTEND, MVT::nxv4i64, MVT::nxv4i32, 2},
2506 
2507     { ISD::SIGN_EXTEND, MVT::nxv16i16, MVT::nxv16i8, 2},
2508     { ISD::SIGN_EXTEND, MVT::nxv16i32, MVT::nxv16i8, 6},
2509     { ISD::SIGN_EXTEND, MVT::nxv16i64, MVT::nxv16i8, 14},
2510     { ISD::SIGN_EXTEND, MVT::nxv8i32, MVT::nxv8i16, 2},
2511     { ISD::SIGN_EXTEND, MVT::nxv8i64, MVT::nxv8i16, 6},
2512     { ISD::SIGN_EXTEND, MVT::nxv4i64, MVT::nxv4i32, 2},
2513   };
2514 
2515   // We have to estimate a cost of fixed length operation upon
2516   // SVE registers(operations) with the number of registers required
2517   // for a fixed type to be represented upon SVE registers.
2518   EVT WiderTy = SrcTy.bitsGT(DstTy) ? SrcTy : DstTy;
2519   if (SrcTy.isFixedLengthVector() && DstTy.isFixedLengthVector() &&
2520       SrcTy.getVectorNumElements() == DstTy.getVectorNumElements() &&
2521       ST->useSVEForFixedLengthVectors(WiderTy)) {
2522     std::pair<InstructionCost, MVT> LT =
2523         getTypeLegalizationCost(WiderTy.getTypeForEVT(Dst->getContext()));
2524     unsigned NumElements = AArch64::SVEBitsPerBlock /
2525                            LT.second.getVectorElementType().getSizeInBits();
2526     return AdjustCost(
2527         LT.first *
2528         getCastInstrCost(
2529             Opcode, ScalableVectorType::get(Dst->getScalarType(), NumElements),
2530             ScalableVectorType::get(Src->getScalarType(), NumElements), CCH,
2531             CostKind, I));
2532   }
2533 
2534   if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD,
2535                                                  DstTy.getSimpleVT(),
2536                                                  SrcTy.getSimpleVT()))
2537     return AdjustCost(Entry->Cost);
2538 
2539   static const TypeConversionCostTblEntry FP16Tbl[] = {
2540       {ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f16, 1}, // fcvtzs
2541       {ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f16, 1},
2542       {ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f16, 1}, // fcvtzs
2543       {ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f16, 1},
2544       {ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f16, 2}, // fcvtl+fcvtzs
2545       {ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f16, 2},
2546       {ISD::FP_TO_SINT, MVT::v8i8, MVT::v8f16, 2}, // fcvtzs+xtn
2547       {ISD::FP_TO_UINT, MVT::v8i8, MVT::v8f16, 2},
2548       {ISD::FP_TO_SINT, MVT::v8i16, MVT::v8f16, 1}, // fcvtzs
2549       {ISD::FP_TO_UINT, MVT::v8i16, MVT::v8f16, 1},
2550       {ISD::FP_TO_SINT, MVT::v8i32, MVT::v8f16, 4}, // 2*fcvtl+2*fcvtzs
2551       {ISD::FP_TO_UINT, MVT::v8i32, MVT::v8f16, 4},
2552       {ISD::FP_TO_SINT, MVT::v16i8, MVT::v16f16, 3}, // 2*fcvtzs+xtn
2553       {ISD::FP_TO_UINT, MVT::v16i8, MVT::v16f16, 3},
2554       {ISD::FP_TO_SINT, MVT::v16i16, MVT::v16f16, 2}, // 2*fcvtzs
2555       {ISD::FP_TO_UINT, MVT::v16i16, MVT::v16f16, 2},
2556       {ISD::FP_TO_SINT, MVT::v16i32, MVT::v16f16, 8}, // 4*fcvtl+4*fcvtzs
2557       {ISD::FP_TO_UINT, MVT::v16i32, MVT::v16f16, 8},
2558       {ISD::UINT_TO_FP, MVT::v8f16, MVT::v8i8, 2},   // ushll + ucvtf
2559       {ISD::SINT_TO_FP, MVT::v8f16, MVT::v8i8, 2},   // sshll + scvtf
2560       {ISD::UINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * ushl(2) + 2 * ucvtf
2561       {ISD::SINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * sshl(2) + 2 * scvtf
2562   };
2563 
2564   if (ST->hasFullFP16())
2565     if (const auto *Entry = ConvertCostTableLookup(
2566             FP16Tbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
2567       return AdjustCost(Entry->Cost);
2568 
2569   if ((ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND) &&
2570       CCH == TTI::CastContextHint::Masked && ST->hasSVEorSME() &&
2571       TLI->getTypeAction(Src->getContext(), SrcTy) ==
2572           TargetLowering::TypePromoteInteger &&
2573       TLI->getTypeAction(Dst->getContext(), DstTy) ==
2574           TargetLowering::TypeSplitVector) {
2575     // The standard behaviour in the backend for these cases is to split the
2576     // extend up into two parts:
2577     //  1. Perform an extending load or masked load up to the legal type.
2578     //  2. Extend the loaded data to the final type.
2579     std::pair<InstructionCost, MVT> SrcLT = getTypeLegalizationCost(Src);
2580     Type *LegalTy = EVT(SrcLT.second).getTypeForEVT(Src->getContext());
2581     InstructionCost Part1 = AArch64TTIImpl::getCastInstrCost(
2582         Opcode, LegalTy, Src, CCH, CostKind, I);
2583     InstructionCost Part2 = AArch64TTIImpl::getCastInstrCost(
2584         Opcode, Dst, LegalTy, TTI::CastContextHint::None, CostKind, I);
2585     return Part1 + Part2;
2586   }
2587 
2588   // The BasicTTIImpl version only deals with CCH==TTI::CastContextHint::Normal,
2589   // but we also want to include the TTI::CastContextHint::Masked case too.
2590   if ((ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND) &&
2591       CCH == TTI::CastContextHint::Masked && ST->hasSVEorSME() &&
2592       TLI->isTypeLegal(DstTy))
2593     CCH = TTI::CastContextHint::Normal;
2594 
2595   return AdjustCost(
2596       BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
2597 }
2598 
2599 InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode,
2600                                                          Type *Dst,
2601                                                          VectorType *VecTy,
2602                                                          unsigned Index) {
2603 
2604   // Make sure we were given a valid extend opcode.
2605   assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) &&
2606          "Invalid opcode");
2607 
2608   // We are extending an element we extract from a vector, so the source type
2609   // of the extend is the element type of the vector.
2610   auto *Src = VecTy->getElementType();
2611 
2612   // Sign- and zero-extends are for integer types only.
2613   assert(isa<IntegerType>(Dst) && isa<IntegerType>(Src) && "Invalid type");
2614 
2615   // Get the cost for the extract. We compute the cost (if any) for the extend
2616   // below.
2617   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2618   InstructionCost Cost = getVectorInstrCost(Instruction::ExtractElement, VecTy,
2619                                             CostKind, Index, nullptr, nullptr);
2620 
2621   // Legalize the types.
2622   auto VecLT = getTypeLegalizationCost(VecTy);
2623   auto DstVT = TLI->getValueType(DL, Dst);
2624   auto SrcVT = TLI->getValueType(DL, Src);
2625 
2626   // If the resulting type is still a vector and the destination type is legal,
2627   // we may get the extension for free. If not, get the default cost for the
2628   // extend.
2629   if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT))
2630     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
2631                                    CostKind);
2632 
2633   // The destination type should be larger than the element type. If not, get
2634   // the default cost for the extend.
2635   if (DstVT.getFixedSizeInBits() < SrcVT.getFixedSizeInBits())
2636     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
2637                                    CostKind);
2638 
2639   switch (Opcode) {
2640   default:
2641     llvm_unreachable("Opcode should be either SExt or ZExt");
2642 
2643   // For sign-extends, we only need a smov, which performs the extension
2644   // automatically.
2645   case Instruction::SExt:
2646     return Cost;
2647 
2648   // For zero-extends, the extend is performed automatically by a umov unless
2649   // the destination type is i64 and the element type is i8 or i16.
2650   case Instruction::ZExt:
2651     if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u)
2652       return Cost;
2653   }
2654 
2655   // If we are unable to perform the extend for free, get the default cost.
2656   return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
2657                                  CostKind);
2658 }
2659 
2660 InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
2661                                                TTI::TargetCostKind CostKind,
2662                                                const Instruction *I) {
2663   if (CostKind != TTI::TCK_RecipThroughput)
2664     return Opcode == Instruction::PHI ? 0 : 1;
2665   assert(CostKind == TTI::TCK_RecipThroughput && "unexpected CostKind");
2666   // Branches are assumed to be predicted.
2667   return 0;
2668 }
2669 
2670 InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I,
2671                                                          Type *Val,
2672                                                          unsigned Index,
2673                                                          bool HasRealUse) {
2674   assert(Val->isVectorTy() && "This must be a vector type");
2675 
2676   if (Index != -1U) {
2677     // Legalize the type.
2678     std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Val);
2679 
2680     // This type is legalized to a scalar type.
2681     if (!LT.second.isVector())
2682       return 0;
2683 
2684     // The type may be split. For fixed-width vectors we can normalize the
2685     // index to the new type.
2686     if (LT.second.isFixedLengthVector()) {
2687       unsigned Width = LT.second.getVectorNumElements();
2688       Index = Index % Width;
2689     }
2690 
2691     // The element at index zero is already inside the vector.
2692     // - For a physical (HasRealUse==true) insert-element or extract-element
2693     // instruction that extracts integers, an explicit FPR -> GPR move is
2694     // needed. So it has non-zero cost.
2695     // - For the rest of cases (virtual instruction or element type is float),
2696     // consider the instruction free.
2697     if (Index == 0 && (!HasRealUse || !Val->getScalarType()->isIntegerTy()))
2698       return 0;
2699 
2700     // This is recognising a LD1 single-element structure to one lane of one
2701     // register instruction. I.e., if this is an `insertelement` instruction,
2702     // and its second operand is a load, then we will generate a LD1, which
2703     // are expensive instructions.
2704     if (I && dyn_cast<LoadInst>(I->getOperand(1)))
2705       return ST->getVectorInsertExtractBaseCost() + 1;
2706 
2707     // i1 inserts and extract will include an extra cset or cmp of the vector
2708     // value. Increase the cost by 1 to account.
2709     if (Val->getScalarSizeInBits() == 1)
2710       return ST->getVectorInsertExtractBaseCost() + 1;
2711 
2712     // FIXME:
2713     // If the extract-element and insert-element instructions could be
2714     // simplified away (e.g., could be combined into users by looking at use-def
2715     // context), they have no cost. This is not done in the first place for
2716     // compile-time considerations.
2717   }
2718 
2719   // All other insert/extracts cost this much.
2720   return ST->getVectorInsertExtractBaseCost();
2721 }
2722 
2723 InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
2724                                                    TTI::TargetCostKind CostKind,
2725                                                    unsigned Index, Value *Op0,
2726                                                    Value *Op1) {
2727   bool HasRealUse =
2728       Opcode == Instruction::InsertElement && Op0 && !isa<UndefValue>(Op0);
2729   return getVectorInstrCostHelper(nullptr, Val, Index, HasRealUse);
2730 }
2731 
2732 InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,
2733                                                    Type *Val,
2734                                                    TTI::TargetCostKind CostKind,
2735                                                    unsigned Index) {
2736   return getVectorInstrCostHelper(&I, Val, Index, true /* HasRealUse */);
2737 }
2738 
2739 InstructionCost AArch64TTIImpl::getScalarizationOverhead(
2740     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
2741     TTI::TargetCostKind CostKind) {
2742   if (isa<ScalableVectorType>(Ty))
2743     return InstructionCost::getInvalid();
2744   if (Ty->getElementType()->isFloatingPointTy())
2745     return BaseT::getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
2746                                            CostKind);
2747   return DemandedElts.popcount() * (Insert + Extract) *
2748          ST->getVectorInsertExtractBaseCost();
2749 }
2750 
2751 InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
2752     unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
2753     TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info,
2754     ArrayRef<const Value *> Args,
2755     const Instruction *CxtI) {
2756 
2757   // TODO: Handle more cost kinds.
2758   if (CostKind != TTI::TCK_RecipThroughput)
2759     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
2760                                          Op2Info, Args, CxtI);
2761 
2762   // Legalize the type.
2763   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
2764   int ISD = TLI->InstructionOpcodeToISD(Opcode);
2765 
2766   switch (ISD) {
2767   default:
2768     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
2769                                          Op2Info);
2770   case ISD::SDIV:
2771     if (Op2Info.isConstant() && Op2Info.isUniform() && Op2Info.isPowerOf2()) {
2772       // On AArch64, scalar signed division by constants power-of-two are
2773       // normally expanded to the sequence ADD + CMP + SELECT + SRA.
2774       // The OperandValue properties many not be same as that of previous
2775       // operation; conservatively assume OP_None.
2776       InstructionCost Cost = getArithmeticInstrCost(
2777           Instruction::Add, Ty, CostKind,
2778           Op1Info.getNoProps(), Op2Info.getNoProps());
2779       Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind,
2780                                      Op1Info.getNoProps(), Op2Info.getNoProps());
2781       Cost += getArithmeticInstrCost(
2782           Instruction::Select, Ty, CostKind,
2783           Op1Info.getNoProps(), Op2Info.getNoProps());
2784       Cost += getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
2785                                      Op1Info.getNoProps(), Op2Info.getNoProps());
2786       return Cost;
2787     }
2788     [[fallthrough]];
2789   case ISD::UDIV: {
2790     if (Op2Info.isConstant() && Op2Info.isUniform()) {
2791       auto VT = TLI->getValueType(DL, Ty);
2792       if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT)) {
2793         // Vector signed division by constant are expanded to the
2794         // sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division
2795         // to MULHS + SUB + SRL + ADD + SRL.
2796         InstructionCost MulCost = getArithmeticInstrCost(
2797             Instruction::Mul, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
2798         InstructionCost AddCost = getArithmeticInstrCost(
2799             Instruction::Add, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
2800         InstructionCost ShrCost = getArithmeticInstrCost(
2801             Instruction::AShr, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps());
2802         return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1;
2803       }
2804     }
2805 
2806     InstructionCost Cost = BaseT::getArithmeticInstrCost(
2807         Opcode, Ty, CostKind, Op1Info, Op2Info);
2808     if (Ty->isVectorTy()) {
2809       if (TLI->isOperationLegalOrCustom(ISD, LT.second) && ST->hasSVE()) {
2810         // SDIV/UDIV operations are lowered using SVE, then we can have less
2811         // costs.
2812         if (isa<FixedVectorType>(Ty) && cast<FixedVectorType>(Ty)
2813                                                 ->getPrimitiveSizeInBits()
2814                                                 .getFixedValue() < 128) {
2815           EVT VT = TLI->getValueType(DL, Ty);
2816           static const CostTblEntry DivTbl[]{
2817               {ISD::SDIV, MVT::v2i8, 5},  {ISD::SDIV, MVT::v4i8, 8},
2818               {ISD::SDIV, MVT::v8i8, 8},  {ISD::SDIV, MVT::v2i16, 5},
2819               {ISD::SDIV, MVT::v4i16, 5}, {ISD::SDIV, MVT::v2i32, 1},
2820               {ISD::UDIV, MVT::v2i8, 5},  {ISD::UDIV, MVT::v4i8, 8},
2821               {ISD::UDIV, MVT::v8i8, 8},  {ISD::UDIV, MVT::v2i16, 5},
2822               {ISD::UDIV, MVT::v4i16, 5}, {ISD::UDIV, MVT::v2i32, 1}};
2823 
2824           const auto *Entry = CostTableLookup(DivTbl, ISD, VT.getSimpleVT());
2825           if (nullptr != Entry)
2826             return Entry->Cost;
2827         }
2828         // For 8/16-bit elements, the cost is higher because the type
2829         // requires promotion and possibly splitting:
2830         if (LT.second.getScalarType() == MVT::i8)
2831           Cost *= 8;
2832         else if (LT.second.getScalarType() == MVT::i16)
2833           Cost *= 4;
2834         return Cost;
2835       } else {
2836         // If one of the operands is a uniform constant then the cost for each
2837         // element is Cost for insertion, extraction and division.
2838         // Insertion cost = 2, Extraction Cost = 2, Division = cost for the
2839         // operation with scalar type
2840         if ((Op1Info.isConstant() && Op1Info.isUniform()) ||
2841             (Op2Info.isConstant() && Op2Info.isUniform())) {
2842           if (auto *VTy = dyn_cast<FixedVectorType>(Ty)) {
2843             InstructionCost DivCost = BaseT::getArithmeticInstrCost(
2844                 Opcode, Ty->getScalarType(), CostKind, Op1Info, Op2Info);
2845             return (4 + DivCost) * VTy->getNumElements();
2846           }
2847         }
2848         // On AArch64, without SVE, vector divisions are expanded
2849         // into scalar divisions of each pair of elements.
2850         Cost += getArithmeticInstrCost(Instruction::ExtractElement, Ty,
2851                                        CostKind, Op1Info, Op2Info);
2852         Cost += getArithmeticInstrCost(Instruction::InsertElement, Ty, CostKind,
2853                                        Op1Info, Op2Info);
2854       }
2855 
2856       // TODO: if one of the arguments is scalar, then it's not necessary to
2857       // double the cost of handling the vector elements.
2858       Cost += Cost;
2859     }
2860     return Cost;
2861   }
2862   case ISD::MUL:
2863     // When SVE is available, then we can lower the v2i64 operation using
2864     // the SVE mul instruction, which has a lower cost.
2865     if (LT.second == MVT::v2i64 && ST->hasSVE())
2866       return LT.first;
2867 
2868     // When SVE is not available, there is no MUL.2d instruction,
2869     // which means mul <2 x i64> is expensive as elements are extracted
2870     // from the vectors and the muls scalarized.
2871     // As getScalarizationOverhead is a bit too pessimistic, we
2872     // estimate the cost for a i64 vector directly here, which is:
2873     // - four 2-cost i64 extracts,
2874     // - two 2-cost i64 inserts, and
2875     // - two 1-cost muls.
2876     // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with
2877     // LT.first = 2 the cost is 28. If both operands are extensions it will not
2878     // need to scalarize so the cost can be cheaper (smull or umull).
2879     // so the cost can be cheaper (smull or umull).
2880     if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args))
2881       return LT.first;
2882     return LT.first * 14;
2883   case ISD::ADD:
2884   case ISD::XOR:
2885   case ISD::OR:
2886   case ISD::AND:
2887   case ISD::SRL:
2888   case ISD::SRA:
2889   case ISD::SHL:
2890     // These nodes are marked as 'custom' for combining purposes only.
2891     // We know that they are legal. See LowerAdd in ISelLowering.
2892     return LT.first;
2893 
2894   case ISD::FNEG:
2895   case ISD::FADD:
2896   case ISD::FSUB:
2897     // Increase the cost for half and bfloat types if not architecturally
2898     // supported.
2899     if ((Ty->getScalarType()->isHalfTy() && !ST->hasFullFP16()) ||
2900         (Ty->getScalarType()->isBFloatTy() && !ST->hasBF16()))
2901       return 2 * LT.first;
2902     if (!Ty->getScalarType()->isFP128Ty())
2903       return LT.first;
2904     [[fallthrough]];
2905   case ISD::FMUL:
2906   case ISD::FDIV:
2907     // These nodes are marked as 'custom' just to lower them to SVE.
2908     // We know said lowering will incur no additional cost.
2909     if (!Ty->getScalarType()->isFP128Ty())
2910       return 2 * LT.first;
2911 
2912     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
2913                                          Op2Info);
2914   }
2915 }
2916 
2917 InstructionCost AArch64TTIImpl::getAddressComputationCost(Type *Ty,
2918                                                           ScalarEvolution *SE,
2919                                                           const SCEV *Ptr) {
2920   // Address computations in vectorized code with non-consecutive addresses will
2921   // likely result in more instructions compared to scalar code where the
2922   // computation can more often be merged into the index mode. The resulting
2923   // extra micro-ops can significantly decrease throughput.
2924   unsigned NumVectorInstToHideOverhead = NeonNonConstStrideOverhead;
2925   int MaxMergeDistance = 64;
2926 
2927   if (Ty->isVectorTy() && SE &&
2928       !BaseT::isConstantStridedAccessLessThan(SE, Ptr, MaxMergeDistance + 1))
2929     return NumVectorInstToHideOverhead;
2930 
2931   // In many cases the address computation is not merged into the instruction
2932   // addressing mode.
2933   return 1;
2934 }
2935 
2936 InstructionCost AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy,
2937                                                    Type *CondTy,
2938                                                    CmpInst::Predicate VecPred,
2939                                                    TTI::TargetCostKind CostKind,
2940                                                    const Instruction *I) {
2941   // TODO: Handle other cost kinds.
2942   if (CostKind != TTI::TCK_RecipThroughput)
2943     return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
2944                                      I);
2945 
2946   int ISD = TLI->InstructionOpcodeToISD(Opcode);
2947   // We don't lower some vector selects well that are wider than the register
2948   // width.
2949   if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT) {
2950     // We would need this many instructions to hide the scalarization happening.
2951     const int AmortizationCost = 20;
2952 
2953     // If VecPred is not set, check if we can get a predicate from the context
2954     // instruction, if its type matches the requested ValTy.
2955     if (VecPred == CmpInst::BAD_ICMP_PREDICATE && I && I->getType() == ValTy) {
2956       CmpInst::Predicate CurrentPred;
2957       if (match(I, m_Select(m_Cmp(CurrentPred, m_Value(), m_Value()), m_Value(),
2958                             m_Value())))
2959         VecPred = CurrentPred;
2960     }
2961     // Check if we have a compare/select chain that can be lowered using
2962     // a (F)CMxx & BFI pair.
2963     if (CmpInst::isIntPredicate(VecPred) || VecPred == CmpInst::FCMP_OLE ||
2964         VecPred == CmpInst::FCMP_OLT || VecPred == CmpInst::FCMP_OGT ||
2965         VecPred == CmpInst::FCMP_OGE || VecPred == CmpInst::FCMP_OEQ ||
2966         VecPred == CmpInst::FCMP_UNE) {
2967       static const auto ValidMinMaxTys = {
2968           MVT::v8i8,  MVT::v16i8, MVT::v4i16, MVT::v8i16, MVT::v2i32,
2969           MVT::v4i32, MVT::v2i64, MVT::v2f32, MVT::v4f32, MVT::v2f64};
2970       static const auto ValidFP16MinMaxTys = {MVT::v4f16, MVT::v8f16};
2971 
2972       auto LT = getTypeLegalizationCost(ValTy);
2973       if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }) ||
2974           (ST->hasFullFP16() &&
2975            any_of(ValidFP16MinMaxTys, [&LT](MVT M) { return M == LT.second; })))
2976         return LT.first;
2977     }
2978 
2979     static const TypeConversionCostTblEntry
2980     VectorSelectTbl[] = {
2981       { ISD::SELECT, MVT::v2i1, MVT::v2f32, 2 },
2982       { ISD::SELECT, MVT::v2i1, MVT::v2f64, 2 },
2983       { ISD::SELECT, MVT::v4i1, MVT::v4f32, 2 },
2984       { ISD::SELECT, MVT::v4i1, MVT::v4f16, 2 },
2985       { ISD::SELECT, MVT::v8i1, MVT::v8f16, 2 },
2986       { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 },
2987       { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 },
2988       { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 },
2989       { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost },
2990       { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost },
2991       { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost }
2992     };
2993 
2994     EVT SelCondTy = TLI->getValueType(DL, CondTy);
2995     EVT SelValTy = TLI->getValueType(DL, ValTy);
2996     if (SelCondTy.isSimple() && SelValTy.isSimple()) {
2997       if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD,
2998                                                      SelCondTy.getSimpleVT(),
2999                                                      SelValTy.getSimpleVT()))
3000         return Entry->Cost;
3001     }
3002   }
3003 
3004   if (isa<FixedVectorType>(ValTy) && ISD == ISD::SETCC) {
3005     auto LT = getTypeLegalizationCost(ValTy);
3006     // Cost v4f16 FCmp without FP16 support via converting to v4f32 and back.
3007     if (LT.second == MVT::v4f16 && !ST->hasFullFP16())
3008       return LT.first * 4; // fcvtl + fcvtl + fcmp + xtn
3009   }
3010 
3011   // Treat the icmp in icmp(and, 0) as free, as we can make use of ands.
3012   // FIXME: This can apply to more conditions and add/sub if it can be shown to
3013   // be profitable.
3014   if (ValTy->isIntegerTy() && ISD == ISD::SETCC && I &&
3015       ICmpInst::isEquality(VecPred) &&
3016       TLI->isTypeLegal(TLI->getValueType(DL, ValTy)) &&
3017       match(I->getOperand(1), m_Zero()) &&
3018       match(I->getOperand(0), m_And(m_Value(), m_Value())))
3019     return 0;
3020 
3021   // The base case handles scalable vectors fine for now, since it treats the
3022   // cost as 1 * legalization cost.
3023   return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I);
3024 }
3025 
3026 AArch64TTIImpl::TTI::MemCmpExpansionOptions
3027 AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
3028   TTI::MemCmpExpansionOptions Options;
3029   if (ST->requiresStrictAlign()) {
3030     // TODO: Add cost modeling for strict align. Misaligned loads expand to
3031     // a bunch of instructions when strict align is enabled.
3032     return Options;
3033   }
3034   Options.AllowOverlappingLoads = true;
3035   Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize);
3036   Options.NumLoadsPerBlock = Options.MaxNumLoads;
3037   // TODO: Though vector loads usually perform well on AArch64, in some targets
3038   // they may wake up the FP unit, which raises the power consumption.  Perhaps
3039   // they could be used with no holds barred (-O3).
3040   Options.LoadSizes = {8, 4, 2, 1};
3041   Options.AllowedTailExpansions = {3, 5, 6};
3042   return Options;
3043 }
3044 
3045 bool AArch64TTIImpl::prefersVectorizedAddressing() const {
3046   return ST->hasSVE();
3047 }
3048 
3049 InstructionCost
3050 AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
3051                                       Align Alignment, unsigned AddressSpace,
3052                                       TTI::TargetCostKind CostKind) {
3053   if (useNeonVector(Src))
3054     return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace,
3055                                         CostKind);
3056   auto LT = getTypeLegalizationCost(Src);
3057   if (!LT.first.isValid())
3058     return InstructionCost::getInvalid();
3059 
3060   // The code-generator is currently not able to handle scalable vectors
3061   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
3062   // it. This change will be removed when code-generation for these types is
3063   // sufficiently reliable.
3064   if (cast<VectorType>(Src)->getElementCount() == ElementCount::getScalable(1))
3065     return InstructionCost::getInvalid();
3066 
3067   return LT.first;
3068 }
3069 
3070 static unsigned getSVEGatherScatterOverhead(unsigned Opcode) {
3071   return Opcode == Instruction::Load ? SVEGatherOverhead : SVEScatterOverhead;
3072 }
3073 
3074 InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
3075     unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
3076     Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) {
3077   if (useNeonVector(DataTy) || !isLegalMaskedGatherScatter(DataTy))
3078     return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
3079                                          Alignment, CostKind, I);
3080   auto *VT = cast<VectorType>(DataTy);
3081   auto LT = getTypeLegalizationCost(DataTy);
3082   if (!LT.first.isValid())
3083     return InstructionCost::getInvalid();
3084 
3085   if (!LT.second.isVector() ||
3086       !isElementTypeLegalForScalableVector(VT->getElementType()))
3087     return InstructionCost::getInvalid();
3088 
3089   // The code-generator is currently not able to handle scalable vectors
3090   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
3091   // it. This change will be removed when code-generation for these types is
3092   // sufficiently reliable.
3093   if (cast<VectorType>(DataTy)->getElementCount() ==
3094       ElementCount::getScalable(1))
3095     return InstructionCost::getInvalid();
3096 
3097   ElementCount LegalVF = LT.second.getVectorElementCount();
3098   InstructionCost MemOpCost =
3099       getMemoryOpCost(Opcode, VT->getElementType(), Alignment, 0, CostKind,
3100                       {TTI::OK_AnyValue, TTI::OP_None}, I);
3101   // Add on an overhead cost for using gathers/scatters.
3102   // TODO: At the moment this is applied unilaterally for all CPUs, but at some
3103   // point we may want a per-CPU overhead.
3104   MemOpCost *= getSVEGatherScatterOverhead(Opcode);
3105   return LT.first * MemOpCost * getMaxNumElements(LegalVF);
3106 }
3107 
3108 bool AArch64TTIImpl::useNeonVector(const Type *Ty) const {
3109   return isa<FixedVectorType>(Ty) && !ST->useSVEForFixedLengthVectors();
3110 }
3111 
3112 InstructionCost AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty,
3113                                                 MaybeAlign Alignment,
3114                                                 unsigned AddressSpace,
3115                                                 TTI::TargetCostKind CostKind,
3116                                                 TTI::OperandValueInfo OpInfo,
3117                                                 const Instruction *I) {
3118   EVT VT = TLI->getValueType(DL, Ty, true);
3119   // Type legalization can't handle structs
3120   if (VT == MVT::Other)
3121     return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
3122                                   CostKind);
3123 
3124   auto LT = getTypeLegalizationCost(Ty);
3125   if (!LT.first.isValid())
3126     return InstructionCost::getInvalid();
3127 
3128   // The code-generator is currently not able to handle scalable vectors
3129   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
3130   // it. This change will be removed when code-generation for these types is
3131   // sufficiently reliable.
3132   if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
3133     if (VTy->getElementCount() == ElementCount::getScalable(1))
3134       return InstructionCost::getInvalid();
3135 
3136   // TODO: consider latency as well for TCK_SizeAndLatency.
3137   if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency)
3138     return LT.first;
3139 
3140   if (CostKind != TTI::TCK_RecipThroughput)
3141     return 1;
3142 
3143   if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store &&
3144       LT.second.is128BitVector() && (!Alignment || *Alignment < Align(16))) {
3145     // Unaligned stores are extremely inefficient. We don't split all
3146     // unaligned 128-bit stores because the negative impact that has shown in
3147     // practice on inlined block copy code.
3148     // We make such stores expensive so that we will only vectorize if there
3149     // are 6 other instructions getting vectorized.
3150     const int AmortizationCost = 6;
3151 
3152     return LT.first * 2 * AmortizationCost;
3153   }
3154 
3155   // Opaque ptr or ptr vector types are i64s and can be lowered to STP/LDPs.
3156   if (Ty->isPtrOrPtrVectorTy())
3157     return LT.first;
3158 
3159   // Check truncating stores and extending loads.
3160   if (useNeonVector(Ty) &&
3161       Ty->getScalarSizeInBits() != LT.second.getScalarSizeInBits()) {
3162     // v4i8 types are lowered to scalar a load/store and sshll/xtn.
3163     if (VT == MVT::v4i8)
3164       return 2;
3165     // Otherwise we need to scalarize.
3166     return cast<FixedVectorType>(Ty)->getNumElements() * 2;
3167   }
3168 
3169   return LT.first;
3170 }
3171 
3172 InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
3173     unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
3174     Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
3175     bool UseMaskForCond, bool UseMaskForGaps) {
3176   assert(Factor >= 2 && "Invalid interleave factor");
3177   auto *VecVTy = cast<VectorType>(VecTy);
3178 
3179   if (VecTy->isScalableTy() && (!ST->hasSVE() || Factor != 2))
3180     return InstructionCost::getInvalid();
3181 
3182   // Vectorization for masked interleaved accesses is only enabled for scalable
3183   // VF.
3184   if (!VecTy->isScalableTy() && (UseMaskForCond || UseMaskForGaps))
3185     return InstructionCost::getInvalid();
3186 
3187   if (!UseMaskForGaps && Factor <= TLI->getMaxSupportedInterleaveFactor()) {
3188     unsigned MinElts = VecVTy->getElementCount().getKnownMinValue();
3189     auto *SubVecTy =
3190         VectorType::get(VecVTy->getElementType(),
3191                         VecVTy->getElementCount().divideCoefficientBy(Factor));
3192 
3193     // ldN/stN only support legal vector types of size 64 or 128 in bits.
3194     // Accesses having vector types that are a multiple of 128 bits can be
3195     // matched to more than one ldN/stN instruction.
3196     bool UseScalable;
3197     if (MinElts % Factor == 0 &&
3198         TLI->isLegalInterleavedAccessType(SubVecTy, DL, UseScalable))
3199       return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL, UseScalable);
3200   }
3201 
3202   return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices,
3203                                            Alignment, AddressSpace, CostKind,
3204                                            UseMaskForCond, UseMaskForGaps);
3205 }
3206 
3207 InstructionCost
3208 AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) {
3209   InstructionCost Cost = 0;
3210   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
3211   for (auto *I : Tys) {
3212     if (!I->isVectorTy())
3213       continue;
3214     if (I->getScalarSizeInBits() * cast<FixedVectorType>(I)->getNumElements() ==
3215         128)
3216       Cost += getMemoryOpCost(Instruction::Store, I, Align(128), 0, CostKind) +
3217               getMemoryOpCost(Instruction::Load, I, Align(128), 0, CostKind);
3218   }
3219   return Cost;
3220 }
3221 
3222 unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) {
3223   return ST->getMaxInterleaveFactor();
3224 }
3225 
3226 // For Falkor, we want to avoid having too many strided loads in a loop since
3227 // that can exhaust the HW prefetcher resources.  We adjust the unroller
3228 // MaxCount preference below to attempt to ensure unrolling doesn't create too
3229 // many strided loads.
3230 static void
3231 getFalkorUnrollingPreferences(Loop *L, ScalarEvolution &SE,
3232                               TargetTransformInfo::UnrollingPreferences &UP) {
3233   enum { MaxStridedLoads = 7 };
3234   auto countStridedLoads = [](Loop *L, ScalarEvolution &SE) {
3235     int StridedLoads = 0;
3236     // FIXME? We could make this more precise by looking at the CFG and
3237     // e.g. not counting loads in each side of an if-then-else diamond.
3238     for (const auto BB : L->blocks()) {
3239       for (auto &I : *BB) {
3240         LoadInst *LMemI = dyn_cast<LoadInst>(&I);
3241         if (!LMemI)
3242           continue;
3243 
3244         Value *PtrValue = LMemI->getPointerOperand();
3245         if (L->isLoopInvariant(PtrValue))
3246           continue;
3247 
3248         const SCEV *LSCEV = SE.getSCEV(PtrValue);
3249         const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
3250         if (!LSCEVAddRec || !LSCEVAddRec->isAffine())
3251           continue;
3252 
3253         // FIXME? We could take pairing of unrolled load copies into account
3254         // by looking at the AddRec, but we would probably have to limit this
3255         // to loops with no stores or other memory optimization barriers.
3256         ++StridedLoads;
3257         // We've seen enough strided loads that seeing more won't make a
3258         // difference.
3259         if (StridedLoads > MaxStridedLoads / 2)
3260           return StridedLoads;
3261       }
3262     }
3263     return StridedLoads;
3264   };
3265 
3266   int StridedLoads = countStridedLoads(L, SE);
3267   LLVM_DEBUG(dbgs() << "falkor-hwpf: detected " << StridedLoads
3268                     << " strided loads\n");
3269   // Pick the largest power of 2 unroll count that won't result in too many
3270   // strided loads.
3271   if (StridedLoads) {
3272     UP.MaxCount = 1 << Log2_32(MaxStridedLoads / StridedLoads);
3273     LLVM_DEBUG(dbgs() << "falkor-hwpf: setting unroll MaxCount to "
3274                       << UP.MaxCount << '\n');
3275   }
3276 }
3277 
3278 void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
3279                                              TTI::UnrollingPreferences &UP,
3280                                              OptimizationRemarkEmitter *ORE) {
3281   // Enable partial unrolling and runtime unrolling.
3282   BaseT::getUnrollingPreferences(L, SE, UP, ORE);
3283 
3284   UP.UpperBound = true;
3285 
3286   // For inner loop, it is more likely to be a hot one, and the runtime check
3287   // can be promoted out from LICM pass, so the overhead is less, let's try
3288   // a larger threshold to unroll more loops.
3289   if (L->getLoopDepth() > 1)
3290     UP.PartialThreshold *= 2;
3291 
3292   // Disable partial & runtime unrolling on -Os.
3293   UP.PartialOptSizeThreshold = 0;
3294 
3295   if (ST->getProcFamily() == AArch64Subtarget::Falkor &&
3296       EnableFalkorHWPFUnrollFix)
3297     getFalkorUnrollingPreferences(L, SE, UP);
3298 
3299   // Scan the loop: don't unroll loops with calls as this could prevent
3300   // inlining. Don't unroll vector loops either, as they don't benefit much from
3301   // unrolling.
3302   for (auto *BB : L->getBlocks()) {
3303     for (auto &I : *BB) {
3304       // Don't unroll vectorised loop.
3305       if (I.getType()->isVectorTy())
3306         return;
3307 
3308       if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
3309         if (const Function *F = cast<CallBase>(I).getCalledFunction()) {
3310           if (!isLoweredToCall(F))
3311             continue;
3312         }
3313         return;
3314       }
3315     }
3316   }
3317 
3318   // Enable runtime unrolling for in-order models
3319   // If mcpu is omitted, getProcFamily() returns AArch64Subtarget::Others, so by
3320   // checking for that case, we can ensure that the default behaviour is
3321   // unchanged
3322   if (ST->getProcFamily() != AArch64Subtarget::Others &&
3323       !ST->getSchedModel().isOutOfOrder()) {
3324     UP.Runtime = true;
3325     UP.Partial = true;
3326     UP.UnrollRemainder = true;
3327     UP.DefaultUnrollRuntimeCount = 4;
3328 
3329     UP.UnrollAndJam = true;
3330     UP.UnrollAndJamInnerLoopThreshold = 60;
3331   }
3332 }
3333 
3334 void AArch64TTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE,
3335                                            TTI::PeelingPreferences &PP) {
3336   BaseT::getPeelingPreferences(L, SE, PP);
3337 }
3338 
3339 Value *AArch64TTIImpl::getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst,
3340                                                          Type *ExpectedType) {
3341   switch (Inst->getIntrinsicID()) {
3342   default:
3343     return nullptr;
3344   case Intrinsic::aarch64_neon_st2:
3345   case Intrinsic::aarch64_neon_st3:
3346   case Intrinsic::aarch64_neon_st4: {
3347     // Create a struct type
3348     StructType *ST = dyn_cast<StructType>(ExpectedType);
3349     if (!ST)
3350       return nullptr;
3351     unsigned NumElts = Inst->arg_size() - 1;
3352     if (ST->getNumElements() != NumElts)
3353       return nullptr;
3354     for (unsigned i = 0, e = NumElts; i != e; ++i) {
3355       if (Inst->getArgOperand(i)->getType() != ST->getElementType(i))
3356         return nullptr;
3357     }
3358     Value *Res = PoisonValue::get(ExpectedType);
3359     IRBuilder<> Builder(Inst);
3360     for (unsigned i = 0, e = NumElts; i != e; ++i) {
3361       Value *L = Inst->getArgOperand(i);
3362       Res = Builder.CreateInsertValue(Res, L, i);
3363     }
3364     return Res;
3365   }
3366   case Intrinsic::aarch64_neon_ld2:
3367   case Intrinsic::aarch64_neon_ld3:
3368   case Intrinsic::aarch64_neon_ld4:
3369     if (Inst->getType() == ExpectedType)
3370       return Inst;
3371     return nullptr;
3372   }
3373 }
3374 
3375 bool AArch64TTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
3376                                         MemIntrinsicInfo &Info) {
3377   switch (Inst->getIntrinsicID()) {
3378   default:
3379     break;
3380   case Intrinsic::aarch64_neon_ld2:
3381   case Intrinsic::aarch64_neon_ld3:
3382   case Intrinsic::aarch64_neon_ld4:
3383     Info.ReadMem = true;
3384     Info.WriteMem = false;
3385     Info.PtrVal = Inst->getArgOperand(0);
3386     break;
3387   case Intrinsic::aarch64_neon_st2:
3388   case Intrinsic::aarch64_neon_st3:
3389   case Intrinsic::aarch64_neon_st4:
3390     Info.ReadMem = false;
3391     Info.WriteMem = true;
3392     Info.PtrVal = Inst->getArgOperand(Inst->arg_size() - 1);
3393     break;
3394   }
3395 
3396   switch (Inst->getIntrinsicID()) {
3397   default:
3398     return false;
3399   case Intrinsic::aarch64_neon_ld2:
3400   case Intrinsic::aarch64_neon_st2:
3401     Info.MatchingId = VECTOR_LDST_TWO_ELEMENTS;
3402     break;
3403   case Intrinsic::aarch64_neon_ld3:
3404   case Intrinsic::aarch64_neon_st3:
3405     Info.MatchingId = VECTOR_LDST_THREE_ELEMENTS;
3406     break;
3407   case Intrinsic::aarch64_neon_ld4:
3408   case Intrinsic::aarch64_neon_st4:
3409     Info.MatchingId = VECTOR_LDST_FOUR_ELEMENTS;
3410     break;
3411   }
3412   return true;
3413 }
3414 
3415 /// See if \p I should be considered for address type promotion. We check if \p
3416 /// I is a sext with right type and used in memory accesses. If it used in a
3417 /// "complex" getelementptr, we allow it to be promoted without finding other
3418 /// sext instructions that sign extended the same initial value. A getelementptr
3419 /// is considered as "complex" if it has more than 2 operands.
3420 bool AArch64TTIImpl::shouldConsiderAddressTypePromotion(
3421     const Instruction &I, bool &AllowPromotionWithoutCommonHeader) {
3422   bool Considerable = false;
3423   AllowPromotionWithoutCommonHeader = false;
3424   if (!isa<SExtInst>(&I))
3425     return false;
3426   Type *ConsideredSExtType =
3427       Type::getInt64Ty(I.getParent()->getParent()->getContext());
3428   if (I.getType() != ConsideredSExtType)
3429     return false;
3430   // See if the sext is the one with the right type and used in at least one
3431   // GetElementPtrInst.
3432   for (const User *U : I.users()) {
3433     if (const GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(U)) {
3434       Considerable = true;
3435       // A getelementptr is considered as "complex" if it has more than 2
3436       // operands. We will promote a SExt used in such complex GEP as we
3437       // expect some computation to be merged if they are done on 64 bits.
3438       if (GEPInst->getNumOperands() > 2) {
3439         AllowPromotionWithoutCommonHeader = true;
3440         break;
3441       }
3442     }
3443   }
3444   return Considerable;
3445 }
3446 
3447 bool AArch64TTIImpl::isLegalToVectorizeReduction(
3448     const RecurrenceDescriptor &RdxDesc, ElementCount VF) const {
3449   if (!VF.isScalable())
3450     return true;
3451 
3452   Type *Ty = RdxDesc.getRecurrenceType();
3453   if (Ty->isBFloatTy() || !isElementTypeLegalForScalableVector(Ty))
3454     return false;
3455 
3456   switch (RdxDesc.getRecurrenceKind()) {
3457   case RecurKind::Add:
3458   case RecurKind::FAdd:
3459   case RecurKind::And:
3460   case RecurKind::Or:
3461   case RecurKind::Xor:
3462   case RecurKind::SMin:
3463   case RecurKind::SMax:
3464   case RecurKind::UMin:
3465   case RecurKind::UMax:
3466   case RecurKind::FMin:
3467   case RecurKind::FMax:
3468   case RecurKind::FMulAdd:
3469   case RecurKind::IAnyOf:
3470   case RecurKind::FAnyOf:
3471     return true;
3472   default:
3473     return false;
3474   }
3475 }
3476 
3477 InstructionCost
3478 AArch64TTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty,
3479                                        FastMathFlags FMF,
3480                                        TTI::TargetCostKind CostKind) {
3481   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
3482 
3483   if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
3484     return BaseT::getMinMaxReductionCost(IID, Ty, FMF, CostKind);
3485 
3486   InstructionCost LegalizationCost = 0;
3487   if (LT.first > 1) {
3488     Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext());
3489     IntrinsicCostAttributes Attrs(IID, LegalVTy, {LegalVTy, LegalVTy}, FMF);
3490     LegalizationCost = getIntrinsicInstrCost(Attrs, CostKind) * (LT.first - 1);
3491   }
3492 
3493   return LegalizationCost + /*Cost of horizontal reduction*/ 2;
3494 }
3495 
3496 InstructionCost AArch64TTIImpl::getArithmeticReductionCostSVE(
3497     unsigned Opcode, VectorType *ValTy, TTI::TargetCostKind CostKind) {
3498   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
3499   InstructionCost LegalizationCost = 0;
3500   if (LT.first > 1) {
3501     Type *LegalVTy = EVT(LT.second).getTypeForEVT(ValTy->getContext());
3502     LegalizationCost = getArithmeticInstrCost(Opcode, LegalVTy, CostKind);
3503     LegalizationCost *= LT.first - 1;
3504   }
3505 
3506   int ISD = TLI->InstructionOpcodeToISD(Opcode);
3507   assert(ISD && "Invalid opcode");
3508   // Add the final reduction cost for the legal horizontal reduction
3509   switch (ISD) {
3510   case ISD::ADD:
3511   case ISD::AND:
3512   case ISD::OR:
3513   case ISD::XOR:
3514   case ISD::FADD:
3515     return LegalizationCost + 2;
3516   default:
3517     return InstructionCost::getInvalid();
3518   }
3519 }
3520 
3521 InstructionCost
3522 AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
3523                                            std::optional<FastMathFlags> FMF,
3524                                            TTI::TargetCostKind CostKind) {
3525   if (TTI::requiresOrderedReduction(FMF)) {
3526     if (auto *FixedVTy = dyn_cast<FixedVectorType>(ValTy)) {
3527       InstructionCost BaseCost =
3528           BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
3529       // Add on extra cost to reflect the extra overhead on some CPUs. We still
3530       // end up vectorizing for more computationally intensive loops.
3531       return BaseCost + FixedVTy->getNumElements();
3532     }
3533 
3534     if (Opcode != Instruction::FAdd)
3535       return InstructionCost::getInvalid();
3536 
3537     auto *VTy = cast<ScalableVectorType>(ValTy);
3538     InstructionCost Cost =
3539         getArithmeticInstrCost(Opcode, VTy->getScalarType(), CostKind);
3540     Cost *= getMaxNumElements(VTy->getElementCount());
3541     return Cost;
3542   }
3543 
3544   if (isa<ScalableVectorType>(ValTy))
3545     return getArithmeticReductionCostSVE(Opcode, ValTy, CostKind);
3546 
3547   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
3548   MVT MTy = LT.second;
3549   int ISD = TLI->InstructionOpcodeToISD(Opcode);
3550   assert(ISD && "Invalid opcode");
3551 
3552   // Horizontal adds can use the 'addv' instruction. We model the cost of these
3553   // instructions as twice a normal vector add, plus 1 for each legalization
3554   // step (LT.first). This is the only arithmetic vector reduction operation for
3555   // which we have an instruction.
3556   // OR, XOR and AND costs should match the codegen from:
3557   // OR: llvm/test/CodeGen/AArch64/reduce-or.ll
3558   // XOR: llvm/test/CodeGen/AArch64/reduce-xor.ll
3559   // AND: llvm/test/CodeGen/AArch64/reduce-and.ll
3560   static const CostTblEntry CostTblNoPairwise[]{
3561       {ISD::ADD, MVT::v8i8,   2},
3562       {ISD::ADD, MVT::v16i8,  2},
3563       {ISD::ADD, MVT::v4i16,  2},
3564       {ISD::ADD, MVT::v8i16,  2},
3565       {ISD::ADD, MVT::v4i32,  2},
3566       {ISD::ADD, MVT::v2i64,  2},
3567       {ISD::OR,  MVT::v8i8,  15},
3568       {ISD::OR,  MVT::v16i8, 17},
3569       {ISD::OR,  MVT::v4i16,  7},
3570       {ISD::OR,  MVT::v8i16,  9},
3571       {ISD::OR,  MVT::v2i32,  3},
3572       {ISD::OR,  MVT::v4i32,  5},
3573       {ISD::OR,  MVT::v2i64,  3},
3574       {ISD::XOR, MVT::v8i8,  15},
3575       {ISD::XOR, MVT::v16i8, 17},
3576       {ISD::XOR, MVT::v4i16,  7},
3577       {ISD::XOR, MVT::v8i16,  9},
3578       {ISD::XOR, MVT::v2i32,  3},
3579       {ISD::XOR, MVT::v4i32,  5},
3580       {ISD::XOR, MVT::v2i64,  3},
3581       {ISD::AND, MVT::v8i8,  15},
3582       {ISD::AND, MVT::v16i8, 17},
3583       {ISD::AND, MVT::v4i16,  7},
3584       {ISD::AND, MVT::v8i16,  9},
3585       {ISD::AND, MVT::v2i32,  3},
3586       {ISD::AND, MVT::v4i32,  5},
3587       {ISD::AND, MVT::v2i64,  3},
3588   };
3589   switch (ISD) {
3590   default:
3591     break;
3592   case ISD::ADD:
3593     if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy))
3594       return (LT.first - 1) + Entry->Cost;
3595     break;
3596   case ISD::XOR:
3597   case ISD::AND:
3598   case ISD::OR:
3599     const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy);
3600     if (!Entry)
3601       break;
3602     auto *ValVTy = cast<FixedVectorType>(ValTy);
3603     if (MTy.getVectorNumElements() <= ValVTy->getNumElements() &&
3604         isPowerOf2_32(ValVTy->getNumElements())) {
3605       InstructionCost ExtraCost = 0;
3606       if (LT.first != 1) {
3607         // Type needs to be split, so there is an extra cost of LT.first - 1
3608         // arithmetic ops.
3609         auto *Ty = FixedVectorType::get(ValTy->getElementType(),
3610                                         MTy.getVectorNumElements());
3611         ExtraCost = getArithmeticInstrCost(Opcode, Ty, CostKind);
3612         ExtraCost *= LT.first - 1;
3613       }
3614       // All and/or/xor of i1 will be lowered with maxv/minv/addv + fmov
3615       auto Cost = ValVTy->getElementType()->isIntegerTy(1) ? 2 : Entry->Cost;
3616       return Cost + ExtraCost;
3617     }
3618     break;
3619   }
3620   return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
3621 }
3622 
3623 InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) {
3624   static const CostTblEntry ShuffleTbl[] = {
3625       { TTI::SK_Splice, MVT::nxv16i8,  1 },
3626       { TTI::SK_Splice, MVT::nxv8i16,  1 },
3627       { TTI::SK_Splice, MVT::nxv4i32,  1 },
3628       { TTI::SK_Splice, MVT::nxv2i64,  1 },
3629       { TTI::SK_Splice, MVT::nxv2f16,  1 },
3630       { TTI::SK_Splice, MVT::nxv4f16,  1 },
3631       { TTI::SK_Splice, MVT::nxv8f16,  1 },
3632       { TTI::SK_Splice, MVT::nxv2bf16, 1 },
3633       { TTI::SK_Splice, MVT::nxv4bf16, 1 },
3634       { TTI::SK_Splice, MVT::nxv8bf16, 1 },
3635       { TTI::SK_Splice, MVT::nxv2f32,  1 },
3636       { TTI::SK_Splice, MVT::nxv4f32,  1 },
3637       { TTI::SK_Splice, MVT::nxv2f64,  1 },
3638   };
3639 
3640   // The code-generator is currently not able to handle scalable vectors
3641   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
3642   // it. This change will be removed when code-generation for these types is
3643   // sufficiently reliable.
3644   if (Tp->getElementCount() == ElementCount::getScalable(1))
3645     return InstructionCost::getInvalid();
3646 
3647   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
3648   Type *LegalVTy = EVT(LT.second).getTypeForEVT(Tp->getContext());
3649   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
3650   EVT PromotedVT = LT.second.getScalarType() == MVT::i1
3651                        ? TLI->getPromotedVTForPredicate(EVT(LT.second))
3652                        : LT.second;
3653   Type *PromotedVTy = EVT(PromotedVT).getTypeForEVT(Tp->getContext());
3654   InstructionCost LegalizationCost = 0;
3655   if (Index < 0) {
3656     LegalizationCost =
3657         getCmpSelInstrCost(Instruction::ICmp, PromotedVTy, PromotedVTy,
3658                            CmpInst::BAD_ICMP_PREDICATE, CostKind) +
3659         getCmpSelInstrCost(Instruction::Select, PromotedVTy, LegalVTy,
3660                            CmpInst::BAD_ICMP_PREDICATE, CostKind);
3661   }
3662 
3663   // Predicated splice are promoted when lowering. See AArch64ISelLowering.cpp
3664   // Cost performed on a promoted type.
3665   if (LT.second.getScalarType() == MVT::i1) {
3666     LegalizationCost +=
3667         getCastInstrCost(Instruction::ZExt, PromotedVTy, LegalVTy,
3668                          TTI::CastContextHint::None, CostKind) +
3669         getCastInstrCost(Instruction::Trunc, LegalVTy, PromotedVTy,
3670                          TTI::CastContextHint::None, CostKind);
3671   }
3672   const auto *Entry =
3673       CostTableLookup(ShuffleTbl, TTI::SK_Splice, PromotedVT.getSimpleVT());
3674   assert(Entry && "Illegal Type for Splice");
3675   LegalizationCost += Entry->Cost;
3676   return LegalizationCost * LT.first;
3677 }
3678 
3679 InstructionCost AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
3680                                                VectorType *Tp,
3681                                                ArrayRef<int> Mask,
3682                                                TTI::TargetCostKind CostKind,
3683                                                int Index, VectorType *SubTp,
3684                                                ArrayRef<const Value *> Args) {
3685   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
3686   // If we have a Mask, and the LT is being legalized somehow, split the Mask
3687   // into smaller vectors and sum the cost of each shuffle.
3688   if (!Mask.empty() && isa<FixedVectorType>(Tp) && LT.second.isVector() &&
3689       Tp->getScalarSizeInBits() == LT.second.getScalarSizeInBits() &&
3690       Mask.size() > LT.second.getVectorNumElements() && !Index && !SubTp) {
3691     unsigned TpNumElts = Mask.size();
3692     unsigned LTNumElts = LT.second.getVectorNumElements();
3693     unsigned NumVecs = (TpNumElts + LTNumElts - 1) / LTNumElts;
3694     VectorType *NTp =
3695         VectorType::get(Tp->getScalarType(), LT.second.getVectorElementCount());
3696     InstructionCost Cost;
3697     for (unsigned N = 0; N < NumVecs; N++) {
3698       SmallVector<int> NMask;
3699       // Split the existing mask into chunks of size LTNumElts. Track the source
3700       // sub-vectors to ensure the result has at most 2 inputs.
3701       unsigned Source1, Source2;
3702       unsigned NumSources = 0;
3703       for (unsigned E = 0; E < LTNumElts; E++) {
3704         int MaskElt = (N * LTNumElts + E < TpNumElts) ? Mask[N * LTNumElts + E]
3705                                                       : PoisonMaskElem;
3706         if (MaskElt < 0) {
3707           NMask.push_back(PoisonMaskElem);
3708           continue;
3709         }
3710 
3711         // Calculate which source from the input this comes from and whether it
3712         // is new to us.
3713         unsigned Source = MaskElt / LTNumElts;
3714         if (NumSources == 0) {
3715           Source1 = Source;
3716           NumSources = 1;
3717         } else if (NumSources == 1 && Source != Source1) {
3718           Source2 = Source;
3719           NumSources = 2;
3720         } else if (NumSources >= 2 && Source != Source1 && Source != Source2) {
3721           NumSources++;
3722         }
3723 
3724         // Add to the new mask. For the NumSources>2 case these are not correct,
3725         // but are only used for the modular lane number.
3726         if (Source == Source1)
3727           NMask.push_back(MaskElt % LTNumElts);
3728         else if (Source == Source2)
3729           NMask.push_back(MaskElt % LTNumElts + LTNumElts);
3730         else
3731           NMask.push_back(MaskElt % LTNumElts);
3732       }
3733       // If the sub-mask has at most 2 input sub-vectors then re-cost it using
3734       // getShuffleCost. If not then cost it using the worst case.
3735       if (NumSources <= 2)
3736         Cost += getShuffleCost(NumSources <= 1 ? TTI::SK_PermuteSingleSrc
3737                                                : TTI::SK_PermuteTwoSrc,
3738                                NTp, NMask, CostKind, 0, nullptr, Args);
3739       else if (any_of(enumerate(NMask), [&](const auto &ME) {
3740                  return ME.value() % LTNumElts == ME.index();
3741                }))
3742         Cost += LTNumElts - 1;
3743       else
3744         Cost += LTNumElts;
3745     }
3746     return Cost;
3747   }
3748 
3749   Kind = improveShuffleKindFromMask(Kind, Mask, Tp, Index, SubTp);
3750 
3751   // Check for broadcast loads, which are supported by the LD1R instruction.
3752   // In terms of code-size, the shuffle vector is free when a load + dup get
3753   // folded into a LD1R. That's what we check and return here. For performance
3754   // and reciprocal throughput, a LD1R is not completely free. In this case, we
3755   // return the cost for the broadcast below (i.e. 1 for most/all types), so
3756   // that we model the load + dup sequence slightly higher because LD1R is a
3757   // high latency instruction.
3758   if (CostKind == TTI::TCK_CodeSize && Kind == TTI::SK_Broadcast) {
3759     bool IsLoad = !Args.empty() && isa<LoadInst>(Args[0]);
3760     if (IsLoad && LT.second.isVector() &&
3761         isLegalBroadcastLoad(Tp->getElementType(),
3762                              LT.second.getVectorElementCount()))
3763       return 0;
3764   }
3765 
3766   // If we have 4 elements for the shuffle and a Mask, get the cost straight
3767   // from the perfect shuffle tables.
3768   if (Mask.size() == 4 && Tp->getElementCount() == ElementCount::getFixed(4) &&
3769       (Tp->getScalarSizeInBits() == 16 || Tp->getScalarSizeInBits() == 32) &&
3770       all_of(Mask, [](int E) { return E < 8; }))
3771     return getPerfectShuffleCost(Mask);
3772 
3773   if (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Transpose ||
3774       Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc ||
3775       Kind == TTI::SK_Reverse || Kind == TTI::SK_Splice) {
3776     static const CostTblEntry ShuffleTbl[] = {
3777         // Broadcast shuffle kinds can be performed with 'dup'.
3778         {TTI::SK_Broadcast, MVT::v8i8, 1},
3779         {TTI::SK_Broadcast, MVT::v16i8, 1},
3780         {TTI::SK_Broadcast, MVT::v4i16, 1},
3781         {TTI::SK_Broadcast, MVT::v8i16, 1},
3782         {TTI::SK_Broadcast, MVT::v2i32, 1},
3783         {TTI::SK_Broadcast, MVT::v4i32, 1},
3784         {TTI::SK_Broadcast, MVT::v2i64, 1},
3785         {TTI::SK_Broadcast, MVT::v4f16, 1},
3786         {TTI::SK_Broadcast, MVT::v8f16, 1},
3787         {TTI::SK_Broadcast, MVT::v2f32, 1},
3788         {TTI::SK_Broadcast, MVT::v4f32, 1},
3789         {TTI::SK_Broadcast, MVT::v2f64, 1},
3790         // Transpose shuffle kinds can be performed with 'trn1/trn2' and
3791         // 'zip1/zip2' instructions.
3792         {TTI::SK_Transpose, MVT::v8i8, 1},
3793         {TTI::SK_Transpose, MVT::v16i8, 1},
3794         {TTI::SK_Transpose, MVT::v4i16, 1},
3795         {TTI::SK_Transpose, MVT::v8i16, 1},
3796         {TTI::SK_Transpose, MVT::v2i32, 1},
3797         {TTI::SK_Transpose, MVT::v4i32, 1},
3798         {TTI::SK_Transpose, MVT::v2i64, 1},
3799         {TTI::SK_Transpose, MVT::v4f16, 1},
3800         {TTI::SK_Transpose, MVT::v8f16, 1},
3801         {TTI::SK_Transpose, MVT::v2f32, 1},
3802         {TTI::SK_Transpose, MVT::v4f32, 1},
3803         {TTI::SK_Transpose, MVT::v2f64, 1},
3804         // Select shuffle kinds.
3805         // TODO: handle vXi8/vXi16.
3806         {TTI::SK_Select, MVT::v2i32, 1}, // mov.
3807         {TTI::SK_Select, MVT::v4i32, 2}, // rev+trn (or similar).
3808         {TTI::SK_Select, MVT::v2i64, 1}, // mov.
3809         {TTI::SK_Select, MVT::v2f32, 1}, // mov.
3810         {TTI::SK_Select, MVT::v4f32, 2}, // rev+trn (or similar).
3811         {TTI::SK_Select, MVT::v2f64, 1}, // mov.
3812         // PermuteSingleSrc shuffle kinds.
3813         {TTI::SK_PermuteSingleSrc, MVT::v2i32, 1}, // mov.
3814         {TTI::SK_PermuteSingleSrc, MVT::v4i32, 3}, // perfectshuffle worst case.
3815         {TTI::SK_PermuteSingleSrc, MVT::v2i64, 1}, // mov.
3816         {TTI::SK_PermuteSingleSrc, MVT::v2f32, 1}, // mov.
3817         {TTI::SK_PermuteSingleSrc, MVT::v4f32, 3}, // perfectshuffle worst case.
3818         {TTI::SK_PermuteSingleSrc, MVT::v2f64, 1}, // mov.
3819         {TTI::SK_PermuteSingleSrc, MVT::v4i16, 3}, // perfectshuffle worst case.
3820         {TTI::SK_PermuteSingleSrc, MVT::v4f16, 3}, // perfectshuffle worst case.
3821         {TTI::SK_PermuteSingleSrc, MVT::v4bf16, 3}, // same
3822         {TTI::SK_PermuteSingleSrc, MVT::v8i16, 8},  // constpool + load + tbl
3823         {TTI::SK_PermuteSingleSrc, MVT::v8f16, 8},  // constpool + load + tbl
3824         {TTI::SK_PermuteSingleSrc, MVT::v8bf16, 8}, // constpool + load + tbl
3825         {TTI::SK_PermuteSingleSrc, MVT::v8i8, 8},   // constpool + load + tbl
3826         {TTI::SK_PermuteSingleSrc, MVT::v16i8, 8},  // constpool + load + tbl
3827         // Reverse can be lowered with `rev`.
3828         {TTI::SK_Reverse, MVT::v2i32, 1}, // REV64
3829         {TTI::SK_Reverse, MVT::v4i32, 2}, // REV64; EXT
3830         {TTI::SK_Reverse, MVT::v2i64, 1}, // EXT
3831         {TTI::SK_Reverse, MVT::v2f32, 1}, // REV64
3832         {TTI::SK_Reverse, MVT::v4f32, 2}, // REV64; EXT
3833         {TTI::SK_Reverse, MVT::v2f64, 1}, // EXT
3834         {TTI::SK_Reverse, MVT::v8f16, 2}, // REV64; EXT
3835         {TTI::SK_Reverse, MVT::v8i16, 2}, // REV64; EXT
3836         {TTI::SK_Reverse, MVT::v16i8, 2}, // REV64; EXT
3837         {TTI::SK_Reverse, MVT::v4f16, 1}, // REV64
3838         {TTI::SK_Reverse, MVT::v4i16, 1}, // REV64
3839         {TTI::SK_Reverse, MVT::v8i8, 1},  // REV64
3840         // Splice can all be lowered as `ext`.
3841         {TTI::SK_Splice, MVT::v2i32, 1},
3842         {TTI::SK_Splice, MVT::v4i32, 1},
3843         {TTI::SK_Splice, MVT::v2i64, 1},
3844         {TTI::SK_Splice, MVT::v2f32, 1},
3845         {TTI::SK_Splice, MVT::v4f32, 1},
3846         {TTI::SK_Splice, MVT::v2f64, 1},
3847         {TTI::SK_Splice, MVT::v8f16, 1},
3848         {TTI::SK_Splice, MVT::v8bf16, 1},
3849         {TTI::SK_Splice, MVT::v8i16, 1},
3850         {TTI::SK_Splice, MVT::v16i8, 1},
3851         {TTI::SK_Splice, MVT::v4bf16, 1},
3852         {TTI::SK_Splice, MVT::v4f16, 1},
3853         {TTI::SK_Splice, MVT::v4i16, 1},
3854         {TTI::SK_Splice, MVT::v8i8, 1},
3855         // Broadcast shuffle kinds for scalable vectors
3856         {TTI::SK_Broadcast, MVT::nxv16i8, 1},
3857         {TTI::SK_Broadcast, MVT::nxv8i16, 1},
3858         {TTI::SK_Broadcast, MVT::nxv4i32, 1},
3859         {TTI::SK_Broadcast, MVT::nxv2i64, 1},
3860         {TTI::SK_Broadcast, MVT::nxv2f16, 1},
3861         {TTI::SK_Broadcast, MVT::nxv4f16, 1},
3862         {TTI::SK_Broadcast, MVT::nxv8f16, 1},
3863         {TTI::SK_Broadcast, MVT::nxv2bf16, 1},
3864         {TTI::SK_Broadcast, MVT::nxv4bf16, 1},
3865         {TTI::SK_Broadcast, MVT::nxv8bf16, 1},
3866         {TTI::SK_Broadcast, MVT::nxv2f32, 1},
3867         {TTI::SK_Broadcast, MVT::nxv4f32, 1},
3868         {TTI::SK_Broadcast, MVT::nxv2f64, 1},
3869         {TTI::SK_Broadcast, MVT::nxv16i1, 1},
3870         {TTI::SK_Broadcast, MVT::nxv8i1, 1},
3871         {TTI::SK_Broadcast, MVT::nxv4i1, 1},
3872         {TTI::SK_Broadcast, MVT::nxv2i1, 1},
3873         // Handle the cases for vector.reverse with scalable vectors
3874         {TTI::SK_Reverse, MVT::nxv16i8, 1},
3875         {TTI::SK_Reverse, MVT::nxv8i16, 1},
3876         {TTI::SK_Reverse, MVT::nxv4i32, 1},
3877         {TTI::SK_Reverse, MVT::nxv2i64, 1},
3878         {TTI::SK_Reverse, MVT::nxv2f16, 1},
3879         {TTI::SK_Reverse, MVT::nxv4f16, 1},
3880         {TTI::SK_Reverse, MVT::nxv8f16, 1},
3881         {TTI::SK_Reverse, MVT::nxv2bf16, 1},
3882         {TTI::SK_Reverse, MVT::nxv4bf16, 1},
3883         {TTI::SK_Reverse, MVT::nxv8bf16, 1},
3884         {TTI::SK_Reverse, MVT::nxv2f32, 1},
3885         {TTI::SK_Reverse, MVT::nxv4f32, 1},
3886         {TTI::SK_Reverse, MVT::nxv2f64, 1},
3887         {TTI::SK_Reverse, MVT::nxv16i1, 1},
3888         {TTI::SK_Reverse, MVT::nxv8i1, 1},
3889         {TTI::SK_Reverse, MVT::nxv4i1, 1},
3890         {TTI::SK_Reverse, MVT::nxv2i1, 1},
3891     };
3892     if (const auto *Entry = CostTableLookup(ShuffleTbl, Kind, LT.second))
3893       return LT.first * Entry->Cost;
3894   }
3895 
3896   if (Kind == TTI::SK_Splice && isa<ScalableVectorType>(Tp))
3897     return getSpliceCost(Tp, Index);
3898 
3899   // Inserting a subvector can often be done with either a D, S or H register
3900   // move, so long as the inserted vector is "aligned".
3901   if (Kind == TTI::SK_InsertSubvector && LT.second.isFixedLengthVector() &&
3902       LT.second.getSizeInBits() <= 128 && SubTp) {
3903     std::pair<InstructionCost, MVT> SubLT = getTypeLegalizationCost(SubTp);
3904     if (SubLT.second.isVector()) {
3905       int NumElts = LT.second.getVectorNumElements();
3906       int NumSubElts = SubLT.second.getVectorNumElements();
3907       if ((Index % NumSubElts) == 0 && (NumElts % NumSubElts) == 0)
3908         return SubLT.first;
3909     }
3910   }
3911 
3912   return BaseT::getShuffleCost(Kind, Tp, Mask, CostKind, Index, SubTp);
3913 }
3914 
3915 static bool containsDecreasingPointers(Loop *TheLoop,
3916                                        PredicatedScalarEvolution *PSE) {
3917   const auto &Strides = DenseMap<Value *, const SCEV *>();
3918   for (BasicBlock *BB : TheLoop->blocks()) {
3919     // Scan the instructions in the block and look for addresses that are
3920     // consecutive and decreasing.
3921     for (Instruction &I : *BB) {
3922       if (isa<LoadInst>(&I) || isa<StoreInst>(&I)) {
3923         Value *Ptr = getLoadStorePointerOperand(&I);
3924         Type *AccessTy = getLoadStoreType(&I);
3925         if (getPtrStride(*PSE, AccessTy, Ptr, TheLoop, Strides, /*Assume=*/true,
3926                          /*ShouldCheckWrap=*/false)
3927                 .value_or(0) < 0)
3928           return true;
3929       }
3930     }
3931   }
3932   return false;
3933 }
3934 
3935 bool AArch64TTIImpl::preferPredicateOverEpilogue(TailFoldingInfo *TFI) {
3936   if (!ST->hasSVE())
3937     return false;
3938 
3939   // We don't currently support vectorisation with interleaving for SVE - with
3940   // such loops we're better off not using tail-folding. This gives us a chance
3941   // to fall back on fixed-width vectorisation using NEON's ld2/st2/etc.
3942   if (TFI->IAI->hasGroups())
3943     return false;
3944 
3945   TailFoldingOpts Required = TailFoldingOpts::Disabled;
3946   if (TFI->LVL->getReductionVars().size())
3947     Required |= TailFoldingOpts::Reductions;
3948   if (TFI->LVL->getFixedOrderRecurrences().size())
3949     Required |= TailFoldingOpts::Recurrences;
3950 
3951   // We call this to discover whether any load/store pointers in the loop have
3952   // negative strides. This will require extra work to reverse the loop
3953   // predicate, which may be expensive.
3954   if (containsDecreasingPointers(TFI->LVL->getLoop(),
3955                                  TFI->LVL->getPredicatedScalarEvolution()))
3956     Required |= TailFoldingOpts::Reverse;
3957   if (Required == TailFoldingOpts::Disabled)
3958     Required |= TailFoldingOpts::Simple;
3959 
3960   if (!TailFoldingOptionLoc.satisfies(ST->getSVETailFoldingDefaultOpts(),
3961                                       Required))
3962     return false;
3963 
3964   // Don't tail-fold for tight loops where we would be better off interleaving
3965   // with an unpredicated loop.
3966   unsigned NumInsns = 0;
3967   for (BasicBlock *BB : TFI->LVL->getLoop()->blocks()) {
3968     NumInsns += BB->sizeWithoutDebug();
3969   }
3970 
3971   // We expect 4 of these to be a IV PHI, IV add, IV compare and branch.
3972   return NumInsns >= SVETailFoldInsnThreshold;
3973 }
3974 
3975 InstructionCost
3976 AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
3977                                      int64_t BaseOffset, bool HasBaseReg,
3978                                      int64_t Scale, unsigned AddrSpace) const {
3979   // Scaling factors are not free at all.
3980   // Operands                     | Rt Latency
3981   // -------------------------------------------
3982   // Rt, [Xn, Xm]                 | 4
3983   // -------------------------------------------
3984   // Rt, [Xn, Xm, lsl #imm]       | Rn: 4 Rm: 5
3985   // Rt, [Xn, Wm, <extend> #imm]  |
3986   TargetLoweringBase::AddrMode AM;
3987   AM.BaseGV = BaseGV;
3988   AM.BaseOffs = BaseOffset;
3989   AM.HasBaseReg = HasBaseReg;
3990   AM.Scale = Scale;
3991   if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
3992     // Scale represents reg2 * scale, thus account for 1 if
3993     // it is not equal to 0 or 1.
3994     return AM.Scale != 0 && AM.Scale != 1;
3995   return -1;
3996 }
3997