xref: /freebsd/contrib/llvm-project/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (revision e64bea71c21eb42e97aa615188ba91f6cce0d36d)
1 //===-- AArch64TargetTransformInfo.cpp - AArch64 specific TTI -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "AArch64TargetTransformInfo.h"
10 #include "AArch64ExpandImm.h"
11 #include "AArch64PerfectShuffle.h"
12 #include "MCTargetDesc/AArch64AddressingModes.h"
13 #include "Utils/AArch64SMEAttributes.h"
14 #include "llvm/ADT/DenseMap.h"
15 #include "llvm/Analysis/LoopInfo.h"
16 #include "llvm/Analysis/TargetTransformInfo.h"
17 #include "llvm/CodeGen/BasicTTIImpl.h"
18 #include "llvm/CodeGen/CostTable.h"
19 #include "llvm/CodeGen/TargetLowering.h"
20 #include "llvm/IR/DerivedTypes.h"
21 #include "llvm/IR/IntrinsicInst.h"
22 #include "llvm/IR/Intrinsics.h"
23 #include "llvm/IR/IntrinsicsAArch64.h"
24 #include "llvm/IR/PatternMatch.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/TargetParser/AArch64TargetParser.h"
27 #include "llvm/Transforms/InstCombine/InstCombiner.h"
28 #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
29 #include <algorithm>
30 #include <optional>
31 using namespace llvm;
32 using namespace llvm::PatternMatch;
33 
34 #define DEBUG_TYPE "aarch64tti"
35 
36 static cl::opt<bool> EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix",
37                                                cl::init(true), cl::Hidden);
38 
39 static cl::opt<bool> SVEPreferFixedOverScalableIfEqualCost(
40     "sve-prefer-fixed-over-scalable-if-equal", cl::Hidden);
41 
42 static cl::opt<unsigned> SVEGatherOverhead("sve-gather-overhead", cl::init(10),
43                                            cl::Hidden);
44 
45 static cl::opt<unsigned> SVEScatterOverhead("sve-scatter-overhead",
46                                             cl::init(10), cl::Hidden);
47 
48 static cl::opt<unsigned> SVETailFoldInsnThreshold("sve-tail-folding-insn-threshold",
49                                                   cl::init(15), cl::Hidden);
50 
51 static cl::opt<unsigned>
52     NeonNonConstStrideOverhead("neon-nonconst-stride-overhead", cl::init(10),
53                                cl::Hidden);
54 
55 static cl::opt<unsigned> CallPenaltyChangeSM(
56     "call-penalty-sm-change", cl::init(5), cl::Hidden,
57     cl::desc(
58         "Penalty of calling a function that requires a change to PSTATE.SM"));
59 
60 static cl::opt<unsigned> InlineCallPenaltyChangeSM(
61     "inline-call-penalty-sm-change", cl::init(10), cl::Hidden,
62     cl::desc("Penalty of inlining a call that requires a change to PSTATE.SM"));
63 
64 static cl::opt<bool> EnableOrLikeSelectOpt("enable-aarch64-or-like-select",
65                                            cl::init(true), cl::Hidden);
66 
67 static cl::opt<bool> EnableLSRCostOpt("enable-aarch64-lsr-cost-opt",
68                                       cl::init(true), cl::Hidden);
69 
70 // A complete guess as to a reasonable cost.
71 static cl::opt<unsigned>
72     BaseHistCntCost("aarch64-base-histcnt-cost", cl::init(8), cl::Hidden,
73                     cl::desc("The cost of a histcnt instruction"));
74 
75 static cl::opt<unsigned> DMBLookaheadThreshold(
76     "dmb-lookahead-threshold", cl::init(10), cl::Hidden,
77     cl::desc("The number of instructions to search for a redundant dmb"));
78 
79 namespace {
80 class TailFoldingOption {
81   // These bitfields will only ever be set to something non-zero in operator=,
82   // when setting the -sve-tail-folding option. This option should always be of
83   // the form (default|simple|all|disable)[+(Flag1|Flag2|etc)], where here
84   // InitialBits is one of (disabled|all|simple). EnableBits represents
85   // additional flags we're enabling, and DisableBits for those flags we're
86   // disabling. The default flag is tracked in the variable NeedsDefault, since
87   // at the time of setting the option we may not know what the default value
88   // for the CPU is.
89   TailFoldingOpts InitialBits = TailFoldingOpts::Disabled;
90   TailFoldingOpts EnableBits = TailFoldingOpts::Disabled;
91   TailFoldingOpts DisableBits = TailFoldingOpts::Disabled;
92 
93   // This value needs to be initialised to true in case the user does not
94   // explicitly set the -sve-tail-folding option.
95   bool NeedsDefault = true;
96 
setInitialBits(TailFoldingOpts Bits)97   void setInitialBits(TailFoldingOpts Bits) { InitialBits = Bits; }
98 
setNeedsDefault(bool V)99   void setNeedsDefault(bool V) { NeedsDefault = V; }
100 
setEnableBit(TailFoldingOpts Bit)101   void setEnableBit(TailFoldingOpts Bit) {
102     EnableBits |= Bit;
103     DisableBits &= ~Bit;
104   }
105 
setDisableBit(TailFoldingOpts Bit)106   void setDisableBit(TailFoldingOpts Bit) {
107     EnableBits &= ~Bit;
108     DisableBits |= Bit;
109   }
110 
getBits(TailFoldingOpts DefaultBits) const111   TailFoldingOpts getBits(TailFoldingOpts DefaultBits) const {
112     TailFoldingOpts Bits = TailFoldingOpts::Disabled;
113 
114     assert((InitialBits == TailFoldingOpts::Disabled || !NeedsDefault) &&
115            "Initial bits should only include one of "
116            "(disabled|all|simple|default)");
117     Bits = NeedsDefault ? DefaultBits : InitialBits;
118     Bits |= EnableBits;
119     Bits &= ~DisableBits;
120 
121     return Bits;
122   }
123 
reportError(std::string Opt)124   void reportError(std::string Opt) {
125     errs() << "invalid argument '" << Opt
126            << "' to -sve-tail-folding=; the option should be of the form\n"
127               "  (disabled|all|default|simple)[+(reductions|recurrences"
128               "|reverse|noreductions|norecurrences|noreverse)]\n";
129     report_fatal_error("Unrecognised tail-folding option");
130   }
131 
132 public:
133 
operator =(const std::string & Val)134   void operator=(const std::string &Val) {
135     // If the user explicitly sets -sve-tail-folding= then treat as an error.
136     if (Val.empty()) {
137       reportError("");
138       return;
139     }
140 
141     // Since the user is explicitly setting the option we don't automatically
142     // need the default unless they require it.
143     setNeedsDefault(false);
144 
145     SmallVector<StringRef, 4> TailFoldTypes;
146     StringRef(Val).split(TailFoldTypes, '+', -1, false);
147 
148     unsigned StartIdx = 1;
149     if (TailFoldTypes[0] == "disabled")
150       setInitialBits(TailFoldingOpts::Disabled);
151     else if (TailFoldTypes[0] == "all")
152       setInitialBits(TailFoldingOpts::All);
153     else if (TailFoldTypes[0] == "default")
154       setNeedsDefault(true);
155     else if (TailFoldTypes[0] == "simple")
156       setInitialBits(TailFoldingOpts::Simple);
157     else {
158       StartIdx = 0;
159       setInitialBits(TailFoldingOpts::Disabled);
160     }
161 
162     for (unsigned I = StartIdx; I < TailFoldTypes.size(); I++) {
163       if (TailFoldTypes[I] == "reductions")
164         setEnableBit(TailFoldingOpts::Reductions);
165       else if (TailFoldTypes[I] == "recurrences")
166         setEnableBit(TailFoldingOpts::Recurrences);
167       else if (TailFoldTypes[I] == "reverse")
168         setEnableBit(TailFoldingOpts::Reverse);
169       else if (TailFoldTypes[I] == "noreductions")
170         setDisableBit(TailFoldingOpts::Reductions);
171       else if (TailFoldTypes[I] == "norecurrences")
172         setDisableBit(TailFoldingOpts::Recurrences);
173       else if (TailFoldTypes[I] == "noreverse")
174         setDisableBit(TailFoldingOpts::Reverse);
175       else
176         reportError(Val);
177     }
178   }
179 
satisfies(TailFoldingOpts DefaultBits,TailFoldingOpts Required) const180   bool satisfies(TailFoldingOpts DefaultBits, TailFoldingOpts Required) const {
181     return (getBits(DefaultBits) & Required) == Required;
182   }
183 };
184 } // namespace
185 
186 TailFoldingOption TailFoldingOptionLoc;
187 
188 static cl::opt<TailFoldingOption, true, cl::parser<std::string>> SVETailFolding(
189     "sve-tail-folding",
190     cl::desc(
191         "Control the use of vectorisation using tail-folding for SVE where the"
192         " option is specified in the form (Initial)[+(Flag1|Flag2|...)]:"
193         "\ndisabled      (Initial) No loop types will vectorize using "
194         "tail-folding"
195         "\ndefault       (Initial) Uses the default tail-folding settings for "
196         "the target CPU"
197         "\nall           (Initial) All legal loop types will vectorize using "
198         "tail-folding"
199         "\nsimple        (Initial) Use tail-folding for simple loops (not "
200         "reductions or recurrences)"
201         "\nreductions    Use tail-folding for loops containing reductions"
202         "\nnoreductions  Inverse of above"
203         "\nrecurrences   Use tail-folding for loops containing fixed order "
204         "recurrences"
205         "\nnorecurrences Inverse of above"
206         "\nreverse       Use tail-folding for loops requiring reversed "
207         "predicates"
208         "\nnoreverse     Inverse of above"),
209     cl::location(TailFoldingOptionLoc));
210 
211 // Experimental option that will only be fully functional when the
212 // code-generator is changed to use SVE instead of NEON for all fixed-width
213 // operations.
214 static cl::opt<bool> EnableFixedwidthAutovecInStreamingMode(
215     "enable-fixedwidth-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
216 
217 // Experimental option that will only be fully functional when the cost-model
218 // and code-generator have been changed to avoid using scalable vector
219 // instructions that are not legal in streaming SVE mode.
220 static cl::opt<bool> EnableScalableAutovecInStreamingMode(
221     "enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
222 
isSMEABIRoutineCall(const CallInst & CI)223 static bool isSMEABIRoutineCall(const CallInst &CI) {
224   const auto *F = CI.getCalledFunction();
225   return F && StringSwitch<bool>(F->getName())
226                   .Case("__arm_sme_state", true)
227                   .Case("__arm_tpidr2_save", true)
228                   .Case("__arm_tpidr2_restore", true)
229                   .Case("__arm_za_disable", true)
230                   .Default(false);
231 }
232 
233 /// Returns true if the function has explicit operations that can only be
234 /// lowered using incompatible instructions for the selected mode. This also
235 /// returns true if the function F may use or modify ZA state.
hasPossibleIncompatibleOps(const Function * F)236 static bool hasPossibleIncompatibleOps(const Function *F) {
237   for (const BasicBlock &BB : *F) {
238     for (const Instruction &I : BB) {
239       // Be conservative for now and assume that any call to inline asm or to
240       // intrinsics could could result in non-streaming ops (e.g. calls to
241       // @llvm.aarch64.* or @llvm.gather/scatter intrinsics). We can assume that
242       // all native LLVM instructions can be lowered to compatible instructions.
243       if (isa<CallInst>(I) && !I.isDebugOrPseudoInst() &&
244           (cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
245            isSMEABIRoutineCall(cast<CallInst>(I))))
246         return true;
247     }
248   }
249   return false;
250 }
251 
getFeatureMask(const Function & F) const252 uint64_t AArch64TTIImpl::getFeatureMask(const Function &F) const {
253   StringRef AttributeStr =
254       isMultiversionedFunction(F) ? "fmv-features" : "target-features";
255   StringRef FeatureStr = F.getFnAttribute(AttributeStr).getValueAsString();
256   SmallVector<StringRef, 8> Features;
257   FeatureStr.split(Features, ",");
258   return AArch64::getFMVPriority(Features);
259 }
260 
isMultiversionedFunction(const Function & F) const261 bool AArch64TTIImpl::isMultiversionedFunction(const Function &F) const {
262   return F.hasFnAttribute("fmv-features");
263 }
264 
265 const FeatureBitset AArch64TTIImpl::InlineInverseFeatures = {
266     AArch64::FeatureExecuteOnly,
267 };
268 
areInlineCompatible(const Function * Caller,const Function * Callee) const269 bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
270                                          const Function *Callee) const {
271   SMECallAttrs CallAttrs(*Caller, *Callee);
272 
273   // When inlining, we should consider the body of the function, not the
274   // interface.
275   if (CallAttrs.callee().hasStreamingBody()) {
276     CallAttrs.callee().set(SMEAttrs::SM_Compatible, false);
277     CallAttrs.callee().set(SMEAttrs::SM_Enabled, true);
278   }
279 
280   if (CallAttrs.callee().isNewZA() || CallAttrs.callee().isNewZT0())
281     return false;
282 
283   if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() ||
284       CallAttrs.requiresPreservingZT0() ||
285       CallAttrs.requiresPreservingAllZAState()) {
286     if (hasPossibleIncompatibleOps(Callee))
287       return false;
288   }
289 
290   const TargetMachine &TM = getTLI()->getTargetMachine();
291   const FeatureBitset &CallerBits =
292       TM.getSubtargetImpl(*Caller)->getFeatureBits();
293   const FeatureBitset &CalleeBits =
294       TM.getSubtargetImpl(*Callee)->getFeatureBits();
295   // Adjust the feature bitsets by inverting some of the bits. This is needed
296   // for target features that represent restrictions rather than capabilities,
297   // for example a "+execute-only" callee can be inlined into a caller without
298   // "+execute-only", but not vice versa.
299   FeatureBitset EffectiveCallerBits = CallerBits ^ InlineInverseFeatures;
300   FeatureBitset EffectiveCalleeBits = CalleeBits ^ InlineInverseFeatures;
301 
302   return (EffectiveCallerBits & EffectiveCalleeBits) == EffectiveCalleeBits;
303 }
304 
areTypesABICompatible(const Function * Caller,const Function * Callee,const ArrayRef<Type * > & Types) const305 bool AArch64TTIImpl::areTypesABICompatible(
306     const Function *Caller, const Function *Callee,
307     const ArrayRef<Type *> &Types) const {
308   if (!BaseT::areTypesABICompatible(Caller, Callee, Types))
309     return false;
310 
311   // We need to ensure that argument promotion does not attempt to promote
312   // pointers to fixed-length vector types larger than 128 bits like
313   // <8 x float> (and pointers to aggregate types which have such fixed-length
314   // vector type members) into the values of the pointees. Such vector types
315   // are used for SVE VLS but there is no ABI for SVE VLS arguments and the
316   // backend cannot lower such value arguments. The 128-bit fixed-length SVE
317   // types can be safely treated as 128-bit NEON types and they cannot be
318   // distinguished in IR.
319   if (ST->useSVEForFixedLengthVectors() && llvm::any_of(Types, [](Type *Ty) {
320         auto FVTy = dyn_cast<FixedVectorType>(Ty);
321         return FVTy &&
322                FVTy->getScalarSizeInBits() * FVTy->getNumElements() > 128;
323       }))
324     return false;
325 
326   return true;
327 }
328 
329 unsigned
getInlineCallPenalty(const Function * F,const CallBase & Call,unsigned DefaultCallPenalty) const330 AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
331                                      unsigned DefaultCallPenalty) const {
332   // This function calculates a penalty for executing Call in F.
333   //
334   // There are two ways this function can be called:
335   // (1)  F:
336   //       call from F -> G (the call here is Call)
337   //
338   // For (1), Call.getCaller() == F, so it will always return a high cost if
339   // a streaming-mode change is required (thus promoting the need to inline the
340   // function)
341   //
342   // (2)  F:
343   //       call from F -> G (the call here is not Call)
344   //      G:
345   //       call from G -> H (the call here is Call)
346   //
347   // For (2), if after inlining the body of G into F the call to H requires a
348   // streaming-mode change, and the call to G from F would also require a
349   // streaming-mode change, then there is benefit to do the streaming-mode
350   // change only once and avoid inlining of G into F.
351 
352   SMEAttrs FAttrs(*F);
353   SMECallAttrs CallAttrs(Call);
354 
355   if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) {
356     if (F == Call.getCaller()) // (1)
357       return CallPenaltyChangeSM * DefaultCallPenalty;
358     if (SMECallAttrs(FAttrs, CallAttrs.caller()).requiresSMChange()) // (2)
359       return InlineCallPenaltyChangeSM * DefaultCallPenalty;
360   }
361 
362   return DefaultCallPenalty;
363 }
364 
shouldMaximizeVectorBandwidth(TargetTransformInfo::RegisterKind K) const365 bool AArch64TTIImpl::shouldMaximizeVectorBandwidth(
366     TargetTransformInfo::RegisterKind K) const {
367   assert(K != TargetTransformInfo::RGK_Scalar);
368   return (K == TargetTransformInfo::RGK_FixedWidthVector &&
369           ST->isNeonAvailable());
370 }
371 
372 /// Calculate the cost of materializing a 64-bit value. This helper
373 /// method might only calculate a fraction of a larger immediate. Therefore it
374 /// is valid to return a cost of ZERO.
getIntImmCost(int64_t Val) const375 InstructionCost AArch64TTIImpl::getIntImmCost(int64_t Val) const {
376   // Check if the immediate can be encoded within an instruction.
377   if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64))
378     return 0;
379 
380   if (Val < 0)
381     Val = ~Val;
382 
383   // Calculate how many moves we will need to materialize this constant.
384   SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn;
385   AArch64_IMM::expandMOVImm(Val, 64, Insn);
386   return Insn.size();
387 }
388 
389 /// Calculate the cost of materializing the given constant.
390 InstructionCost
getIntImmCost(const APInt & Imm,Type * Ty,TTI::TargetCostKind CostKind) const391 AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
392                               TTI::TargetCostKind CostKind) const {
393   assert(Ty->isIntegerTy());
394 
395   unsigned BitSize = Ty->getPrimitiveSizeInBits();
396   if (BitSize == 0)
397     return ~0U;
398 
399   // Sign-extend all constants to a multiple of 64-bit.
400   APInt ImmVal = Imm;
401   if (BitSize & 0x3f)
402     ImmVal = Imm.sext((BitSize + 63) & ~0x3fU);
403 
404   // Split the constant into 64-bit chunks and calculate the cost for each
405   // chunk.
406   InstructionCost Cost = 0;
407   for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) {
408     APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64);
409     int64_t Val = Tmp.getSExtValue();
410     Cost += getIntImmCost(Val);
411   }
412   // We need at least one instruction to materialze the constant.
413   return std::max<InstructionCost>(1, Cost);
414 }
415 
getIntImmCostInst(unsigned Opcode,unsigned Idx,const APInt & Imm,Type * Ty,TTI::TargetCostKind CostKind,Instruction * Inst) const416 InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
417                                                   const APInt &Imm, Type *Ty,
418                                                   TTI::TargetCostKind CostKind,
419                                                   Instruction *Inst) const {
420   assert(Ty->isIntegerTy());
421 
422   unsigned BitSize = Ty->getPrimitiveSizeInBits();
423   // There is no cost model for constants with a bit size of 0. Return TCC_Free
424   // here, so that constant hoisting will ignore this constant.
425   if (BitSize == 0)
426     return TTI::TCC_Free;
427 
428   unsigned ImmIdx = ~0U;
429   switch (Opcode) {
430   default:
431     return TTI::TCC_Free;
432   case Instruction::GetElementPtr:
433     // Always hoist the base address of a GetElementPtr.
434     if (Idx == 0)
435       return 2 * TTI::TCC_Basic;
436     return TTI::TCC_Free;
437   case Instruction::Store:
438     ImmIdx = 0;
439     break;
440   case Instruction::Add:
441   case Instruction::Sub:
442   case Instruction::Mul:
443   case Instruction::UDiv:
444   case Instruction::SDiv:
445   case Instruction::URem:
446   case Instruction::SRem:
447   case Instruction::And:
448   case Instruction::Or:
449   case Instruction::Xor:
450   case Instruction::ICmp:
451     ImmIdx = 1;
452     break;
453   // Always return TCC_Free for the shift value of a shift instruction.
454   case Instruction::Shl:
455   case Instruction::LShr:
456   case Instruction::AShr:
457     if (Idx == 1)
458       return TTI::TCC_Free;
459     break;
460   case Instruction::Trunc:
461   case Instruction::ZExt:
462   case Instruction::SExt:
463   case Instruction::IntToPtr:
464   case Instruction::PtrToInt:
465   case Instruction::BitCast:
466   case Instruction::PHI:
467   case Instruction::Call:
468   case Instruction::Select:
469   case Instruction::Ret:
470   case Instruction::Load:
471     break;
472   }
473 
474   if (Idx == ImmIdx) {
475     int NumConstants = (BitSize + 63) / 64;
476     InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
477     return (Cost <= NumConstants * TTI::TCC_Basic)
478                ? static_cast<int>(TTI::TCC_Free)
479                : Cost;
480   }
481   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
482 }
483 
484 InstructionCost
getIntImmCostIntrin(Intrinsic::ID IID,unsigned Idx,const APInt & Imm,Type * Ty,TTI::TargetCostKind CostKind) const485 AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
486                                     const APInt &Imm, Type *Ty,
487                                     TTI::TargetCostKind CostKind) const {
488   assert(Ty->isIntegerTy());
489 
490   unsigned BitSize = Ty->getPrimitiveSizeInBits();
491   // There is no cost model for constants with a bit size of 0. Return TCC_Free
492   // here, so that constant hoisting will ignore this constant.
493   if (BitSize == 0)
494     return TTI::TCC_Free;
495 
496   // Most (all?) AArch64 intrinsics do not support folding immediates into the
497   // selected instruction, so we compute the materialization cost for the
498   // immediate directly.
499   if (IID >= Intrinsic::aarch64_addg && IID <= Intrinsic::aarch64_udiv)
500     return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
501 
502   switch (IID) {
503   default:
504     return TTI::TCC_Free;
505   case Intrinsic::sadd_with_overflow:
506   case Intrinsic::uadd_with_overflow:
507   case Intrinsic::ssub_with_overflow:
508   case Intrinsic::usub_with_overflow:
509   case Intrinsic::smul_with_overflow:
510   case Intrinsic::umul_with_overflow:
511     if (Idx == 1) {
512       int NumConstants = (BitSize + 63) / 64;
513       InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
514       return (Cost <= NumConstants * TTI::TCC_Basic)
515                  ? static_cast<int>(TTI::TCC_Free)
516                  : Cost;
517     }
518     break;
519   case Intrinsic::experimental_stackmap:
520     if ((Idx < 2) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
521       return TTI::TCC_Free;
522     break;
523   case Intrinsic::experimental_patchpoint_void:
524   case Intrinsic::experimental_patchpoint:
525     if ((Idx < 4) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
526       return TTI::TCC_Free;
527     break;
528   case Intrinsic::experimental_gc_statepoint:
529     if ((Idx < 5) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
530       return TTI::TCC_Free;
531     break;
532   }
533   return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
534 }
535 
536 TargetTransformInfo::PopcntSupportKind
getPopcntSupport(unsigned TyWidth) const537 AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) const {
538   assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2");
539   if (TyWidth == 32 || TyWidth == 64)
540     return TTI::PSK_FastHardware;
541   // TODO: AArch64TargetLowering::LowerCTPOP() supports 128bit popcount.
542   return TTI::PSK_Software;
543 }
544 
isUnpackedVectorVT(EVT VecVT)545 static bool isUnpackedVectorVT(EVT VecVT) {
546   return VecVT.isScalableVector() &&
547          VecVT.getSizeInBits().getKnownMinValue() < AArch64::SVEBitsPerBlock;
548 }
549 
getHistogramCost(const IntrinsicCostAttributes & ICA)550 static InstructionCost getHistogramCost(const IntrinsicCostAttributes &ICA) {
551   Type *BucketPtrsTy = ICA.getArgTypes()[0]; // Type of vector of pointers
552   Type *EltTy = ICA.getArgTypes()[1];        // Type of bucket elements
553   unsigned TotalHistCnts = 1;
554 
555   unsigned EltSize = EltTy->getScalarSizeInBits();
556   // Only allow (up to 64b) integers or pointers
557   if ((!EltTy->isIntegerTy() && !EltTy->isPointerTy()) || EltSize > 64)
558     return InstructionCost::getInvalid();
559 
560   // FIXME: We should be able to generate histcnt for fixed-length vectors
561   //        using ptrue with a specific VL.
562   if (VectorType *VTy = dyn_cast<VectorType>(BucketPtrsTy)) {
563     unsigned EC = VTy->getElementCount().getKnownMinValue();
564     if (!isPowerOf2_64(EC) || !VTy->isScalableTy())
565       return InstructionCost::getInvalid();
566 
567     // HistCnt only supports 32b and 64b element types
568     unsigned LegalEltSize = EltSize <= 32 ? 32 : 64;
569 
570     if (EC == 2 || (LegalEltSize == 32 && EC == 4))
571       return InstructionCost(BaseHistCntCost);
572 
573     unsigned NaturalVectorWidth = AArch64::SVEBitsPerBlock / LegalEltSize;
574     TotalHistCnts = EC / NaturalVectorWidth;
575   }
576 
577   return InstructionCost(BaseHistCntCost * TotalHistCnts);
578 }
579 
580 InstructionCost
getIntrinsicInstrCost(const IntrinsicCostAttributes & ICA,TTI::TargetCostKind CostKind) const581 AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
582                                       TTI::TargetCostKind CostKind) const {
583   // The code-generator is currently not able to handle scalable vectors
584   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
585   // it. This change will be removed when code-generation for these types is
586   // sufficiently reliable.
587   auto *RetTy = ICA.getReturnType();
588   if (auto *VTy = dyn_cast<ScalableVectorType>(RetTy))
589     if (VTy->getElementCount() == ElementCount::getScalable(1))
590       return InstructionCost::getInvalid();
591 
592   switch (ICA.getID()) {
593   case Intrinsic::experimental_vector_histogram_add:
594     if (!ST->hasSVE2())
595       return InstructionCost::getInvalid();
596     return getHistogramCost(ICA);
597   case Intrinsic::umin:
598   case Intrinsic::umax:
599   case Intrinsic::smin:
600   case Intrinsic::smax: {
601     static const auto ValidMinMaxTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
602                                         MVT::v8i16, MVT::v2i32, MVT::v4i32,
603                                         MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32,
604                                         MVT::nxv2i64};
605     auto LT = getTypeLegalizationCost(RetTy);
606     // v2i64 types get converted to cmp+bif hence the cost of 2
607     if (LT.second == MVT::v2i64)
608       return LT.first * 2;
609     if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }))
610       return LT.first;
611     break;
612   }
613   case Intrinsic::sadd_sat:
614   case Intrinsic::ssub_sat:
615   case Intrinsic::uadd_sat:
616   case Intrinsic::usub_sat: {
617     static const auto ValidSatTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
618                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
619                                      MVT::v2i64};
620     auto LT = getTypeLegalizationCost(RetTy);
621     // This is a base cost of 1 for the vadd, plus 3 extract shifts if we
622     // need to extend the type, as it uses shr(qadd(shl, shl)).
623     unsigned Instrs =
624         LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4;
625     if (any_of(ValidSatTys, [&LT](MVT M) { return M == LT.second; }))
626       return LT.first * Instrs;
627     break;
628   }
629   case Intrinsic::abs: {
630     static const auto ValidAbsTys = {MVT::v8i8,  MVT::v16i8, MVT::v4i16,
631                                      MVT::v8i16, MVT::v2i32, MVT::v4i32,
632                                      MVT::v2i64};
633     auto LT = getTypeLegalizationCost(RetTy);
634     if (any_of(ValidAbsTys, [&LT](MVT M) { return M == LT.second; }))
635       return LT.first;
636     break;
637   }
638   case Intrinsic::bswap: {
639     static const auto ValidAbsTys = {MVT::v4i16, MVT::v8i16, MVT::v2i32,
640                                      MVT::v4i32, MVT::v2i64};
641     auto LT = getTypeLegalizationCost(RetTy);
642     if (any_of(ValidAbsTys, [&LT](MVT M) { return M == LT.second; }) &&
643         LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits())
644       return LT.first;
645     break;
646   }
647   case Intrinsic::stepvector: {
648     InstructionCost Cost = 1; // Cost of the `index' instruction
649     auto LT = getTypeLegalizationCost(RetTy);
650     // Legalisation of illegal vectors involves an `index' instruction plus
651     // (LT.first - 1) vector adds.
652     if (LT.first > 1) {
653       Type *LegalVTy = EVT(LT.second).getTypeForEVT(RetTy->getContext());
654       InstructionCost AddCost =
655           getArithmeticInstrCost(Instruction::Add, LegalVTy, CostKind);
656       Cost += AddCost * (LT.first - 1);
657     }
658     return Cost;
659   }
660   case Intrinsic::vector_extract:
661   case Intrinsic::vector_insert: {
662     // If both the vector and subvector types are legal types and the index
663     // is 0, then this should be a no-op or simple operation; return a
664     // relatively low cost.
665 
666     // If arguments aren't actually supplied, then we cannot determine the
667     // value of the index. We also want to skip predicate types.
668     if (ICA.getArgs().size() != ICA.getArgTypes().size() ||
669         ICA.getReturnType()->getScalarType()->isIntegerTy(1))
670       break;
671 
672     LLVMContext &C = RetTy->getContext();
673     EVT VecVT = getTLI()->getValueType(DL, ICA.getArgTypes()[0]);
674     bool IsExtract = ICA.getID() == Intrinsic::vector_extract;
675     EVT SubVecVT = IsExtract ? getTLI()->getValueType(DL, RetTy)
676                              : getTLI()->getValueType(DL, ICA.getArgTypes()[1]);
677     // Skip this if either the vector or subvector types are unpacked
678     // SVE types; they may get lowered to stack stores and loads.
679     if (isUnpackedVectorVT(VecVT) || isUnpackedVectorVT(SubVecVT))
680       break;
681 
682     TargetLoweringBase::LegalizeKind SubVecLK =
683         getTLI()->getTypeConversion(C, SubVecVT);
684     TargetLoweringBase::LegalizeKind VecLK =
685         getTLI()->getTypeConversion(C, VecVT);
686     const Value *Idx = IsExtract ? ICA.getArgs()[1] : ICA.getArgs()[2];
687     const ConstantInt *CIdx = cast<ConstantInt>(Idx);
688     if (SubVecLK.first == TargetLoweringBase::TypeLegal &&
689         VecLK.first == TargetLoweringBase::TypeLegal && CIdx->isZero())
690       return TTI::TCC_Free;
691     break;
692   }
693   case Intrinsic::bitreverse: {
694     static const CostTblEntry BitreverseTbl[] = {
695         {Intrinsic::bitreverse, MVT::i32, 1},
696         {Intrinsic::bitreverse, MVT::i64, 1},
697         {Intrinsic::bitreverse, MVT::v8i8, 1},
698         {Intrinsic::bitreverse, MVT::v16i8, 1},
699         {Intrinsic::bitreverse, MVT::v4i16, 2},
700         {Intrinsic::bitreverse, MVT::v8i16, 2},
701         {Intrinsic::bitreverse, MVT::v2i32, 2},
702         {Intrinsic::bitreverse, MVT::v4i32, 2},
703         {Intrinsic::bitreverse, MVT::v1i64, 2},
704         {Intrinsic::bitreverse, MVT::v2i64, 2},
705     };
706     const auto LegalisationCost = getTypeLegalizationCost(RetTy);
707     const auto *Entry =
708         CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second);
709     if (Entry) {
710       // Cost Model is using the legal type(i32) that i8 and i16 will be
711       // converted to +1 so that we match the actual lowering cost
712       if (TLI->getValueType(DL, RetTy, true) == MVT::i8 ||
713           TLI->getValueType(DL, RetTy, true) == MVT::i16)
714         return LegalisationCost.first * Entry->Cost + 1;
715 
716       return LegalisationCost.first * Entry->Cost;
717     }
718     break;
719   }
720   case Intrinsic::ctpop: {
721     if (!ST->hasNEON()) {
722       // 32-bit or 64-bit ctpop without NEON is 12 instructions.
723       return getTypeLegalizationCost(RetTy).first * 12;
724     }
725     static const CostTblEntry CtpopCostTbl[] = {
726         {ISD::CTPOP, MVT::v2i64, 4},
727         {ISD::CTPOP, MVT::v4i32, 3},
728         {ISD::CTPOP, MVT::v8i16, 2},
729         {ISD::CTPOP, MVT::v16i8, 1},
730         {ISD::CTPOP, MVT::i64,   4},
731         {ISD::CTPOP, MVT::v2i32, 3},
732         {ISD::CTPOP, MVT::v4i16, 2},
733         {ISD::CTPOP, MVT::v8i8,  1},
734         {ISD::CTPOP, MVT::i32,   5},
735     };
736     auto LT = getTypeLegalizationCost(RetTy);
737     MVT MTy = LT.second;
738     if (const auto *Entry = CostTableLookup(CtpopCostTbl, ISD::CTPOP, MTy)) {
739       // Extra cost of +1 when illegal vector types are legalized by promoting
740       // the integer type.
741       int ExtraCost = MTy.isVector() && MTy.getScalarSizeInBits() !=
742                                             RetTy->getScalarSizeInBits()
743                           ? 1
744                           : 0;
745       return LT.first * Entry->Cost + ExtraCost;
746     }
747     break;
748   }
749   case Intrinsic::sadd_with_overflow:
750   case Intrinsic::uadd_with_overflow:
751   case Intrinsic::ssub_with_overflow:
752   case Intrinsic::usub_with_overflow:
753   case Intrinsic::smul_with_overflow:
754   case Intrinsic::umul_with_overflow: {
755     static const CostTblEntry WithOverflowCostTbl[] = {
756         {Intrinsic::sadd_with_overflow, MVT::i8, 3},
757         {Intrinsic::uadd_with_overflow, MVT::i8, 3},
758         {Intrinsic::sadd_with_overflow, MVT::i16, 3},
759         {Intrinsic::uadd_with_overflow, MVT::i16, 3},
760         {Intrinsic::sadd_with_overflow, MVT::i32, 1},
761         {Intrinsic::uadd_with_overflow, MVT::i32, 1},
762         {Intrinsic::sadd_with_overflow, MVT::i64, 1},
763         {Intrinsic::uadd_with_overflow, MVT::i64, 1},
764         {Intrinsic::ssub_with_overflow, MVT::i8, 3},
765         {Intrinsic::usub_with_overflow, MVT::i8, 3},
766         {Intrinsic::ssub_with_overflow, MVT::i16, 3},
767         {Intrinsic::usub_with_overflow, MVT::i16, 3},
768         {Intrinsic::ssub_with_overflow, MVT::i32, 1},
769         {Intrinsic::usub_with_overflow, MVT::i32, 1},
770         {Intrinsic::ssub_with_overflow, MVT::i64, 1},
771         {Intrinsic::usub_with_overflow, MVT::i64, 1},
772         {Intrinsic::smul_with_overflow, MVT::i8, 5},
773         {Intrinsic::umul_with_overflow, MVT::i8, 4},
774         {Intrinsic::smul_with_overflow, MVT::i16, 5},
775         {Intrinsic::umul_with_overflow, MVT::i16, 4},
776         {Intrinsic::smul_with_overflow, MVT::i32, 2}, // eg umull;tst
777         {Intrinsic::umul_with_overflow, MVT::i32, 2}, // eg umull;cmp sxtw
778         {Intrinsic::smul_with_overflow, MVT::i64, 3}, // eg mul;smulh;cmp
779         {Intrinsic::umul_with_overflow, MVT::i64, 3}, // eg mul;umulh;cmp asr
780     };
781     EVT MTy = TLI->getValueType(DL, RetTy->getContainedType(0), true);
782     if (MTy.isSimple())
783       if (const auto *Entry = CostTableLookup(WithOverflowCostTbl, ICA.getID(),
784                                               MTy.getSimpleVT()))
785         return Entry->Cost;
786     break;
787   }
788   case Intrinsic::fptosi_sat:
789   case Intrinsic::fptoui_sat: {
790     if (ICA.getArgTypes().empty())
791       break;
792     bool IsSigned = ICA.getID() == Intrinsic::fptosi_sat;
793     auto LT = getTypeLegalizationCost(ICA.getArgTypes()[0]);
794     EVT MTy = TLI->getValueType(DL, RetTy);
795     // Check for the legal types, which are where the size of the input and the
796     // output are the same, or we are using cvt f64->i32 or f32->i64.
797     if ((LT.second == MVT::f32 || LT.second == MVT::f64 ||
798          LT.second == MVT::v2f32 || LT.second == MVT::v4f32 ||
799          LT.second == MVT::v2f64)) {
800       if ((LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits() ||
801            (LT.second == MVT::f64 && MTy == MVT::i32) ||
802            (LT.second == MVT::f32 && MTy == MVT::i64)))
803         return LT.first;
804       // Extending vector types v2f32->v2i64, fcvtl*2 + fcvt*2
805       if (LT.second.getScalarType() == MVT::f32 && MTy.isFixedLengthVector() &&
806           MTy.getScalarSizeInBits() == 64)
807         return LT.first * (MTy.getVectorNumElements() > 2 ? 4 : 2);
808     }
809     // Similarly for fp16 sizes. Without FullFP16 we generally need to fcvt to
810     // f32.
811     if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
812       return LT.first + getIntrinsicInstrCost(
813                             {ICA.getID(),
814                              RetTy,
815                              {ICA.getArgTypes()[0]->getWithNewType(
816                                  Type::getFloatTy(RetTy->getContext()))}},
817                             CostKind);
818     if ((LT.second == MVT::f16 && MTy == MVT::i32) ||
819         (LT.second == MVT::f16 && MTy == MVT::i64) ||
820         ((LT.second == MVT::v4f16 || LT.second == MVT::v8f16) &&
821          (LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits())))
822       return LT.first;
823     // Extending vector types v8f16->v8i32, fcvtl*2 + fcvt*2
824     if (LT.second.getScalarType() == MVT::f16 && MTy.isFixedLengthVector() &&
825         MTy.getScalarSizeInBits() == 32)
826       return LT.first * (MTy.getVectorNumElements() > 4 ? 4 : 2);
827     // Extending vector types v8f16->v8i32. These current scalarize but the
828     // codegen could be better.
829     if (LT.second.getScalarType() == MVT::f16 && MTy.isFixedLengthVector() &&
830         MTy.getScalarSizeInBits() == 64)
831       return MTy.getVectorNumElements() * 3;
832 
833     // If we can we use a legal convert followed by a min+max
834     if ((LT.second.getScalarType() == MVT::f32 ||
835          LT.second.getScalarType() == MVT::f64 ||
836          LT.second.getScalarType() == MVT::f16) &&
837         LT.second.getScalarSizeInBits() >= MTy.getScalarSizeInBits()) {
838       Type *LegalTy =
839           Type::getIntNTy(RetTy->getContext(), LT.second.getScalarSizeInBits());
840       if (LT.second.isVector())
841         LegalTy = VectorType::get(LegalTy, LT.second.getVectorElementCount());
842       InstructionCost Cost = 1;
843       IntrinsicCostAttributes Attrs1(IsSigned ? Intrinsic::smin : Intrinsic::umin,
844                                     LegalTy, {LegalTy, LegalTy});
845       Cost += getIntrinsicInstrCost(Attrs1, CostKind);
846       IntrinsicCostAttributes Attrs2(IsSigned ? Intrinsic::smax : Intrinsic::umax,
847                                     LegalTy, {LegalTy, LegalTy});
848       Cost += getIntrinsicInstrCost(Attrs2, CostKind);
849       return LT.first * Cost +
850              ((LT.second.getScalarType() != MVT::f16 || ST->hasFullFP16()) ? 0
851                                                                            : 1);
852     }
853     // Otherwise we need to follow the default expansion that clamps the value
854     // using a float min/max with a fcmp+sel for nan handling when signed.
855     Type *FPTy = ICA.getArgTypes()[0]->getScalarType();
856     RetTy = RetTy->getScalarType();
857     if (LT.second.isVector()) {
858       FPTy = VectorType::get(FPTy, LT.second.getVectorElementCount());
859       RetTy = VectorType::get(RetTy, LT.second.getVectorElementCount());
860     }
861     IntrinsicCostAttributes Attrs1(Intrinsic::minnum, FPTy, {FPTy, FPTy});
862     InstructionCost Cost = getIntrinsicInstrCost(Attrs1, CostKind);
863     IntrinsicCostAttributes Attrs2(Intrinsic::maxnum, FPTy, {FPTy, FPTy});
864     Cost += getIntrinsicInstrCost(Attrs2, CostKind);
865     Cost +=
866         getCastInstrCost(IsSigned ? Instruction::FPToSI : Instruction::FPToUI,
867                          RetTy, FPTy, TTI::CastContextHint::None, CostKind);
868     if (IsSigned) {
869       Type *CondTy = RetTy->getWithNewBitWidth(1);
870       Cost += getCmpSelInstrCost(BinaryOperator::FCmp, FPTy, CondTy,
871                                  CmpInst::FCMP_UNO, CostKind);
872       Cost += getCmpSelInstrCost(BinaryOperator::Select, RetTy, CondTy,
873                                  CmpInst::FCMP_UNO, CostKind);
874     }
875     return LT.first * Cost;
876   }
877   case Intrinsic::fshl:
878   case Intrinsic::fshr: {
879     if (ICA.getArgs().empty())
880       break;
881 
882     // TODO: Add handling for fshl where third argument is not a constant.
883     const TTI::OperandValueInfo OpInfoZ = TTI::getOperandInfo(ICA.getArgs()[2]);
884     if (!OpInfoZ.isConstant())
885       break;
886 
887     const auto LegalisationCost = getTypeLegalizationCost(RetTy);
888     if (OpInfoZ.isUniform()) {
889       static const CostTblEntry FshlTbl[] = {
890           {Intrinsic::fshl, MVT::v4i32, 2}, // shl + usra
891           {Intrinsic::fshl, MVT::v2i64, 2}, {Intrinsic::fshl, MVT::v16i8, 2},
892           {Intrinsic::fshl, MVT::v8i16, 2}, {Intrinsic::fshl, MVT::v2i32, 2},
893           {Intrinsic::fshl, MVT::v8i8, 2},  {Intrinsic::fshl, MVT::v4i16, 2}};
894       // Costs for both fshl & fshr are the same, so just pass Intrinsic::fshl
895       // to avoid having to duplicate the costs.
896       const auto *Entry =
897           CostTableLookup(FshlTbl, Intrinsic::fshl, LegalisationCost.second);
898       if (Entry)
899         return LegalisationCost.first * Entry->Cost;
900     }
901 
902     auto TyL = getTypeLegalizationCost(RetTy);
903     if (!RetTy->isIntegerTy())
904       break;
905 
906     // Estimate cost manually, as types like i8 and i16 will get promoted to
907     // i32 and CostTableLookup will ignore the extra conversion cost.
908     bool HigherCost = (RetTy->getScalarSizeInBits() != 32 &&
909                        RetTy->getScalarSizeInBits() < 64) ||
910                       (RetTy->getScalarSizeInBits() % 64 != 0);
911     unsigned ExtraCost = HigherCost ? 1 : 0;
912     if (RetTy->getScalarSizeInBits() == 32 ||
913         RetTy->getScalarSizeInBits() == 64)
914       ExtraCost = 0; // fhsl/fshr for i32 and i64 can be lowered to a single
915                      // extr instruction.
916     else if (HigherCost)
917       ExtraCost = 1;
918     else
919       break;
920     return TyL.first + ExtraCost;
921   }
922   case Intrinsic::get_active_lane_mask: {
923     auto *RetTy = dyn_cast<FixedVectorType>(ICA.getReturnType());
924     if (RetTy) {
925       EVT RetVT = getTLI()->getValueType(DL, RetTy);
926       EVT OpVT = getTLI()->getValueType(DL, ICA.getArgTypes()[0]);
927       if (!getTLI()->shouldExpandGetActiveLaneMask(RetVT, OpVT) &&
928           !getTLI()->isTypeLegal(RetVT)) {
929         // We don't have enough context at this point to determine if the mask
930         // is going to be kept live after the block, which will force the vXi1
931         // type to be expanded to legal vectors of integers, e.g. v4i1->v4i32.
932         // For now, we just assume the vectorizer created this intrinsic and
933         // the result will be the input for a PHI. In this case the cost will
934         // be extremely high for fixed-width vectors.
935         // NOTE: getScalarizationOverhead returns a cost that's far too
936         // pessimistic for the actual generated codegen. In reality there are
937         // two instructions generated per lane.
938         return RetTy->getNumElements() * 2;
939       }
940     }
941     break;
942   }
943   case Intrinsic::experimental_vector_match: {
944     auto *NeedleTy = cast<FixedVectorType>(ICA.getArgTypes()[1]);
945     EVT SearchVT = getTLI()->getValueType(DL, ICA.getArgTypes()[0]);
946     unsigned SearchSize = NeedleTy->getNumElements();
947     if (!getTLI()->shouldExpandVectorMatch(SearchVT, SearchSize)) {
948       // Base cost for MATCH instructions. At least on the Neoverse V2 and
949       // Neoverse V3, these are cheap operations with the same latency as a
950       // vector ADD. In most cases, however, we also need to do an extra DUP.
951       // For fixed-length vectors we currently need an extra five--six
952       // instructions besides the MATCH.
953       InstructionCost Cost = 4;
954       if (isa<FixedVectorType>(RetTy))
955         Cost += 10;
956       return Cost;
957     }
958     break;
959   }
960   case Intrinsic::experimental_cttz_elts: {
961     EVT ArgVT = getTLI()->getValueType(DL, ICA.getArgTypes()[0]);
962     if (!getTLI()->shouldExpandCttzElements(ArgVT)) {
963       // This will consist of a SVE brkb and a cntp instruction. These
964       // typically have the same latency and half the throughput as a vector
965       // add instruction.
966       return 4;
967     }
968     break;
969   }
970   default:
971     break;
972   }
973   return BaseT::getIntrinsicInstrCost(ICA, CostKind);
974 }
975 
976 /// The function will remove redundant reinterprets casting in the presence
977 /// of the control flow
processPhiNode(InstCombiner & IC,IntrinsicInst & II)978 static std::optional<Instruction *> processPhiNode(InstCombiner &IC,
979                                                    IntrinsicInst &II) {
980   SmallVector<Instruction *, 32> Worklist;
981   auto RequiredType = II.getType();
982 
983   auto *PN = dyn_cast<PHINode>(II.getArgOperand(0));
984   assert(PN && "Expected Phi Node!");
985 
986   // Don't create a new Phi unless we can remove the old one.
987   if (!PN->hasOneUse())
988     return std::nullopt;
989 
990   for (Value *IncValPhi : PN->incoming_values()) {
991     auto *Reinterpret = dyn_cast<IntrinsicInst>(IncValPhi);
992     if (!Reinterpret ||
993         Reinterpret->getIntrinsicID() !=
994             Intrinsic::aarch64_sve_convert_to_svbool ||
995         RequiredType != Reinterpret->getArgOperand(0)->getType())
996       return std::nullopt;
997   }
998 
999   // Create the new Phi
1000   IC.Builder.SetInsertPoint(PN);
1001   PHINode *NPN = IC.Builder.CreatePHI(RequiredType, PN->getNumIncomingValues());
1002   Worklist.push_back(PN);
1003 
1004   for (unsigned I = 0; I < PN->getNumIncomingValues(); I++) {
1005     auto *Reinterpret = cast<Instruction>(PN->getIncomingValue(I));
1006     NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(I));
1007     Worklist.push_back(Reinterpret);
1008   }
1009 
1010   // Cleanup Phi Node and reinterprets
1011   return IC.replaceInstUsesWith(II, NPN);
1012 }
1013 
1014 // A collection of properties common to SVE intrinsics that allow for combines
1015 // to be written without needing to know the specific intrinsic.
1016 struct SVEIntrinsicInfo {
1017   //
1018   // Helper routines for common intrinsic definitions.
1019   //
1020 
1021   // e.g. llvm.aarch64.sve.add pg, op1, op2
1022   //        with IID ==> llvm.aarch64.sve.add_u
1023   static SVEIntrinsicInfo
defaultMergingOpSVEIntrinsicInfo1024   defaultMergingOp(Intrinsic::ID IID = Intrinsic::not_intrinsic) {
1025     return SVEIntrinsicInfo()
1026         .setGoverningPredicateOperandIdx(0)
1027         .setOperandIdxInactiveLanesTakenFrom(1)
1028         .setMatchingUndefIntrinsic(IID);
1029   }
1030 
1031   // e.g. llvm.aarch64.sve.neg inactive, pg, op
defaultMergingUnaryOpSVEIntrinsicInfo1032   static SVEIntrinsicInfo defaultMergingUnaryOp() {
1033     return SVEIntrinsicInfo()
1034         .setGoverningPredicateOperandIdx(1)
1035         .setOperandIdxInactiveLanesTakenFrom(0)
1036         .setOperandIdxWithNoActiveLanes(0);
1037   }
1038 
1039   // e.g. llvm.aarch64.sve.fcvtnt inactive, pg, op
defaultMergingUnaryNarrowingTopOpSVEIntrinsicInfo1040   static SVEIntrinsicInfo defaultMergingUnaryNarrowingTopOp() {
1041     return SVEIntrinsicInfo()
1042         .setGoverningPredicateOperandIdx(1)
1043         .setOperandIdxInactiveLanesTakenFrom(0);
1044   }
1045 
1046   // e.g. llvm.aarch64.sve.add_u pg, op1, op2
defaultUndefOpSVEIntrinsicInfo1047   static SVEIntrinsicInfo defaultUndefOp() {
1048     return SVEIntrinsicInfo()
1049         .setGoverningPredicateOperandIdx(0)
1050         .setInactiveLanesAreNotDefined();
1051   }
1052 
1053   // e.g. llvm.aarch64.sve.prf pg, ptr        (GPIndex = 0)
1054   //      llvm.aarch64.sve.st1 data, pg, ptr  (GPIndex = 1)
defaultVoidOpSVEIntrinsicInfo1055   static SVEIntrinsicInfo defaultVoidOp(unsigned GPIndex) {
1056     return SVEIntrinsicInfo()
1057         .setGoverningPredicateOperandIdx(GPIndex)
1058         .setInactiveLanesAreUnused();
1059   }
1060 
1061   // e.g. llvm.aarch64.sve.cmpeq pg, op1, op2
1062   //      llvm.aarch64.sve.ld1 pg, ptr
defaultZeroingOpSVEIntrinsicInfo1063   static SVEIntrinsicInfo defaultZeroingOp() {
1064     return SVEIntrinsicInfo()
1065         .setGoverningPredicateOperandIdx(0)
1066         .setInactiveLanesAreUnused()
1067         .setResultIsZeroInitialized();
1068   }
1069 
1070   // All properties relate to predication and thus having a general predicate
1071   // is the minimum requirement to say there is intrinsic info to act on.
operator boolSVEIntrinsicInfo1072   explicit operator bool() const { return hasGoverningPredicate(); }
1073 
1074   //
1075   // Properties relating to the governing predicate.
1076   //
1077 
hasGoverningPredicateSVEIntrinsicInfo1078   bool hasGoverningPredicate() const {
1079     return GoverningPredicateIdx != std::numeric_limits<unsigned>::max();
1080   }
1081 
getGoverningPredicateOperandIdxSVEIntrinsicInfo1082   unsigned getGoverningPredicateOperandIdx() const {
1083     assert(hasGoverningPredicate() && "Propery not set!");
1084     return GoverningPredicateIdx;
1085   }
1086 
setGoverningPredicateOperandIdxSVEIntrinsicInfo1087   SVEIntrinsicInfo &setGoverningPredicateOperandIdx(unsigned Index) {
1088     assert(!hasGoverningPredicate() && "Cannot set property twice!");
1089     GoverningPredicateIdx = Index;
1090     return *this;
1091   }
1092 
1093   //
1094   // Properties relating to operations the intrinsic could be transformed into.
1095   // NOTE: This does not mean such a transformation is always possible, but the
1096   // knowledge makes it possible to reuse existing optimisations without needing
1097   // to embed specific handling for each intrinsic. For example, instruction
1098   // simplification can be used to optimise an intrinsic's active lanes.
1099   //
1100 
hasMatchingUndefIntrinsicSVEIntrinsicInfo1101   bool hasMatchingUndefIntrinsic() const {
1102     return UndefIntrinsic != Intrinsic::not_intrinsic;
1103   }
1104 
getMatchingUndefIntrinsicSVEIntrinsicInfo1105   Intrinsic::ID getMatchingUndefIntrinsic() const {
1106     assert(hasMatchingUndefIntrinsic() && "Propery not set!");
1107     return UndefIntrinsic;
1108   }
1109 
setMatchingUndefIntrinsicSVEIntrinsicInfo1110   SVEIntrinsicInfo &setMatchingUndefIntrinsic(Intrinsic::ID IID) {
1111     assert(!hasMatchingUndefIntrinsic() && "Cannot set property twice!");
1112     UndefIntrinsic = IID;
1113     return *this;
1114   }
1115 
hasMatchingIROpodeSVEIntrinsicInfo1116   bool hasMatchingIROpode() const { return IROpcode != 0; }
1117 
getMatchingIROpodeSVEIntrinsicInfo1118   unsigned getMatchingIROpode() const {
1119     assert(hasMatchingIROpode() && "Propery not set!");
1120     return IROpcode;
1121   }
1122 
setMatchingIROpcodeSVEIntrinsicInfo1123   SVEIntrinsicInfo &setMatchingIROpcode(unsigned Opcode) {
1124     assert(!hasMatchingIROpode() && "Cannot set property twice!");
1125     IROpcode = Opcode;
1126     return *this;
1127   }
1128 
1129   //
1130   // Properties relating to the result of inactive lanes.
1131   //
1132 
inactiveLanesTakenFromOperandSVEIntrinsicInfo1133   bool inactiveLanesTakenFromOperand() const {
1134     return ResultLanes == InactiveLanesTakenFromOperand;
1135   }
1136 
getOperandIdxInactiveLanesTakenFromSVEIntrinsicInfo1137   unsigned getOperandIdxInactiveLanesTakenFrom() const {
1138     assert(inactiveLanesTakenFromOperand() && "Propery not set!");
1139     return OperandIdxForInactiveLanes;
1140   }
1141 
setOperandIdxInactiveLanesTakenFromSVEIntrinsicInfo1142   SVEIntrinsicInfo &setOperandIdxInactiveLanesTakenFrom(unsigned Index) {
1143     assert(ResultLanes == Uninitialized && "Cannot set property twice!");
1144     ResultLanes = InactiveLanesTakenFromOperand;
1145     OperandIdxForInactiveLanes = Index;
1146     return *this;
1147   }
1148 
inactiveLanesAreNotDefinedSVEIntrinsicInfo1149   bool inactiveLanesAreNotDefined() const {
1150     return ResultLanes == InactiveLanesAreNotDefined;
1151   }
1152 
setInactiveLanesAreNotDefinedSVEIntrinsicInfo1153   SVEIntrinsicInfo &setInactiveLanesAreNotDefined() {
1154     assert(ResultLanes == Uninitialized && "Cannot set property twice!");
1155     ResultLanes = InactiveLanesAreNotDefined;
1156     return *this;
1157   }
1158 
inactiveLanesAreUnusedSVEIntrinsicInfo1159   bool inactiveLanesAreUnused() const {
1160     return ResultLanes == InactiveLanesAreUnused;
1161   }
1162 
setInactiveLanesAreUnusedSVEIntrinsicInfo1163   SVEIntrinsicInfo &setInactiveLanesAreUnused() {
1164     assert(ResultLanes == Uninitialized && "Cannot set property twice!");
1165     ResultLanes = InactiveLanesAreUnused;
1166     return *this;
1167   }
1168 
1169   // NOTE: Whilst not limited to only inactive lanes, the common use case is:
1170   // inactiveLanesAreZeroed =
1171   //     resultIsZeroInitialized() && inactiveLanesAreUnused()
resultIsZeroInitializedSVEIntrinsicInfo1172   bool resultIsZeroInitialized() const { return ResultIsZeroInitialized; }
1173 
setResultIsZeroInitializedSVEIntrinsicInfo1174   SVEIntrinsicInfo &setResultIsZeroInitialized() {
1175     ResultIsZeroInitialized = true;
1176     return *this;
1177   }
1178 
1179   //
1180   // The first operand of unary merging operations is typically only used to
1181   // set the result for inactive lanes. Knowing this allows us to deadcode the
1182   // operand when we can prove there are no inactive lanes.
1183   //
1184 
hasOperandWithNoActiveLanesSVEIntrinsicInfo1185   bool hasOperandWithNoActiveLanes() const {
1186     return OperandIdxWithNoActiveLanes != std::numeric_limits<unsigned>::max();
1187   }
1188 
getOperandIdxWithNoActiveLanesSVEIntrinsicInfo1189   unsigned getOperandIdxWithNoActiveLanes() const {
1190     assert(hasOperandWithNoActiveLanes() && "Propery not set!");
1191     return OperandIdxWithNoActiveLanes;
1192   }
1193 
setOperandIdxWithNoActiveLanesSVEIntrinsicInfo1194   SVEIntrinsicInfo &setOperandIdxWithNoActiveLanes(unsigned Index) {
1195     assert(!hasOperandWithNoActiveLanes() && "Cannot set property twice!");
1196     OperandIdxWithNoActiveLanes = Index;
1197     return *this;
1198   }
1199 
1200 private:
1201   unsigned GoverningPredicateIdx = std::numeric_limits<unsigned>::max();
1202 
1203   Intrinsic::ID UndefIntrinsic = Intrinsic::not_intrinsic;
1204   unsigned IROpcode = 0;
1205 
1206   enum PredicationStyle {
1207     Uninitialized,
1208     InactiveLanesTakenFromOperand,
1209     InactiveLanesAreNotDefined,
1210     InactiveLanesAreUnused
1211   } ResultLanes = Uninitialized;
1212 
1213   bool ResultIsZeroInitialized = false;
1214   unsigned OperandIdxForInactiveLanes = std::numeric_limits<unsigned>::max();
1215   unsigned OperandIdxWithNoActiveLanes = std::numeric_limits<unsigned>::max();
1216 };
1217 
constructSVEIntrinsicInfo(IntrinsicInst & II)1218 static SVEIntrinsicInfo constructSVEIntrinsicInfo(IntrinsicInst &II) {
1219   // Some SVE intrinsics do not use scalable vector types, but since they are
1220   // not relevant from an SVEIntrinsicInfo perspective, they are also ignored.
1221   if (!isa<ScalableVectorType>(II.getType()) &&
1222       all_of(II.args(), [&](const Value *V) {
1223         return !isa<ScalableVectorType>(V->getType());
1224       }))
1225     return SVEIntrinsicInfo();
1226 
1227   Intrinsic::ID IID = II.getIntrinsicID();
1228   switch (IID) {
1229   default:
1230     break;
1231   case Intrinsic::aarch64_sve_fcvt_bf16f32_v2:
1232   case Intrinsic::aarch64_sve_fcvt_f16f32:
1233   case Intrinsic::aarch64_sve_fcvt_f16f64:
1234   case Intrinsic::aarch64_sve_fcvt_f32f16:
1235   case Intrinsic::aarch64_sve_fcvt_f32f64:
1236   case Intrinsic::aarch64_sve_fcvt_f64f16:
1237   case Intrinsic::aarch64_sve_fcvt_f64f32:
1238   case Intrinsic::aarch64_sve_fcvtlt_f32f16:
1239   case Intrinsic::aarch64_sve_fcvtlt_f64f32:
1240   case Intrinsic::aarch64_sve_fcvtx_f32f64:
1241   case Intrinsic::aarch64_sve_fcvtzs:
1242   case Intrinsic::aarch64_sve_fcvtzs_i32f16:
1243   case Intrinsic::aarch64_sve_fcvtzs_i32f64:
1244   case Intrinsic::aarch64_sve_fcvtzs_i64f16:
1245   case Intrinsic::aarch64_sve_fcvtzs_i64f32:
1246   case Intrinsic::aarch64_sve_fcvtzu:
1247   case Intrinsic::aarch64_sve_fcvtzu_i32f16:
1248   case Intrinsic::aarch64_sve_fcvtzu_i32f64:
1249   case Intrinsic::aarch64_sve_fcvtzu_i64f16:
1250   case Intrinsic::aarch64_sve_fcvtzu_i64f32:
1251   case Intrinsic::aarch64_sve_scvtf:
1252   case Intrinsic::aarch64_sve_scvtf_f16i32:
1253   case Intrinsic::aarch64_sve_scvtf_f16i64:
1254   case Intrinsic::aarch64_sve_scvtf_f32i64:
1255   case Intrinsic::aarch64_sve_scvtf_f64i32:
1256   case Intrinsic::aarch64_sve_ucvtf:
1257   case Intrinsic::aarch64_sve_ucvtf_f16i32:
1258   case Intrinsic::aarch64_sve_ucvtf_f16i64:
1259   case Intrinsic::aarch64_sve_ucvtf_f32i64:
1260   case Intrinsic::aarch64_sve_ucvtf_f64i32:
1261     return SVEIntrinsicInfo::defaultMergingUnaryOp();
1262 
1263   case Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2:
1264   case Intrinsic::aarch64_sve_fcvtnt_f16f32:
1265   case Intrinsic::aarch64_sve_fcvtnt_f32f64:
1266   case Intrinsic::aarch64_sve_fcvtxnt_f32f64:
1267     return SVEIntrinsicInfo::defaultMergingUnaryNarrowingTopOp();
1268 
1269   case Intrinsic::aarch64_sve_fabd:
1270     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fabd_u);
1271   case Intrinsic::aarch64_sve_fadd:
1272     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fadd_u)
1273         .setMatchingIROpcode(Instruction::FAdd);
1274   case Intrinsic::aarch64_sve_fdiv:
1275     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fdiv_u)
1276         .setMatchingIROpcode(Instruction::FDiv);
1277   case Intrinsic::aarch64_sve_fmax:
1278     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fmax_u);
1279   case Intrinsic::aarch64_sve_fmaxnm:
1280     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fmaxnm_u);
1281   case Intrinsic::aarch64_sve_fmin:
1282     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fmin_u);
1283   case Intrinsic::aarch64_sve_fminnm:
1284     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fminnm_u);
1285   case Intrinsic::aarch64_sve_fmla:
1286     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fmla_u);
1287   case Intrinsic::aarch64_sve_fmls:
1288     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fmls_u);
1289   case Intrinsic::aarch64_sve_fmul:
1290     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fmul_u)
1291         .setMatchingIROpcode(Instruction::FMul);
1292   case Intrinsic::aarch64_sve_fmulx:
1293     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fmulx_u);
1294   case Intrinsic::aarch64_sve_fnmla:
1295     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fnmla_u);
1296   case Intrinsic::aarch64_sve_fnmls:
1297     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fnmls_u);
1298   case Intrinsic::aarch64_sve_fsub:
1299     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fsub_u)
1300         .setMatchingIROpcode(Instruction::FSub);
1301   case Intrinsic::aarch64_sve_add:
1302     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_add_u)
1303         .setMatchingIROpcode(Instruction::Add);
1304   case Intrinsic::aarch64_sve_mla:
1305     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_mla_u);
1306   case Intrinsic::aarch64_sve_mls:
1307     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_mls_u);
1308   case Intrinsic::aarch64_sve_mul:
1309     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_mul_u)
1310         .setMatchingIROpcode(Instruction::Mul);
1311   case Intrinsic::aarch64_sve_sabd:
1312     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_sabd_u);
1313   case Intrinsic::aarch64_sve_sdiv:
1314     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_sdiv_u)
1315         .setMatchingIROpcode(Instruction::SDiv);
1316   case Intrinsic::aarch64_sve_smax:
1317     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_smax_u);
1318   case Intrinsic::aarch64_sve_smin:
1319     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_smin_u);
1320   case Intrinsic::aarch64_sve_smulh:
1321     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_smulh_u);
1322   case Intrinsic::aarch64_sve_sub:
1323     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_sub_u)
1324         .setMatchingIROpcode(Instruction::Sub);
1325   case Intrinsic::aarch64_sve_uabd:
1326     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_uabd_u);
1327   case Intrinsic::aarch64_sve_udiv:
1328     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_udiv_u)
1329         .setMatchingIROpcode(Instruction::UDiv);
1330   case Intrinsic::aarch64_sve_umax:
1331     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_umax_u);
1332   case Intrinsic::aarch64_sve_umin:
1333     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_umin_u);
1334   case Intrinsic::aarch64_sve_umulh:
1335     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_umulh_u);
1336   case Intrinsic::aarch64_sve_asr:
1337     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_asr_u)
1338         .setMatchingIROpcode(Instruction::AShr);
1339   case Intrinsic::aarch64_sve_lsl:
1340     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_lsl_u)
1341         .setMatchingIROpcode(Instruction::Shl);
1342   case Intrinsic::aarch64_sve_lsr:
1343     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_lsr_u)
1344         .setMatchingIROpcode(Instruction::LShr);
1345   case Intrinsic::aarch64_sve_and:
1346     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_and_u)
1347         .setMatchingIROpcode(Instruction::And);
1348   case Intrinsic::aarch64_sve_bic:
1349     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_bic_u);
1350   case Intrinsic::aarch64_sve_eor:
1351     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_eor_u)
1352         .setMatchingIROpcode(Instruction::Xor);
1353   case Intrinsic::aarch64_sve_orr:
1354     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_orr_u)
1355         .setMatchingIROpcode(Instruction::Or);
1356   case Intrinsic::aarch64_sve_sqsub:
1357     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_sqsub_u);
1358   case Intrinsic::aarch64_sve_uqsub:
1359     return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_uqsub_u);
1360 
1361   case Intrinsic::aarch64_sve_add_u:
1362     return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1363         Instruction::Add);
1364   case Intrinsic::aarch64_sve_and_u:
1365     return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1366         Instruction::And);
1367   case Intrinsic::aarch64_sve_asr_u:
1368     return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1369         Instruction::AShr);
1370   case Intrinsic::aarch64_sve_eor_u:
1371     return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1372         Instruction::Xor);
1373   case Intrinsic::aarch64_sve_fadd_u:
1374     return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1375         Instruction::FAdd);
1376   case Intrinsic::aarch64_sve_fdiv_u:
1377     return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1378         Instruction::FDiv);
1379   case Intrinsic::aarch64_sve_fmul_u:
1380     return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1381         Instruction::FMul);
1382   case Intrinsic::aarch64_sve_fsub_u:
1383     return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1384         Instruction::FSub);
1385   case Intrinsic::aarch64_sve_lsl_u:
1386     return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1387         Instruction::Shl);
1388   case Intrinsic::aarch64_sve_lsr_u:
1389     return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1390         Instruction::LShr);
1391   case Intrinsic::aarch64_sve_mul_u:
1392     return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1393         Instruction::Mul);
1394   case Intrinsic::aarch64_sve_orr_u:
1395     return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1396         Instruction::Or);
1397   case Intrinsic::aarch64_sve_sdiv_u:
1398     return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1399         Instruction::SDiv);
1400   case Intrinsic::aarch64_sve_sub_u:
1401     return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1402         Instruction::Sub);
1403   case Intrinsic::aarch64_sve_udiv_u:
1404     return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1405         Instruction::UDiv);
1406 
1407   case Intrinsic::aarch64_sve_addqv:
1408   case Intrinsic::aarch64_sve_and_z:
1409   case Intrinsic::aarch64_sve_bic_z:
1410   case Intrinsic::aarch64_sve_brka_z:
1411   case Intrinsic::aarch64_sve_brkb_z:
1412   case Intrinsic::aarch64_sve_brkn_z:
1413   case Intrinsic::aarch64_sve_brkpa_z:
1414   case Intrinsic::aarch64_sve_brkpb_z:
1415   case Intrinsic::aarch64_sve_cntp:
1416   case Intrinsic::aarch64_sve_compact:
1417   case Intrinsic::aarch64_sve_eor_z:
1418   case Intrinsic::aarch64_sve_eorv:
1419   case Intrinsic::aarch64_sve_eorqv:
1420   case Intrinsic::aarch64_sve_nand_z:
1421   case Intrinsic::aarch64_sve_nor_z:
1422   case Intrinsic::aarch64_sve_orn_z:
1423   case Intrinsic::aarch64_sve_orr_z:
1424   case Intrinsic::aarch64_sve_orv:
1425   case Intrinsic::aarch64_sve_orqv:
1426   case Intrinsic::aarch64_sve_pnext:
1427   case Intrinsic::aarch64_sve_rdffr_z:
1428   case Intrinsic::aarch64_sve_saddv:
1429   case Intrinsic::aarch64_sve_uaddv:
1430   case Intrinsic::aarch64_sve_umaxv:
1431   case Intrinsic::aarch64_sve_umaxqv:
1432   case Intrinsic::aarch64_sve_cmpeq:
1433   case Intrinsic::aarch64_sve_cmpeq_wide:
1434   case Intrinsic::aarch64_sve_cmpge:
1435   case Intrinsic::aarch64_sve_cmpge_wide:
1436   case Intrinsic::aarch64_sve_cmpgt:
1437   case Intrinsic::aarch64_sve_cmpgt_wide:
1438   case Intrinsic::aarch64_sve_cmphi:
1439   case Intrinsic::aarch64_sve_cmphi_wide:
1440   case Intrinsic::aarch64_sve_cmphs:
1441   case Intrinsic::aarch64_sve_cmphs_wide:
1442   case Intrinsic::aarch64_sve_cmple_wide:
1443   case Intrinsic::aarch64_sve_cmplo_wide:
1444   case Intrinsic::aarch64_sve_cmpls_wide:
1445   case Intrinsic::aarch64_sve_cmplt_wide:
1446   case Intrinsic::aarch64_sve_cmpne:
1447   case Intrinsic::aarch64_sve_cmpne_wide:
1448   case Intrinsic::aarch64_sve_facge:
1449   case Intrinsic::aarch64_sve_facgt:
1450   case Intrinsic::aarch64_sve_fcmpeq:
1451   case Intrinsic::aarch64_sve_fcmpge:
1452   case Intrinsic::aarch64_sve_fcmpgt:
1453   case Intrinsic::aarch64_sve_fcmpne:
1454   case Intrinsic::aarch64_sve_fcmpuo:
1455   case Intrinsic::aarch64_sve_ld1:
1456   case Intrinsic::aarch64_sve_ld1_gather:
1457   case Intrinsic::aarch64_sve_ld1_gather_index:
1458   case Intrinsic::aarch64_sve_ld1_gather_scalar_offset:
1459   case Intrinsic::aarch64_sve_ld1_gather_sxtw:
1460   case Intrinsic::aarch64_sve_ld1_gather_sxtw_index:
1461   case Intrinsic::aarch64_sve_ld1_gather_uxtw:
1462   case Intrinsic::aarch64_sve_ld1_gather_uxtw_index:
1463   case Intrinsic::aarch64_sve_ld1q_gather_index:
1464   case Intrinsic::aarch64_sve_ld1q_gather_scalar_offset:
1465   case Intrinsic::aarch64_sve_ld1q_gather_vector_offset:
1466   case Intrinsic::aarch64_sve_ld1ro:
1467   case Intrinsic::aarch64_sve_ld1rq:
1468   case Intrinsic::aarch64_sve_ld1udq:
1469   case Intrinsic::aarch64_sve_ld1uwq:
1470   case Intrinsic::aarch64_sve_ld2_sret:
1471   case Intrinsic::aarch64_sve_ld2q_sret:
1472   case Intrinsic::aarch64_sve_ld3_sret:
1473   case Intrinsic::aarch64_sve_ld3q_sret:
1474   case Intrinsic::aarch64_sve_ld4_sret:
1475   case Intrinsic::aarch64_sve_ld4q_sret:
1476   case Intrinsic::aarch64_sve_ldff1:
1477   case Intrinsic::aarch64_sve_ldff1_gather:
1478   case Intrinsic::aarch64_sve_ldff1_gather_index:
1479   case Intrinsic::aarch64_sve_ldff1_gather_scalar_offset:
1480   case Intrinsic::aarch64_sve_ldff1_gather_sxtw:
1481   case Intrinsic::aarch64_sve_ldff1_gather_sxtw_index:
1482   case Intrinsic::aarch64_sve_ldff1_gather_uxtw:
1483   case Intrinsic::aarch64_sve_ldff1_gather_uxtw_index:
1484   case Intrinsic::aarch64_sve_ldnf1:
1485   case Intrinsic::aarch64_sve_ldnt1:
1486   case Intrinsic::aarch64_sve_ldnt1_gather:
1487   case Intrinsic::aarch64_sve_ldnt1_gather_index:
1488   case Intrinsic::aarch64_sve_ldnt1_gather_scalar_offset:
1489   case Intrinsic::aarch64_sve_ldnt1_gather_uxtw:
1490     return SVEIntrinsicInfo::defaultZeroingOp();
1491 
1492   case Intrinsic::aarch64_sve_prf:
1493   case Intrinsic::aarch64_sve_prfb_gather_index:
1494   case Intrinsic::aarch64_sve_prfb_gather_scalar_offset:
1495   case Intrinsic::aarch64_sve_prfb_gather_sxtw_index:
1496   case Intrinsic::aarch64_sve_prfb_gather_uxtw_index:
1497   case Intrinsic::aarch64_sve_prfd_gather_index:
1498   case Intrinsic::aarch64_sve_prfd_gather_scalar_offset:
1499   case Intrinsic::aarch64_sve_prfd_gather_sxtw_index:
1500   case Intrinsic::aarch64_sve_prfd_gather_uxtw_index:
1501   case Intrinsic::aarch64_sve_prfh_gather_index:
1502   case Intrinsic::aarch64_sve_prfh_gather_scalar_offset:
1503   case Intrinsic::aarch64_sve_prfh_gather_sxtw_index:
1504   case Intrinsic::aarch64_sve_prfh_gather_uxtw_index:
1505   case Intrinsic::aarch64_sve_prfw_gather_index:
1506   case Intrinsic::aarch64_sve_prfw_gather_scalar_offset:
1507   case Intrinsic::aarch64_sve_prfw_gather_sxtw_index:
1508   case Intrinsic::aarch64_sve_prfw_gather_uxtw_index:
1509     return SVEIntrinsicInfo::defaultVoidOp(0);
1510 
1511   case Intrinsic::aarch64_sve_st1_scatter:
1512   case Intrinsic::aarch64_sve_st1_scatter_scalar_offset:
1513   case Intrinsic::aarch64_sve_st1_scatter_sxtw:
1514   case Intrinsic::aarch64_sve_st1_scatter_sxtw_index:
1515   case Intrinsic::aarch64_sve_st1_scatter_uxtw:
1516   case Intrinsic::aarch64_sve_st1_scatter_uxtw_index:
1517   case Intrinsic::aarch64_sve_st1dq:
1518   case Intrinsic::aarch64_sve_st1q_scatter_index:
1519   case Intrinsic::aarch64_sve_st1q_scatter_scalar_offset:
1520   case Intrinsic::aarch64_sve_st1q_scatter_vector_offset:
1521   case Intrinsic::aarch64_sve_st1wq:
1522   case Intrinsic::aarch64_sve_stnt1:
1523   case Intrinsic::aarch64_sve_stnt1_scatter:
1524   case Intrinsic::aarch64_sve_stnt1_scatter_index:
1525   case Intrinsic::aarch64_sve_stnt1_scatter_scalar_offset:
1526   case Intrinsic::aarch64_sve_stnt1_scatter_uxtw:
1527     return SVEIntrinsicInfo::defaultVoidOp(1);
1528   case Intrinsic::aarch64_sve_st2:
1529   case Intrinsic::aarch64_sve_st2q:
1530     return SVEIntrinsicInfo::defaultVoidOp(2);
1531   case Intrinsic::aarch64_sve_st3:
1532   case Intrinsic::aarch64_sve_st3q:
1533     return SVEIntrinsicInfo::defaultVoidOp(3);
1534   case Intrinsic::aarch64_sve_st4:
1535   case Intrinsic::aarch64_sve_st4q:
1536     return SVEIntrinsicInfo::defaultVoidOp(4);
1537   }
1538 
1539   return SVEIntrinsicInfo();
1540 }
1541 
isAllActivePredicate(Value * Pred)1542 static bool isAllActivePredicate(Value *Pred) {
1543   // Look through convert.from.svbool(convert.to.svbool(...) chain.
1544   Value *UncastedPred;
1545   if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_from_svbool>(
1546                       m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>(
1547                           m_Value(UncastedPred)))))
1548     // If the predicate has the same or less lanes than the uncasted
1549     // predicate then we know the casting has no effect.
1550     if (cast<ScalableVectorType>(Pred->getType())->getMinNumElements() <=
1551         cast<ScalableVectorType>(UncastedPred->getType())->getMinNumElements())
1552       Pred = UncastedPred;
1553   auto *C = dyn_cast<Constant>(Pred);
1554   return (C && C->isAllOnesValue());
1555 }
1556 
1557 // Simplify `V` by only considering the operations that affect active lanes.
1558 // This function should only return existing Values or newly created Constants.
stripInactiveLanes(Value * V,const Value * Pg)1559 static Value *stripInactiveLanes(Value *V, const Value *Pg) {
1560   auto *Dup = dyn_cast<IntrinsicInst>(V);
1561   if (Dup && Dup->getIntrinsicID() == Intrinsic::aarch64_sve_dup &&
1562       Dup->getOperand(1) == Pg && isa<Constant>(Dup->getOperand(2)))
1563     return ConstantVector::getSplat(
1564         cast<VectorType>(V->getType())->getElementCount(),
1565         cast<Constant>(Dup->getOperand(2)));
1566 
1567   return V;
1568 }
1569 
1570 static std::optional<Instruction *>
simplifySVEIntrinsicBinOp(InstCombiner & IC,IntrinsicInst & II,const SVEIntrinsicInfo & IInfo)1571 simplifySVEIntrinsicBinOp(InstCombiner &IC, IntrinsicInst &II,
1572                           const SVEIntrinsicInfo &IInfo) {
1573   const unsigned Opc = IInfo.getMatchingIROpode();
1574   assert(Instruction::isBinaryOp(Opc) && "Expected a binary operation!");
1575 
1576   Value *Pg = II.getOperand(0);
1577   Value *Op1 = II.getOperand(1);
1578   Value *Op2 = II.getOperand(2);
1579   const DataLayout &DL = II.getDataLayout();
1580 
1581   // Canonicalise constants to the RHS.
1582   if (Instruction::isCommutative(Opc) && IInfo.inactiveLanesAreNotDefined() &&
1583       isa<Constant>(Op1) && !isa<Constant>(Op2)) {
1584     IC.replaceOperand(II, 1, Op2);
1585     IC.replaceOperand(II, 2, Op1);
1586     return &II;
1587   }
1588 
1589   // Only active lanes matter when simplifying the operation.
1590   Op1 = stripInactiveLanes(Op1, Pg);
1591   Op2 = stripInactiveLanes(Op2, Pg);
1592 
1593   Value *SimpleII;
1594   if (auto FII = dyn_cast<FPMathOperator>(&II))
1595     SimpleII = simplifyBinOp(Opc, Op1, Op2, FII->getFastMathFlags(), DL);
1596   else
1597     SimpleII = simplifyBinOp(Opc, Op1, Op2, DL);
1598 
1599   // An SVE intrinsic's result is always defined. However, this is not the case
1600   // for its equivalent IR instruction (e.g. when shifting by an amount more
1601   // than the data's bitwidth). Simplifications to an undefined result must be
1602   // ignored to preserve the intrinsic's expected behaviour.
1603   if (!SimpleII || isa<UndefValue>(SimpleII))
1604     return std::nullopt;
1605 
1606   if (IInfo.inactiveLanesAreNotDefined())
1607     return IC.replaceInstUsesWith(II, SimpleII);
1608 
1609   Value *Inactive = II.getOperand(IInfo.getOperandIdxInactiveLanesTakenFrom());
1610 
1611   // The intrinsic does nothing (e.g. sve.mul(pg, A, 1.0)).
1612   if (SimpleII == Inactive)
1613     return IC.replaceInstUsesWith(II, SimpleII);
1614 
1615   // Inactive lanes must be preserved.
1616   SimpleII = IC.Builder.CreateSelect(Pg, SimpleII, Inactive);
1617   return IC.replaceInstUsesWith(II, SimpleII);
1618 }
1619 
1620 // Use SVE intrinsic info to eliminate redundant operands and/or canonicalise
1621 // to operations with less strict inactive lane requirements.
1622 static std::optional<Instruction *>
simplifySVEIntrinsic(InstCombiner & IC,IntrinsicInst & II,const SVEIntrinsicInfo & IInfo)1623 simplifySVEIntrinsic(InstCombiner &IC, IntrinsicInst &II,
1624                      const SVEIntrinsicInfo &IInfo) {
1625   if (!IInfo.hasGoverningPredicate())
1626     return std::nullopt;
1627 
1628   auto *OpPredicate = II.getOperand(IInfo.getGoverningPredicateOperandIdx());
1629 
1630   // If there are no active lanes.
1631   if (match(OpPredicate, m_ZeroInt())) {
1632     if (IInfo.inactiveLanesTakenFromOperand())
1633       return IC.replaceInstUsesWith(
1634           II, II.getOperand(IInfo.getOperandIdxInactiveLanesTakenFrom()));
1635 
1636     if (IInfo.inactiveLanesAreUnused()) {
1637       if (IInfo.resultIsZeroInitialized())
1638         IC.replaceInstUsesWith(II, Constant::getNullValue(II.getType()));
1639 
1640       return IC.eraseInstFromFunction(II);
1641     }
1642   }
1643 
1644   // If there are no inactive lanes.
1645   if (isAllActivePredicate(OpPredicate)) {
1646     if (IInfo.hasOperandWithNoActiveLanes()) {
1647       unsigned OpIdx = IInfo.getOperandIdxWithNoActiveLanes();
1648       if (!isa<UndefValue>(II.getOperand(OpIdx)))
1649         return IC.replaceOperand(II, OpIdx, UndefValue::get(II.getType()));
1650     }
1651 
1652     if (IInfo.hasMatchingUndefIntrinsic()) {
1653       auto *NewDecl = Intrinsic::getOrInsertDeclaration(
1654           II.getModule(), IInfo.getMatchingUndefIntrinsic(), {II.getType()});
1655       II.setCalledFunction(NewDecl);
1656       return &II;
1657     }
1658   }
1659 
1660   // Operation specific simplifications.
1661   if (IInfo.hasMatchingIROpode() &&
1662       Instruction::isBinaryOp(IInfo.getMatchingIROpode()))
1663     return simplifySVEIntrinsicBinOp(IC, II, IInfo);
1664 
1665   return std::nullopt;
1666 }
1667 
1668 // (from_svbool (binop (to_svbool pred) (svbool_t _) (svbool_t _))))
1669 // => (binop (pred) (from_svbool _) (from_svbool _))
1670 //
1671 // The above transformation eliminates a `to_svbool` in the predicate
1672 // operand of bitwise operation `binop` by narrowing the vector width of
1673 // the operation. For example, it would convert a `<vscale x 16 x i1>
1674 // and` into a `<vscale x 4 x i1> and`. This is profitable because
1675 // to_svbool must zero the new lanes during widening, whereas
1676 // from_svbool is free.
1677 static std::optional<Instruction *>
tryCombineFromSVBoolBinOp(InstCombiner & IC,IntrinsicInst & II)1678 tryCombineFromSVBoolBinOp(InstCombiner &IC, IntrinsicInst &II) {
1679   auto BinOp = dyn_cast<IntrinsicInst>(II.getOperand(0));
1680   if (!BinOp)
1681     return std::nullopt;
1682 
1683   auto IntrinsicID = BinOp->getIntrinsicID();
1684   switch (IntrinsicID) {
1685   case Intrinsic::aarch64_sve_and_z:
1686   case Intrinsic::aarch64_sve_bic_z:
1687   case Intrinsic::aarch64_sve_eor_z:
1688   case Intrinsic::aarch64_sve_nand_z:
1689   case Intrinsic::aarch64_sve_nor_z:
1690   case Intrinsic::aarch64_sve_orn_z:
1691   case Intrinsic::aarch64_sve_orr_z:
1692     break;
1693   default:
1694     return std::nullopt;
1695   }
1696 
1697   auto BinOpPred = BinOp->getOperand(0);
1698   auto BinOpOp1 = BinOp->getOperand(1);
1699   auto BinOpOp2 = BinOp->getOperand(2);
1700 
1701   auto PredIntr = dyn_cast<IntrinsicInst>(BinOpPred);
1702   if (!PredIntr ||
1703       PredIntr->getIntrinsicID() != Intrinsic::aarch64_sve_convert_to_svbool)
1704     return std::nullopt;
1705 
1706   auto PredOp = PredIntr->getOperand(0);
1707   auto PredOpTy = cast<VectorType>(PredOp->getType());
1708   if (PredOpTy != II.getType())
1709     return std::nullopt;
1710 
1711   SmallVector<Value *> NarrowedBinOpArgs = {PredOp};
1712   auto NarrowBinOpOp1 = IC.Builder.CreateIntrinsic(
1713       Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp1});
1714   NarrowedBinOpArgs.push_back(NarrowBinOpOp1);
1715   if (BinOpOp1 == BinOpOp2)
1716     NarrowedBinOpArgs.push_back(NarrowBinOpOp1);
1717   else
1718     NarrowedBinOpArgs.push_back(IC.Builder.CreateIntrinsic(
1719         Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp2}));
1720 
1721   auto NarrowedBinOp =
1722       IC.Builder.CreateIntrinsic(IntrinsicID, {PredOpTy}, NarrowedBinOpArgs);
1723   return IC.replaceInstUsesWith(II, NarrowedBinOp);
1724 }
1725 
1726 static std::optional<Instruction *>
instCombineConvertFromSVBool(InstCombiner & IC,IntrinsicInst & II)1727 instCombineConvertFromSVBool(InstCombiner &IC, IntrinsicInst &II) {
1728   // If the reinterpret instruction operand is a PHI Node
1729   if (isa<PHINode>(II.getArgOperand(0)))
1730     return processPhiNode(IC, II);
1731 
1732   if (auto BinOpCombine = tryCombineFromSVBoolBinOp(IC, II))
1733     return BinOpCombine;
1734 
1735   // Ignore converts to/from svcount_t.
1736   if (isa<TargetExtType>(II.getArgOperand(0)->getType()) ||
1737       isa<TargetExtType>(II.getType()))
1738     return std::nullopt;
1739 
1740   SmallVector<Instruction *, 32> CandidatesForRemoval;
1741   Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr;
1742 
1743   const auto *IVTy = cast<VectorType>(II.getType());
1744 
1745   // Walk the chain of conversions.
1746   while (Cursor) {
1747     // If the type of the cursor has fewer lanes than the final result, zeroing
1748     // must take place, which breaks the equivalence chain.
1749     const auto *CursorVTy = cast<VectorType>(Cursor->getType());
1750     if (CursorVTy->getElementCount().getKnownMinValue() <
1751         IVTy->getElementCount().getKnownMinValue())
1752       break;
1753 
1754     // If the cursor has the same type as I, it is a viable replacement.
1755     if (Cursor->getType() == IVTy)
1756       EarliestReplacement = Cursor;
1757 
1758     auto *IntrinsicCursor = dyn_cast<IntrinsicInst>(Cursor);
1759 
1760     // If this is not an SVE conversion intrinsic, this is the end of the chain.
1761     if (!IntrinsicCursor || !(IntrinsicCursor->getIntrinsicID() ==
1762                                   Intrinsic::aarch64_sve_convert_to_svbool ||
1763                               IntrinsicCursor->getIntrinsicID() ==
1764                                   Intrinsic::aarch64_sve_convert_from_svbool))
1765       break;
1766 
1767     CandidatesForRemoval.insert(CandidatesForRemoval.begin(), IntrinsicCursor);
1768     Cursor = IntrinsicCursor->getOperand(0);
1769   }
1770 
1771   // If no viable replacement in the conversion chain was found, there is
1772   // nothing to do.
1773   if (!EarliestReplacement)
1774     return std::nullopt;
1775 
1776   return IC.replaceInstUsesWith(II, EarliestReplacement);
1777 }
1778 
instCombineSVESel(InstCombiner & IC,IntrinsicInst & II)1779 static std::optional<Instruction *> instCombineSVESel(InstCombiner &IC,
1780                                                       IntrinsicInst &II) {
1781   // svsel(ptrue, x, y) => x
1782   auto *OpPredicate = II.getOperand(0);
1783   if (isAllActivePredicate(OpPredicate))
1784     return IC.replaceInstUsesWith(II, II.getOperand(1));
1785 
1786   auto Select =
1787       IC.Builder.CreateSelect(OpPredicate, II.getOperand(1), II.getOperand(2));
1788   return IC.replaceInstUsesWith(II, Select);
1789 }
1790 
instCombineSVEDup(InstCombiner & IC,IntrinsicInst & II)1791 static std::optional<Instruction *> instCombineSVEDup(InstCombiner &IC,
1792                                                       IntrinsicInst &II) {
1793   IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
1794   if (!Pg)
1795     return std::nullopt;
1796 
1797   if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
1798     return std::nullopt;
1799 
1800   const auto PTruePattern =
1801       cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
1802   if (PTruePattern != AArch64SVEPredPattern::vl1)
1803     return std::nullopt;
1804 
1805   // The intrinsic is inserting into lane zero so use an insert instead.
1806   auto *IdxTy = Type::getInt64Ty(II.getContext());
1807   auto *Insert = InsertElementInst::Create(
1808       II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0));
1809   Insert->insertBefore(II.getIterator());
1810   Insert->takeName(&II);
1811 
1812   return IC.replaceInstUsesWith(II, Insert);
1813 }
1814 
instCombineSVEDupX(InstCombiner & IC,IntrinsicInst & II)1815 static std::optional<Instruction *> instCombineSVEDupX(InstCombiner &IC,
1816                                                        IntrinsicInst &II) {
1817   // Replace DupX with a regular IR splat.
1818   auto *RetTy = cast<ScalableVectorType>(II.getType());
1819   Value *Splat = IC.Builder.CreateVectorSplat(RetTy->getElementCount(),
1820                                               II.getArgOperand(0));
1821   Splat->takeName(&II);
1822   return IC.replaceInstUsesWith(II, Splat);
1823 }
1824 
instCombineSVECmpNE(InstCombiner & IC,IntrinsicInst & II)1825 static std::optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
1826                                                         IntrinsicInst &II) {
1827   LLVMContext &Ctx = II.getContext();
1828 
1829   if (!isAllActivePredicate(II.getArgOperand(0)))
1830     return std::nullopt;
1831 
1832   // Check that we have a compare of zero..
1833   auto *SplatValue =
1834       dyn_cast_or_null<ConstantInt>(getSplatValue(II.getArgOperand(2)));
1835   if (!SplatValue || !SplatValue->isZero())
1836     return std::nullopt;
1837 
1838   // ..against a dupq
1839   auto *DupQLane = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
1840   if (!DupQLane ||
1841       DupQLane->getIntrinsicID() != Intrinsic::aarch64_sve_dupq_lane)
1842     return std::nullopt;
1843 
1844   // Where the dupq is a lane 0 replicate of a vector insert
1845   auto *DupQLaneIdx = dyn_cast<ConstantInt>(DupQLane->getArgOperand(1));
1846   if (!DupQLaneIdx || !DupQLaneIdx->isZero())
1847     return std::nullopt;
1848 
1849   auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0));
1850   if (!VecIns || VecIns->getIntrinsicID() != Intrinsic::vector_insert)
1851     return std::nullopt;
1852 
1853   // Where the vector insert is a fixed constant vector insert into undef at
1854   // index zero
1855   if (!isa<UndefValue>(VecIns->getArgOperand(0)))
1856     return std::nullopt;
1857 
1858   if (!cast<ConstantInt>(VecIns->getArgOperand(2))->isZero())
1859     return std::nullopt;
1860 
1861   auto *ConstVec = dyn_cast<Constant>(VecIns->getArgOperand(1));
1862   if (!ConstVec)
1863     return std::nullopt;
1864 
1865   auto *VecTy = dyn_cast<FixedVectorType>(ConstVec->getType());
1866   auto *OutTy = dyn_cast<ScalableVectorType>(II.getType());
1867   if (!VecTy || !OutTy || VecTy->getNumElements() != OutTy->getMinNumElements())
1868     return std::nullopt;
1869 
1870   unsigned NumElts = VecTy->getNumElements();
1871   unsigned PredicateBits = 0;
1872 
1873   // Expand intrinsic operands to a 16-bit byte level predicate
1874   for (unsigned I = 0; I < NumElts; ++I) {
1875     auto *Arg = dyn_cast<ConstantInt>(ConstVec->getAggregateElement(I));
1876     if (!Arg)
1877       return std::nullopt;
1878     if (!Arg->isZero())
1879       PredicateBits |= 1 << (I * (16 / NumElts));
1880   }
1881 
1882   // If all bits are zero bail early with an empty predicate
1883   if (PredicateBits == 0) {
1884     auto *PFalse = Constant::getNullValue(II.getType());
1885     PFalse->takeName(&II);
1886     return IC.replaceInstUsesWith(II, PFalse);
1887   }
1888 
1889   // Calculate largest predicate type used (where byte predicate is largest)
1890   unsigned Mask = 8;
1891   for (unsigned I = 0; I < 16; ++I)
1892     if ((PredicateBits & (1 << I)) != 0)
1893       Mask |= (I % 8);
1894 
1895   unsigned PredSize = Mask & -Mask;
1896   auto *PredType = ScalableVectorType::get(
1897       Type::getInt1Ty(Ctx), AArch64::SVEBitsPerBlock / (PredSize * 8));
1898 
1899   // Ensure all relevant bits are set
1900   for (unsigned I = 0; I < 16; I += PredSize)
1901     if ((PredicateBits & (1 << I)) == 0)
1902       return std::nullopt;
1903 
1904   auto *PTruePat =
1905       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
1906   auto *PTrue = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
1907                                            {PredType}, {PTruePat});
1908   auto *ConvertToSVBool = IC.Builder.CreateIntrinsic(
1909       Intrinsic::aarch64_sve_convert_to_svbool, {PredType}, {PTrue});
1910   auto *ConvertFromSVBool =
1911       IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
1912                                  {II.getType()}, {ConvertToSVBool});
1913 
1914   ConvertFromSVBool->takeName(&II);
1915   return IC.replaceInstUsesWith(II, ConvertFromSVBool);
1916 }
1917 
instCombineSVELast(InstCombiner & IC,IntrinsicInst & II)1918 static std::optional<Instruction *> instCombineSVELast(InstCombiner &IC,
1919                                                        IntrinsicInst &II) {
1920   Value *Pg = II.getArgOperand(0);
1921   Value *Vec = II.getArgOperand(1);
1922   auto IntrinsicID = II.getIntrinsicID();
1923   bool IsAfter = IntrinsicID == Intrinsic::aarch64_sve_lasta;
1924 
1925   // lastX(splat(X)) --> X
1926   if (auto *SplatVal = getSplatValue(Vec))
1927     return IC.replaceInstUsesWith(II, SplatVal);
1928 
1929   // If x and/or y is a splat value then:
1930   // lastX (binop (x, y)) --> binop(lastX(x), lastX(y))
1931   Value *LHS, *RHS;
1932   if (match(Vec, m_OneUse(m_BinOp(m_Value(LHS), m_Value(RHS))))) {
1933     if (isSplatValue(LHS) || isSplatValue(RHS)) {
1934       auto *OldBinOp = cast<BinaryOperator>(Vec);
1935       auto OpC = OldBinOp->getOpcode();
1936       auto *NewLHS =
1937           IC.Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, LHS});
1938       auto *NewRHS =
1939           IC.Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, RHS});
1940       auto *NewBinOp = BinaryOperator::CreateWithCopiedFlags(
1941           OpC, NewLHS, NewRHS, OldBinOp, OldBinOp->getName(), II.getIterator());
1942       return IC.replaceInstUsesWith(II, NewBinOp);
1943     }
1944   }
1945 
1946   auto *C = dyn_cast<Constant>(Pg);
1947   if (IsAfter && C && C->isNullValue()) {
1948     // The intrinsic is extracting lane 0 so use an extract instead.
1949     auto *IdxTy = Type::getInt64Ty(II.getContext());
1950     auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, 0));
1951     Extract->insertBefore(II.getIterator());
1952     Extract->takeName(&II);
1953     return IC.replaceInstUsesWith(II, Extract);
1954   }
1955 
1956   auto *IntrPG = dyn_cast<IntrinsicInst>(Pg);
1957   if (!IntrPG)
1958     return std::nullopt;
1959 
1960   if (IntrPG->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
1961     return std::nullopt;
1962 
1963   const auto PTruePattern =
1964       cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue();
1965 
1966   // Can the intrinsic's predicate be converted to a known constant index?
1967   unsigned MinNumElts = getNumElementsFromSVEPredPattern(PTruePattern);
1968   if (!MinNumElts)
1969     return std::nullopt;
1970 
1971   unsigned Idx = MinNumElts - 1;
1972   // Increment the index if extracting the element after the last active
1973   // predicate element.
1974   if (IsAfter)
1975     ++Idx;
1976 
1977   // Ignore extracts whose index is larger than the known minimum vector
1978   // length. NOTE: This is an artificial constraint where we prefer to
1979   // maintain what the user asked for until an alternative is proven faster.
1980   auto *PgVTy = cast<ScalableVectorType>(Pg->getType());
1981   if (Idx >= PgVTy->getMinNumElements())
1982     return std::nullopt;
1983 
1984   // The intrinsic is extracting a fixed lane so use an extract instead.
1985   auto *IdxTy = Type::getInt64Ty(II.getContext());
1986   auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, Idx));
1987   Extract->insertBefore(II.getIterator());
1988   Extract->takeName(&II);
1989   return IC.replaceInstUsesWith(II, Extract);
1990 }
1991 
instCombineSVECondLast(InstCombiner & IC,IntrinsicInst & II)1992 static std::optional<Instruction *> instCombineSVECondLast(InstCombiner &IC,
1993                                                            IntrinsicInst &II) {
1994   // The SIMD&FP variant of CLAST[AB] is significantly faster than the scalar
1995   // integer variant across a variety of micro-architectures. Replace scalar
1996   // integer CLAST[AB] intrinsic with optimal SIMD&FP variant. A simple
1997   // bitcast-to-fp + clast[ab] + bitcast-to-int will cost a cycle or two more
1998   // depending on the micro-architecture, but has been observed as generally
1999   // being faster, particularly when the CLAST[AB] op is a loop-carried
2000   // dependency.
2001   Value *Pg = II.getArgOperand(0);
2002   Value *Fallback = II.getArgOperand(1);
2003   Value *Vec = II.getArgOperand(2);
2004   Type *Ty = II.getType();
2005 
2006   if (!Ty->isIntegerTy())
2007     return std::nullopt;
2008 
2009   Type *FPTy;
2010   switch (cast<IntegerType>(Ty)->getBitWidth()) {
2011   default:
2012     return std::nullopt;
2013   case 16:
2014     FPTy = IC.Builder.getHalfTy();
2015     break;
2016   case 32:
2017     FPTy = IC.Builder.getFloatTy();
2018     break;
2019   case 64:
2020     FPTy = IC.Builder.getDoubleTy();
2021     break;
2022   }
2023 
2024   Value *FPFallBack = IC.Builder.CreateBitCast(Fallback, FPTy);
2025   auto *FPVTy = VectorType::get(
2026       FPTy, cast<VectorType>(Vec->getType())->getElementCount());
2027   Value *FPVec = IC.Builder.CreateBitCast(Vec, FPVTy);
2028   auto *FPII = IC.Builder.CreateIntrinsic(
2029       II.getIntrinsicID(), {FPVec->getType()}, {Pg, FPFallBack, FPVec});
2030   Value *FPIItoInt = IC.Builder.CreateBitCast(FPII, II.getType());
2031   return IC.replaceInstUsesWith(II, FPIItoInt);
2032 }
2033 
instCombineRDFFR(InstCombiner & IC,IntrinsicInst & II)2034 static std::optional<Instruction *> instCombineRDFFR(InstCombiner &IC,
2035                                                      IntrinsicInst &II) {
2036   LLVMContext &Ctx = II.getContext();
2037   // Replace rdffr with predicated rdffr.z intrinsic, so that optimizePTestInstr
2038   // can work with RDFFR_PP for ptest elimination.
2039   auto *AllPat =
2040       ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
2041   auto *PTrue = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
2042                                            {II.getType()}, {AllPat});
2043   auto *RDFFR =
2044       IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_rdffr_z, {PTrue});
2045   RDFFR->takeName(&II);
2046   return IC.replaceInstUsesWith(II, RDFFR);
2047 }
2048 
2049 static std::optional<Instruction *>
instCombineSVECntElts(InstCombiner & IC,IntrinsicInst & II,unsigned NumElts)2050 instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
2051   const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
2052 
2053   if (Pattern == AArch64SVEPredPattern::all) {
2054     Value *Cnt = IC.Builder.CreateElementCount(
2055         II.getType(), ElementCount::getScalable(NumElts));
2056     Cnt->takeName(&II);
2057     return IC.replaceInstUsesWith(II, Cnt);
2058   }
2059 
2060   unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern);
2061 
2062   return MinNumElts && NumElts >= MinNumElts
2063              ? std::optional<Instruction *>(IC.replaceInstUsesWith(
2064                    II, ConstantInt::get(II.getType(), MinNumElts)))
2065              : std::nullopt;
2066 }
2067 
instCombineSVEPTest(InstCombiner & IC,IntrinsicInst & II)2068 static std::optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
2069                                                         IntrinsicInst &II) {
2070   Value *PgVal = II.getArgOperand(0);
2071   Value *OpVal = II.getArgOperand(1);
2072 
2073   // PTEST_<FIRST|LAST>(X, X) is equivalent to PTEST_ANY(X, X).
2074   // Later optimizations prefer this form.
2075   if (PgVal == OpVal &&
2076       (II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_first ||
2077        II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_last)) {
2078     Value *Ops[] = {PgVal, OpVal};
2079     Type *Tys[] = {PgVal->getType()};
2080 
2081     auto *PTest =
2082         IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptest_any, Tys, Ops);
2083     PTest->takeName(&II);
2084 
2085     return IC.replaceInstUsesWith(II, PTest);
2086   }
2087 
2088   IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(PgVal);
2089   IntrinsicInst *Op = dyn_cast<IntrinsicInst>(OpVal);
2090 
2091   if (!Pg || !Op)
2092     return std::nullopt;
2093 
2094   Intrinsic::ID OpIID = Op->getIntrinsicID();
2095 
2096   if (Pg->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
2097       OpIID == Intrinsic::aarch64_sve_convert_to_svbool &&
2098       Pg->getArgOperand(0)->getType() == Op->getArgOperand(0)->getType()) {
2099     Value *Ops[] = {Pg->getArgOperand(0), Op->getArgOperand(0)};
2100     Type *Tys[] = {Pg->getArgOperand(0)->getType()};
2101 
2102     auto *PTest = IC.Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
2103 
2104     PTest->takeName(&II);
2105     return IC.replaceInstUsesWith(II, PTest);
2106   }
2107 
2108   // Transform PTEST_ANY(X=OP(PG,...), X) -> PTEST_ANY(PG, X)).
2109   // Later optimizations may rewrite sequence to use the flag-setting variant
2110   // of instruction X to remove PTEST.
2111   if ((Pg == Op) && (II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_any) &&
2112       ((OpIID == Intrinsic::aarch64_sve_brka_z) ||
2113        (OpIID == Intrinsic::aarch64_sve_brkb_z) ||
2114        (OpIID == Intrinsic::aarch64_sve_brkpa_z) ||
2115        (OpIID == Intrinsic::aarch64_sve_brkpb_z) ||
2116        (OpIID == Intrinsic::aarch64_sve_rdffr_z) ||
2117        (OpIID == Intrinsic::aarch64_sve_and_z) ||
2118        (OpIID == Intrinsic::aarch64_sve_bic_z) ||
2119        (OpIID == Intrinsic::aarch64_sve_eor_z) ||
2120        (OpIID == Intrinsic::aarch64_sve_nand_z) ||
2121        (OpIID == Intrinsic::aarch64_sve_nor_z) ||
2122        (OpIID == Intrinsic::aarch64_sve_orn_z) ||
2123        (OpIID == Intrinsic::aarch64_sve_orr_z))) {
2124     Value *Ops[] = {Pg->getArgOperand(0), Pg};
2125     Type *Tys[] = {Pg->getType()};
2126 
2127     auto *PTest = IC.Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
2128     PTest->takeName(&II);
2129 
2130     return IC.replaceInstUsesWith(II, PTest);
2131   }
2132 
2133   return std::nullopt;
2134 }
2135 
2136 template <Intrinsic::ID MulOpc, typename Intrinsic::ID FuseOpc>
2137 static std::optional<Instruction *>
instCombineSVEVectorFuseMulAddSub(InstCombiner & IC,IntrinsicInst & II,bool MergeIntoAddendOp)2138 instCombineSVEVectorFuseMulAddSub(InstCombiner &IC, IntrinsicInst &II,
2139                                   bool MergeIntoAddendOp) {
2140   Value *P = II.getOperand(0);
2141   Value *MulOp0, *MulOp1, *AddendOp, *Mul;
2142   if (MergeIntoAddendOp) {
2143     AddendOp = II.getOperand(1);
2144     Mul = II.getOperand(2);
2145   } else {
2146     AddendOp = II.getOperand(2);
2147     Mul = II.getOperand(1);
2148   }
2149 
2150   if (!match(Mul, m_Intrinsic<MulOpc>(m_Specific(P), m_Value(MulOp0),
2151                                       m_Value(MulOp1))))
2152     return std::nullopt;
2153 
2154   if (!Mul->hasOneUse())
2155     return std::nullopt;
2156 
2157   Instruction *FMFSource = nullptr;
2158   if (II.getType()->isFPOrFPVectorTy()) {
2159     llvm::FastMathFlags FAddFlags = II.getFastMathFlags();
2160     // Stop the combine when the flags on the inputs differ in case dropping
2161     // flags would lead to us missing out on more beneficial optimizations.
2162     if (FAddFlags != cast<CallInst>(Mul)->getFastMathFlags())
2163       return std::nullopt;
2164     if (!FAddFlags.allowContract())
2165       return std::nullopt;
2166     FMFSource = &II;
2167   }
2168 
2169   CallInst *Res;
2170   if (MergeIntoAddendOp)
2171     Res = IC.Builder.CreateIntrinsic(FuseOpc, {II.getType()},
2172                                      {P, AddendOp, MulOp0, MulOp1}, FMFSource);
2173   else
2174     Res = IC.Builder.CreateIntrinsic(FuseOpc, {II.getType()},
2175                                      {P, MulOp0, MulOp1, AddendOp}, FMFSource);
2176 
2177   return IC.replaceInstUsesWith(II, Res);
2178 }
2179 
2180 static std::optional<Instruction *>
instCombineSVELD1(InstCombiner & IC,IntrinsicInst & II,const DataLayout & DL)2181 instCombineSVELD1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
2182   Value *Pred = II.getOperand(0);
2183   Value *PtrOp = II.getOperand(1);
2184   Type *VecTy = II.getType();
2185 
2186   if (isAllActivePredicate(Pred)) {
2187     LoadInst *Load = IC.Builder.CreateLoad(VecTy, PtrOp);
2188     Load->copyMetadata(II);
2189     return IC.replaceInstUsesWith(II, Load);
2190   }
2191 
2192   CallInst *MaskedLoad =
2193       IC.Builder.CreateMaskedLoad(VecTy, PtrOp, PtrOp->getPointerAlignment(DL),
2194                                   Pred, ConstantAggregateZero::get(VecTy));
2195   MaskedLoad->copyMetadata(II);
2196   return IC.replaceInstUsesWith(II, MaskedLoad);
2197 }
2198 
2199 static std::optional<Instruction *>
instCombineSVEST1(InstCombiner & IC,IntrinsicInst & II,const DataLayout & DL)2200 instCombineSVEST1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
2201   Value *VecOp = II.getOperand(0);
2202   Value *Pred = II.getOperand(1);
2203   Value *PtrOp = II.getOperand(2);
2204 
2205   if (isAllActivePredicate(Pred)) {
2206     StoreInst *Store = IC.Builder.CreateStore(VecOp, PtrOp);
2207     Store->copyMetadata(II);
2208     return IC.eraseInstFromFunction(II);
2209   }
2210 
2211   CallInst *MaskedStore = IC.Builder.CreateMaskedStore(
2212       VecOp, PtrOp, PtrOp->getPointerAlignment(DL), Pred);
2213   MaskedStore->copyMetadata(II);
2214   return IC.eraseInstFromFunction(II);
2215 }
2216 
intrinsicIDToBinOpCode(unsigned Intrinsic)2217 static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) {
2218   switch (Intrinsic) {
2219   case Intrinsic::aarch64_sve_fmul_u:
2220     return Instruction::BinaryOps::FMul;
2221   case Intrinsic::aarch64_sve_fadd_u:
2222     return Instruction::BinaryOps::FAdd;
2223   case Intrinsic::aarch64_sve_fsub_u:
2224     return Instruction::BinaryOps::FSub;
2225   default:
2226     return Instruction::BinaryOpsEnd;
2227   }
2228 }
2229 
2230 static std::optional<Instruction *>
instCombineSVEVectorBinOp(InstCombiner & IC,IntrinsicInst & II)2231 instCombineSVEVectorBinOp(InstCombiner &IC, IntrinsicInst &II) {
2232   // Bail due to missing support for ISD::STRICT_ scalable vector operations.
2233   if (II.isStrictFP())
2234     return std::nullopt;
2235 
2236   auto *OpPredicate = II.getOperand(0);
2237   auto BinOpCode = intrinsicIDToBinOpCode(II.getIntrinsicID());
2238   if (BinOpCode == Instruction::BinaryOpsEnd ||
2239       !isAllActivePredicate(OpPredicate))
2240     return std::nullopt;
2241   auto BinOp = IC.Builder.CreateBinOpFMF(
2242       BinOpCode, II.getOperand(1), II.getOperand(2), II.getFastMathFlags());
2243   return IC.replaceInstUsesWith(II, BinOp);
2244 }
2245 
instCombineSVEVectorAdd(InstCombiner & IC,IntrinsicInst & II)2246 static std::optional<Instruction *> instCombineSVEVectorAdd(InstCombiner &IC,
2247                                                             IntrinsicInst &II) {
2248   if (auto MLA = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
2249                                                    Intrinsic::aarch64_sve_mla>(
2250           IC, II, true))
2251     return MLA;
2252   if (auto MAD = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
2253                                                    Intrinsic::aarch64_sve_mad>(
2254           IC, II, false))
2255     return MAD;
2256   return std::nullopt;
2257 }
2258 
2259 static std::optional<Instruction *>
instCombineSVEVectorFAdd(InstCombiner & IC,IntrinsicInst & II)2260 instCombineSVEVectorFAdd(InstCombiner &IC, IntrinsicInst &II) {
2261   if (auto FMLA =
2262           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
2263                                             Intrinsic::aarch64_sve_fmla>(IC, II,
2264                                                                          true))
2265     return FMLA;
2266   if (auto FMAD =
2267           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
2268                                             Intrinsic::aarch64_sve_fmad>(IC, II,
2269                                                                          false))
2270     return FMAD;
2271   if (auto FMLA =
2272           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
2273                                             Intrinsic::aarch64_sve_fmla>(IC, II,
2274                                                                          true))
2275     return FMLA;
2276   return std::nullopt;
2277 }
2278 
2279 static std::optional<Instruction *>
instCombineSVEVectorFAddU(InstCombiner & IC,IntrinsicInst & II)2280 instCombineSVEVectorFAddU(InstCombiner &IC, IntrinsicInst &II) {
2281   if (auto FMLA =
2282           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
2283                                             Intrinsic::aarch64_sve_fmla>(IC, II,
2284                                                                          true))
2285     return FMLA;
2286   if (auto FMAD =
2287           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
2288                                             Intrinsic::aarch64_sve_fmad>(IC, II,
2289                                                                          false))
2290     return FMAD;
2291   if (auto FMLA_U =
2292           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
2293                                             Intrinsic::aarch64_sve_fmla_u>(
2294               IC, II, true))
2295     return FMLA_U;
2296   return instCombineSVEVectorBinOp(IC, II);
2297 }
2298 
2299 static std::optional<Instruction *>
instCombineSVEVectorFSub(InstCombiner & IC,IntrinsicInst & II)2300 instCombineSVEVectorFSub(InstCombiner &IC, IntrinsicInst &II) {
2301   if (auto FMLS =
2302           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
2303                                             Intrinsic::aarch64_sve_fmls>(IC, II,
2304                                                                          true))
2305     return FMLS;
2306   if (auto FMSB =
2307           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
2308                                             Intrinsic::aarch64_sve_fnmsb>(
2309               IC, II, false))
2310     return FMSB;
2311   if (auto FMLS =
2312           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
2313                                             Intrinsic::aarch64_sve_fmls>(IC, II,
2314                                                                          true))
2315     return FMLS;
2316   return std::nullopt;
2317 }
2318 
2319 static std::optional<Instruction *>
instCombineSVEVectorFSubU(InstCombiner & IC,IntrinsicInst & II)2320 instCombineSVEVectorFSubU(InstCombiner &IC, IntrinsicInst &II) {
2321   if (auto FMLS =
2322           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
2323                                             Intrinsic::aarch64_sve_fmls>(IC, II,
2324                                                                          true))
2325     return FMLS;
2326   if (auto FMSB =
2327           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
2328                                             Intrinsic::aarch64_sve_fnmsb>(
2329               IC, II, false))
2330     return FMSB;
2331   if (auto FMLS_U =
2332           instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
2333                                             Intrinsic::aarch64_sve_fmls_u>(
2334               IC, II, true))
2335     return FMLS_U;
2336   return instCombineSVEVectorBinOp(IC, II);
2337 }
2338 
instCombineSVEVectorSub(InstCombiner & IC,IntrinsicInst & II)2339 static std::optional<Instruction *> instCombineSVEVectorSub(InstCombiner &IC,
2340                                                             IntrinsicInst &II) {
2341   if (auto MLS = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
2342                                                    Intrinsic::aarch64_sve_mls>(
2343           IC, II, true))
2344     return MLS;
2345   return std::nullopt;
2346 }
2347 
instCombineSVEUnpack(InstCombiner & IC,IntrinsicInst & II)2348 static std::optional<Instruction *> instCombineSVEUnpack(InstCombiner &IC,
2349                                                          IntrinsicInst &II) {
2350   Value *UnpackArg = II.getArgOperand(0);
2351   auto *RetTy = cast<ScalableVectorType>(II.getType());
2352   bool IsSigned = II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpkhi ||
2353                   II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpklo;
2354 
2355   // Hi = uunpkhi(splat(X)) --> Hi = splat(extend(X))
2356   // Lo = uunpklo(splat(X)) --> Lo = splat(extend(X))
2357   if (auto *ScalarArg = getSplatValue(UnpackArg)) {
2358     ScalarArg =
2359         IC.Builder.CreateIntCast(ScalarArg, RetTy->getScalarType(), IsSigned);
2360     Value *NewVal =
2361         IC.Builder.CreateVectorSplat(RetTy->getElementCount(), ScalarArg);
2362     NewVal->takeName(&II);
2363     return IC.replaceInstUsesWith(II, NewVal);
2364   }
2365 
2366   return std::nullopt;
2367 }
instCombineSVETBL(InstCombiner & IC,IntrinsicInst & II)2368 static std::optional<Instruction *> instCombineSVETBL(InstCombiner &IC,
2369                                                       IntrinsicInst &II) {
2370   auto *OpVal = II.getOperand(0);
2371   auto *OpIndices = II.getOperand(1);
2372   VectorType *VTy = cast<VectorType>(II.getType());
2373 
2374   // Check whether OpIndices is a constant splat value < minimal element count
2375   // of result.
2376   auto *SplatValue = dyn_cast_or_null<ConstantInt>(getSplatValue(OpIndices));
2377   if (!SplatValue ||
2378       SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue()))
2379     return std::nullopt;
2380 
2381   // Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to
2382   // splat_vector(extractelement(OpVal, SplatValue)) for further optimization.
2383   auto *Extract = IC.Builder.CreateExtractElement(OpVal, SplatValue);
2384   auto *VectorSplat =
2385       IC.Builder.CreateVectorSplat(VTy->getElementCount(), Extract);
2386 
2387   VectorSplat->takeName(&II);
2388   return IC.replaceInstUsesWith(II, VectorSplat);
2389 }
2390 
instCombineSVEUzp1(InstCombiner & IC,IntrinsicInst & II)2391 static std::optional<Instruction *> instCombineSVEUzp1(InstCombiner &IC,
2392                                                        IntrinsicInst &II) {
2393   Value *A, *B;
2394   Type *RetTy = II.getType();
2395   constexpr Intrinsic::ID FromSVB = Intrinsic::aarch64_sve_convert_from_svbool;
2396   constexpr Intrinsic::ID ToSVB = Intrinsic::aarch64_sve_convert_to_svbool;
2397 
2398   // uzp1(to_svbool(A), to_svbool(B)) --> <A, B>
2399   // uzp1(from_svbool(to_svbool(A)), from_svbool(to_svbool(B))) --> <A, B>
2400   if ((match(II.getArgOperand(0),
2401              m_Intrinsic<FromSVB>(m_Intrinsic<ToSVB>(m_Value(A)))) &&
2402        match(II.getArgOperand(1),
2403              m_Intrinsic<FromSVB>(m_Intrinsic<ToSVB>(m_Value(B))))) ||
2404       (match(II.getArgOperand(0), m_Intrinsic<ToSVB>(m_Value(A))) &&
2405        match(II.getArgOperand(1), m_Intrinsic<ToSVB>(m_Value(B))))) {
2406     auto *TyA = cast<ScalableVectorType>(A->getType());
2407     if (TyA == B->getType() &&
2408         RetTy == ScalableVectorType::getDoubleElementsVectorType(TyA)) {
2409       auto *SubVec = IC.Builder.CreateInsertVector(
2410           RetTy, PoisonValue::get(RetTy), A, uint64_t(0));
2411       auto *ConcatVec = IC.Builder.CreateInsertVector(RetTy, SubVec, B,
2412                                                       TyA->getMinNumElements());
2413       ConcatVec->takeName(&II);
2414       return IC.replaceInstUsesWith(II, ConcatVec);
2415     }
2416   }
2417 
2418   return std::nullopt;
2419 }
2420 
instCombineSVEZip(InstCombiner & IC,IntrinsicInst & II)2421 static std::optional<Instruction *> instCombineSVEZip(InstCombiner &IC,
2422                                                       IntrinsicInst &II) {
2423   // zip1(uzp1(A, B), uzp2(A, B)) --> A
2424   // zip2(uzp1(A, B), uzp2(A, B)) --> B
2425   Value *A, *B;
2426   if (match(II.getArgOperand(0),
2427             m_Intrinsic<Intrinsic::aarch64_sve_uzp1>(m_Value(A), m_Value(B))) &&
2428       match(II.getArgOperand(1), m_Intrinsic<Intrinsic::aarch64_sve_uzp2>(
2429                                      m_Specific(A), m_Specific(B))))
2430     return IC.replaceInstUsesWith(
2431         II, (II.getIntrinsicID() == Intrinsic::aarch64_sve_zip1 ? A : B));
2432 
2433   return std::nullopt;
2434 }
2435 
2436 static std::optional<Instruction *>
instCombineLD1GatherIndex(InstCombiner & IC,IntrinsicInst & II)2437 instCombineLD1GatherIndex(InstCombiner &IC, IntrinsicInst &II) {
2438   Value *Mask = II.getOperand(0);
2439   Value *BasePtr = II.getOperand(1);
2440   Value *Index = II.getOperand(2);
2441   Type *Ty = II.getType();
2442   Value *PassThru = ConstantAggregateZero::get(Ty);
2443 
2444   // Contiguous gather => masked load.
2445   // (sve.ld1.gather.index Mask BasePtr (sve.index IndexBase 1))
2446   // => (masked.load (gep BasePtr IndexBase) Align Mask zeroinitializer)
2447   Value *IndexBase;
2448   if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
2449                        m_Value(IndexBase), m_SpecificInt(1)))) {
2450     Align Alignment =
2451         BasePtr->getPointerAlignment(II.getDataLayout());
2452 
2453     Value *Ptr = IC.Builder.CreateGEP(cast<VectorType>(Ty)->getElementType(),
2454                                       BasePtr, IndexBase);
2455     CallInst *MaskedLoad =
2456         IC.Builder.CreateMaskedLoad(Ty, Ptr, Alignment, Mask, PassThru);
2457     MaskedLoad->takeName(&II);
2458     return IC.replaceInstUsesWith(II, MaskedLoad);
2459   }
2460 
2461   return std::nullopt;
2462 }
2463 
2464 static std::optional<Instruction *>
instCombineST1ScatterIndex(InstCombiner & IC,IntrinsicInst & II)2465 instCombineST1ScatterIndex(InstCombiner &IC, IntrinsicInst &II) {
2466   Value *Val = II.getOperand(0);
2467   Value *Mask = II.getOperand(1);
2468   Value *BasePtr = II.getOperand(2);
2469   Value *Index = II.getOperand(3);
2470   Type *Ty = Val->getType();
2471 
2472   // Contiguous scatter => masked store.
2473   // (sve.st1.scatter.index Value Mask BasePtr (sve.index IndexBase 1))
2474   // => (masked.store Value (gep BasePtr IndexBase) Align Mask)
2475   Value *IndexBase;
2476   if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
2477                        m_Value(IndexBase), m_SpecificInt(1)))) {
2478     Align Alignment =
2479         BasePtr->getPointerAlignment(II.getDataLayout());
2480 
2481     Value *Ptr = IC.Builder.CreateGEP(cast<VectorType>(Ty)->getElementType(),
2482                                       BasePtr, IndexBase);
2483     (void)IC.Builder.CreateMaskedStore(Val, Ptr, Alignment, Mask);
2484 
2485     return IC.eraseInstFromFunction(II);
2486   }
2487 
2488   return std::nullopt;
2489 }
2490 
instCombineSVESDIV(InstCombiner & IC,IntrinsicInst & II)2491 static std::optional<Instruction *> instCombineSVESDIV(InstCombiner &IC,
2492                                                        IntrinsicInst &II) {
2493   Type *Int32Ty = IC.Builder.getInt32Ty();
2494   Value *Pred = II.getOperand(0);
2495   Value *Vec = II.getOperand(1);
2496   Value *DivVec = II.getOperand(2);
2497 
2498   Value *SplatValue = getSplatValue(DivVec);
2499   ConstantInt *SplatConstantInt = dyn_cast_or_null<ConstantInt>(SplatValue);
2500   if (!SplatConstantInt)
2501     return std::nullopt;
2502 
2503   APInt Divisor = SplatConstantInt->getValue();
2504   const int64_t DivisorValue = Divisor.getSExtValue();
2505   if (DivisorValue == -1)
2506     return std::nullopt;
2507   if (DivisorValue == 1)
2508     IC.replaceInstUsesWith(II, Vec);
2509 
2510   if (Divisor.isPowerOf2()) {
2511     Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
2512     auto ASRD = IC.Builder.CreateIntrinsic(
2513         Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
2514     return IC.replaceInstUsesWith(II, ASRD);
2515   }
2516   if (Divisor.isNegatedPowerOf2()) {
2517     Divisor.negate();
2518     Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
2519     auto ASRD = IC.Builder.CreateIntrinsic(
2520         Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
2521     auto NEG = IC.Builder.CreateIntrinsic(
2522         Intrinsic::aarch64_sve_neg, {ASRD->getType()}, {ASRD, Pred, ASRD});
2523     return IC.replaceInstUsesWith(II, NEG);
2524   }
2525 
2526   return std::nullopt;
2527 }
2528 
SimplifyValuePattern(SmallVector<Value * > & Vec,bool AllowPoison)2529 bool SimplifyValuePattern(SmallVector<Value *> &Vec, bool AllowPoison) {
2530   size_t VecSize = Vec.size();
2531   if (VecSize == 1)
2532     return true;
2533   if (!isPowerOf2_64(VecSize))
2534     return false;
2535   size_t HalfVecSize = VecSize / 2;
2536 
2537   for (auto LHS = Vec.begin(), RHS = Vec.begin() + HalfVecSize;
2538        RHS != Vec.end(); LHS++, RHS++) {
2539     if (*LHS != nullptr && *RHS != nullptr) {
2540       if (*LHS == *RHS)
2541         continue;
2542       else
2543         return false;
2544     }
2545     if (!AllowPoison)
2546       return false;
2547     if (*LHS == nullptr && *RHS != nullptr)
2548       *LHS = *RHS;
2549   }
2550 
2551   Vec.resize(HalfVecSize);
2552   SimplifyValuePattern(Vec, AllowPoison);
2553   return true;
2554 }
2555 
2556 // Try to simplify dupqlane patterns like dupqlane(f32 A, f32 B, f32 A, f32 B)
2557 // to dupqlane(f64(C)) where C is A concatenated with B
instCombineSVEDupqLane(InstCombiner & IC,IntrinsicInst & II)2558 static std::optional<Instruction *> instCombineSVEDupqLane(InstCombiner &IC,
2559                                                            IntrinsicInst &II) {
2560   Value *CurrentInsertElt = nullptr, *Default = nullptr;
2561   if (!match(II.getOperand(0),
2562              m_Intrinsic<Intrinsic::vector_insert>(
2563                  m_Value(Default), m_Value(CurrentInsertElt), m_Value())) ||
2564       !isa<FixedVectorType>(CurrentInsertElt->getType()))
2565     return std::nullopt;
2566   auto IIScalableTy = cast<ScalableVectorType>(II.getType());
2567 
2568   // Insert the scalars into a container ordered by InsertElement index
2569   SmallVector<Value *> Elts(IIScalableTy->getMinNumElements(), nullptr);
2570   while (auto InsertElt = dyn_cast<InsertElementInst>(CurrentInsertElt)) {
2571     auto Idx = cast<ConstantInt>(InsertElt->getOperand(2));
2572     Elts[Idx->getValue().getZExtValue()] = InsertElt->getOperand(1);
2573     CurrentInsertElt = InsertElt->getOperand(0);
2574   }
2575 
2576   bool AllowPoison =
2577       isa<PoisonValue>(CurrentInsertElt) && isa<PoisonValue>(Default);
2578   if (!SimplifyValuePattern(Elts, AllowPoison))
2579     return std::nullopt;
2580 
2581   // Rebuild the simplified chain of InsertElements. e.g. (a, b, a, b) as (a, b)
2582   Value *InsertEltChain = PoisonValue::get(CurrentInsertElt->getType());
2583   for (size_t I = 0; I < Elts.size(); I++) {
2584     if (Elts[I] == nullptr)
2585       continue;
2586     InsertEltChain = IC.Builder.CreateInsertElement(InsertEltChain, Elts[I],
2587                                                     IC.Builder.getInt64(I));
2588   }
2589   if (InsertEltChain == nullptr)
2590     return std::nullopt;
2591 
2592   // Splat the simplified sequence, e.g. (f16 a, f16 b, f16 c, f16 d) as one i64
2593   // value or (f16 a, f16 b) as one i32 value. This requires an InsertSubvector
2594   // be bitcast to a type wide enough to fit the sequence, be splatted, and then
2595   // be narrowed back to the original type.
2596   unsigned PatternWidth = IIScalableTy->getScalarSizeInBits() * Elts.size();
2597   unsigned PatternElementCount = IIScalableTy->getScalarSizeInBits() *
2598                                  IIScalableTy->getMinNumElements() /
2599                                  PatternWidth;
2600 
2601   IntegerType *WideTy = IC.Builder.getIntNTy(PatternWidth);
2602   auto *WideScalableTy = ScalableVectorType::get(WideTy, PatternElementCount);
2603   auto *WideShuffleMaskTy =
2604       ScalableVectorType::get(IC.Builder.getInt32Ty(), PatternElementCount);
2605 
2606   auto InsertSubvector = IC.Builder.CreateInsertVector(
2607       II.getType(), PoisonValue::get(II.getType()), InsertEltChain,
2608       uint64_t(0));
2609   auto WideBitcast =
2610       IC.Builder.CreateBitOrPointerCast(InsertSubvector, WideScalableTy);
2611   auto WideShuffleMask = ConstantAggregateZero::get(WideShuffleMaskTy);
2612   auto WideShuffle = IC.Builder.CreateShuffleVector(
2613       WideBitcast, PoisonValue::get(WideScalableTy), WideShuffleMask);
2614   auto NarrowBitcast =
2615       IC.Builder.CreateBitOrPointerCast(WideShuffle, II.getType());
2616 
2617   return IC.replaceInstUsesWith(II, NarrowBitcast);
2618 }
2619 
instCombineMaxMinNM(InstCombiner & IC,IntrinsicInst & II)2620 static std::optional<Instruction *> instCombineMaxMinNM(InstCombiner &IC,
2621                                                         IntrinsicInst &II) {
2622   Value *A = II.getArgOperand(0);
2623   Value *B = II.getArgOperand(1);
2624   if (A == B)
2625     return IC.replaceInstUsesWith(II, A);
2626 
2627   return std::nullopt;
2628 }
2629 
instCombineSVESrshl(InstCombiner & IC,IntrinsicInst & II)2630 static std::optional<Instruction *> instCombineSVESrshl(InstCombiner &IC,
2631                                                         IntrinsicInst &II) {
2632   Value *Pred = II.getOperand(0);
2633   Value *Vec = II.getOperand(1);
2634   Value *Shift = II.getOperand(2);
2635 
2636   // Convert SRSHL into the simpler LSL intrinsic when fed by an ABS intrinsic.
2637   Value *AbsPred, *MergedValue;
2638   if (!match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_sqabs>(
2639                       m_Value(MergedValue), m_Value(AbsPred), m_Value())) &&
2640       !match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_abs>(
2641                       m_Value(MergedValue), m_Value(AbsPred), m_Value())))
2642 
2643     return std::nullopt;
2644 
2645   // Transform is valid if any of the following are true:
2646   // * The ABS merge value is an undef or non-negative
2647   // * The ABS predicate is all active
2648   // * The ABS predicate and the SRSHL predicates are the same
2649   if (!isa<UndefValue>(MergedValue) && !match(MergedValue, m_NonNegative()) &&
2650       AbsPred != Pred && !isAllActivePredicate(AbsPred))
2651     return std::nullopt;
2652 
2653   // Only valid when the shift amount is non-negative, otherwise the rounding
2654   // behaviour of SRSHL cannot be ignored.
2655   if (!match(Shift, m_NonNegative()))
2656     return std::nullopt;
2657 
2658   auto LSL = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_lsl,
2659                                         {II.getType()}, {Pred, Vec, Shift});
2660 
2661   return IC.replaceInstUsesWith(II, LSL);
2662 }
2663 
instCombineSVEInsr(InstCombiner & IC,IntrinsicInst & II)2664 static std::optional<Instruction *> instCombineSVEInsr(InstCombiner &IC,
2665                                                        IntrinsicInst &II) {
2666   Value *Vec = II.getOperand(0);
2667 
2668   if (getSplatValue(Vec) == II.getOperand(1))
2669     return IC.replaceInstUsesWith(II, Vec);
2670 
2671   return std::nullopt;
2672 }
2673 
instCombineDMB(InstCombiner & IC,IntrinsicInst & II)2674 static std::optional<Instruction *> instCombineDMB(InstCombiner &IC,
2675                                                    IntrinsicInst &II) {
2676   // If this barrier is post-dominated by identical one we can remove it
2677   auto *NI = II.getNextNonDebugInstruction();
2678   unsigned LookaheadThreshold = DMBLookaheadThreshold;
2679   auto CanSkipOver = [](Instruction *I) {
2680     return !I->mayReadOrWriteMemory() && !I->mayHaveSideEffects();
2681   };
2682   while (LookaheadThreshold-- && CanSkipOver(NI)) {
2683     auto *NIBB = NI->getParent();
2684     NI = NI->getNextNonDebugInstruction();
2685     if (!NI) {
2686       if (auto *SuccBB = NIBB->getUniqueSuccessor())
2687         NI = &*SuccBB->getFirstNonPHIOrDbgOrLifetime();
2688       else
2689         break;
2690     }
2691   }
2692   auto *NextII = dyn_cast_or_null<IntrinsicInst>(NI);
2693   if (NextII && II.isIdenticalTo(NextII))
2694     return IC.eraseInstFromFunction(II);
2695 
2696   return std::nullopt;
2697 }
2698 
instCombinePTrue(InstCombiner & IC,IntrinsicInst & II)2699 static std::optional<Instruction *> instCombinePTrue(InstCombiner &IC,
2700                                                      IntrinsicInst &II) {
2701   if (match(II.getOperand(0), m_ConstantInt<AArch64SVEPredPattern::all>()))
2702     return IC.replaceInstUsesWith(II, Constant::getAllOnesValue(II.getType()));
2703   return std::nullopt;
2704 }
2705 
instCombineSVEUxt(InstCombiner & IC,IntrinsicInst & II,unsigned NumBits)2706 static std::optional<Instruction *> instCombineSVEUxt(InstCombiner &IC,
2707                                                       IntrinsicInst &II,
2708                                                       unsigned NumBits) {
2709   Value *Passthru = II.getOperand(0);
2710   Value *Pg = II.getOperand(1);
2711   Value *Op = II.getOperand(2);
2712 
2713   // Convert UXT[BHW] to AND.
2714   if (isa<UndefValue>(Passthru) || isAllActivePredicate(Pg)) {
2715     auto *Ty = cast<VectorType>(II.getType());
2716     auto MaskValue = APInt::getLowBitsSet(Ty->getScalarSizeInBits(), NumBits);
2717     auto *Mask = ConstantInt::get(Ty, MaskValue);
2718     auto *And = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_and_u, {Ty},
2719                                            {Pg, Op, Mask});
2720     return IC.replaceInstUsesWith(II, And);
2721   }
2722 
2723   return std::nullopt;
2724 }
2725 
2726 static std::optional<Instruction *>
instCombineInStreamingMode(InstCombiner & IC,IntrinsicInst & II)2727 instCombineInStreamingMode(InstCombiner &IC, IntrinsicInst &II) {
2728   SMEAttrs FnSMEAttrs(*II.getFunction());
2729   bool IsStreaming = FnSMEAttrs.hasStreamingInterfaceOrBody();
2730   if (IsStreaming || !FnSMEAttrs.hasStreamingCompatibleInterface())
2731     return IC.replaceInstUsesWith(
2732         II, ConstantInt::getBool(II.getType(), IsStreaming));
2733   return std::nullopt;
2734 }
2735 
2736 std::optional<Instruction *>
instCombineIntrinsic(InstCombiner & IC,IntrinsicInst & II) const2737 AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
2738                                      IntrinsicInst &II) const {
2739   const SVEIntrinsicInfo &IInfo = constructSVEIntrinsicInfo(II);
2740   if (std::optional<Instruction *> I = simplifySVEIntrinsic(IC, II, IInfo))
2741     return I;
2742 
2743   Intrinsic::ID IID = II.getIntrinsicID();
2744   switch (IID) {
2745   default:
2746     break;
2747   case Intrinsic::aarch64_dmb:
2748     return instCombineDMB(IC, II);
2749   case Intrinsic::aarch64_neon_fmaxnm:
2750   case Intrinsic::aarch64_neon_fminnm:
2751     return instCombineMaxMinNM(IC, II);
2752   case Intrinsic::aarch64_sve_convert_from_svbool:
2753     return instCombineConvertFromSVBool(IC, II);
2754   case Intrinsic::aarch64_sve_dup:
2755     return instCombineSVEDup(IC, II);
2756   case Intrinsic::aarch64_sve_dup_x:
2757     return instCombineSVEDupX(IC, II);
2758   case Intrinsic::aarch64_sve_cmpne:
2759   case Intrinsic::aarch64_sve_cmpne_wide:
2760     return instCombineSVECmpNE(IC, II);
2761   case Intrinsic::aarch64_sve_rdffr:
2762     return instCombineRDFFR(IC, II);
2763   case Intrinsic::aarch64_sve_lasta:
2764   case Intrinsic::aarch64_sve_lastb:
2765     return instCombineSVELast(IC, II);
2766   case Intrinsic::aarch64_sve_clasta_n:
2767   case Intrinsic::aarch64_sve_clastb_n:
2768     return instCombineSVECondLast(IC, II);
2769   case Intrinsic::aarch64_sve_cntd:
2770     return instCombineSVECntElts(IC, II, 2);
2771   case Intrinsic::aarch64_sve_cntw:
2772     return instCombineSVECntElts(IC, II, 4);
2773   case Intrinsic::aarch64_sve_cnth:
2774     return instCombineSVECntElts(IC, II, 8);
2775   case Intrinsic::aarch64_sve_cntb:
2776     return instCombineSVECntElts(IC, II, 16);
2777   case Intrinsic::aarch64_sve_ptest_any:
2778   case Intrinsic::aarch64_sve_ptest_first:
2779   case Intrinsic::aarch64_sve_ptest_last:
2780     return instCombineSVEPTest(IC, II);
2781   case Intrinsic::aarch64_sve_fadd:
2782     return instCombineSVEVectorFAdd(IC, II);
2783   case Intrinsic::aarch64_sve_fadd_u:
2784     return instCombineSVEVectorFAddU(IC, II);
2785   case Intrinsic::aarch64_sve_fmul_u:
2786     return instCombineSVEVectorBinOp(IC, II);
2787   case Intrinsic::aarch64_sve_fsub:
2788     return instCombineSVEVectorFSub(IC, II);
2789   case Intrinsic::aarch64_sve_fsub_u:
2790     return instCombineSVEVectorFSubU(IC, II);
2791   case Intrinsic::aarch64_sve_add:
2792     return instCombineSVEVectorAdd(IC, II);
2793   case Intrinsic::aarch64_sve_add_u:
2794     return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul_u,
2795                                              Intrinsic::aarch64_sve_mla_u>(
2796         IC, II, true);
2797   case Intrinsic::aarch64_sve_sub:
2798     return instCombineSVEVectorSub(IC, II);
2799   case Intrinsic::aarch64_sve_sub_u:
2800     return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul_u,
2801                                              Intrinsic::aarch64_sve_mls_u>(
2802         IC, II, true);
2803   case Intrinsic::aarch64_sve_tbl:
2804     return instCombineSVETBL(IC, II);
2805   case Intrinsic::aarch64_sve_uunpkhi:
2806   case Intrinsic::aarch64_sve_uunpklo:
2807   case Intrinsic::aarch64_sve_sunpkhi:
2808   case Intrinsic::aarch64_sve_sunpklo:
2809     return instCombineSVEUnpack(IC, II);
2810   case Intrinsic::aarch64_sve_uzp1:
2811     return instCombineSVEUzp1(IC, II);
2812   case Intrinsic::aarch64_sve_zip1:
2813   case Intrinsic::aarch64_sve_zip2:
2814     return instCombineSVEZip(IC, II);
2815   case Intrinsic::aarch64_sve_ld1_gather_index:
2816     return instCombineLD1GatherIndex(IC, II);
2817   case Intrinsic::aarch64_sve_st1_scatter_index:
2818     return instCombineST1ScatterIndex(IC, II);
2819   case Intrinsic::aarch64_sve_ld1:
2820     return instCombineSVELD1(IC, II, DL);
2821   case Intrinsic::aarch64_sve_st1:
2822     return instCombineSVEST1(IC, II, DL);
2823   case Intrinsic::aarch64_sve_sdiv:
2824     return instCombineSVESDIV(IC, II);
2825   case Intrinsic::aarch64_sve_sel:
2826     return instCombineSVESel(IC, II);
2827   case Intrinsic::aarch64_sve_srshl:
2828     return instCombineSVESrshl(IC, II);
2829   case Intrinsic::aarch64_sve_dupq_lane:
2830     return instCombineSVEDupqLane(IC, II);
2831   case Intrinsic::aarch64_sve_insr:
2832     return instCombineSVEInsr(IC, II);
2833   case Intrinsic::aarch64_sve_ptrue:
2834     return instCombinePTrue(IC, II);
2835   case Intrinsic::aarch64_sve_uxtb:
2836     return instCombineSVEUxt(IC, II, 8);
2837   case Intrinsic::aarch64_sve_uxth:
2838     return instCombineSVEUxt(IC, II, 16);
2839   case Intrinsic::aarch64_sve_uxtw:
2840     return instCombineSVEUxt(IC, II, 32);
2841   case Intrinsic::aarch64_sme_in_streaming_mode:
2842     return instCombineInStreamingMode(IC, II);
2843   }
2844 
2845   return std::nullopt;
2846 }
2847 
simplifyDemandedVectorEltsIntrinsic(InstCombiner & IC,IntrinsicInst & II,APInt OrigDemandedElts,APInt & UndefElts,APInt & UndefElts2,APInt & UndefElts3,std::function<void (Instruction *,unsigned,APInt,APInt &)> SimplifyAndSetOp) const2848 std::optional<Value *> AArch64TTIImpl::simplifyDemandedVectorEltsIntrinsic(
2849     InstCombiner &IC, IntrinsicInst &II, APInt OrigDemandedElts,
2850     APInt &UndefElts, APInt &UndefElts2, APInt &UndefElts3,
2851     std::function<void(Instruction *, unsigned, APInt, APInt &)>
2852         SimplifyAndSetOp) const {
2853   switch (II.getIntrinsicID()) {
2854   default:
2855     break;
2856   case Intrinsic::aarch64_neon_fcvtxn:
2857   case Intrinsic::aarch64_neon_rshrn:
2858   case Intrinsic::aarch64_neon_sqrshrn:
2859   case Intrinsic::aarch64_neon_sqrshrun:
2860   case Intrinsic::aarch64_neon_sqshrn:
2861   case Intrinsic::aarch64_neon_sqshrun:
2862   case Intrinsic::aarch64_neon_sqxtn:
2863   case Intrinsic::aarch64_neon_sqxtun:
2864   case Intrinsic::aarch64_neon_uqrshrn:
2865   case Intrinsic::aarch64_neon_uqshrn:
2866   case Intrinsic::aarch64_neon_uqxtn:
2867     SimplifyAndSetOp(&II, 0, OrigDemandedElts, UndefElts);
2868     break;
2869   }
2870 
2871   return std::nullopt;
2872 }
2873 
enableScalableVectorization() const2874 bool AArch64TTIImpl::enableScalableVectorization() const {
2875   return ST->isSVEAvailable() || (ST->isSVEorStreamingSVEAvailable() &&
2876                                   EnableScalableAutovecInStreamingMode);
2877 }
2878 
2879 TypeSize
getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const2880 AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
2881   switch (K) {
2882   case TargetTransformInfo::RGK_Scalar:
2883     return TypeSize::getFixed(64);
2884   case TargetTransformInfo::RGK_FixedWidthVector:
2885     if (ST->useSVEForFixedLengthVectors() &&
2886         (ST->isSVEAvailable() || EnableFixedwidthAutovecInStreamingMode))
2887       return TypeSize::getFixed(
2888           std::max(ST->getMinSVEVectorSizeInBits(), 128u));
2889     else if (ST->isNeonAvailable())
2890       return TypeSize::getFixed(128);
2891     else
2892       return TypeSize::getFixed(0);
2893   case TargetTransformInfo::RGK_ScalableVector:
2894     if (ST->isSVEAvailable() || (ST->isSVEorStreamingSVEAvailable() &&
2895                                  EnableScalableAutovecInStreamingMode))
2896       return TypeSize::getScalable(128);
2897     else
2898       return TypeSize::getScalable(0);
2899   }
2900   llvm_unreachable("Unsupported register kind");
2901 }
2902 
isWideningInstruction(Type * DstTy,unsigned Opcode,ArrayRef<const Value * > Args,Type * SrcOverrideTy) const2903 bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
2904                                            ArrayRef<const Value *> Args,
2905                                            Type *SrcOverrideTy) const {
2906   // A helper that returns a vector type from the given type. The number of
2907   // elements in type Ty determines the vector width.
2908   auto toVectorTy = [&](Type *ArgTy) {
2909     return VectorType::get(ArgTy->getScalarType(),
2910                            cast<VectorType>(DstTy)->getElementCount());
2911   };
2912 
2913   // Exit early if DstTy is not a vector type whose elements are one of [i16,
2914   // i32, i64]. SVE doesn't generally have the same set of instructions to
2915   // perform an extend with the add/sub/mul. There are SMULLB style
2916   // instructions, but they operate on top/bottom, requiring some sort of lane
2917   // interleaving to be used with zext/sext.
2918   unsigned DstEltSize = DstTy->getScalarSizeInBits();
2919   if (!useNeonVector(DstTy) || Args.size() != 2 ||
2920       (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64))
2921     return false;
2922 
2923   // Determine if the operation has a widening variant. We consider both the
2924   // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
2925   // instructions.
2926   //
2927   // TODO: Add additional widening operations (e.g., shl, etc.) once we
2928   //       verify that their extending operands are eliminated during code
2929   //       generation.
2930   Type *SrcTy = SrcOverrideTy;
2931   switch (Opcode) {
2932   case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
2933   case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
2934     // The second operand needs to be an extend
2935     if (isa<SExtInst>(Args[1]) || isa<ZExtInst>(Args[1])) {
2936       if (!SrcTy)
2937         SrcTy =
2938             toVectorTy(cast<Instruction>(Args[1])->getOperand(0)->getType());
2939     } else
2940       return false;
2941     break;
2942   case Instruction::Mul: { // SMULL(2), UMULL(2)
2943     // Both operands need to be extends of the same type.
2944     if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) ||
2945         (isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) {
2946       if (!SrcTy)
2947         SrcTy =
2948             toVectorTy(cast<Instruction>(Args[0])->getOperand(0)->getType());
2949     } else if (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1])) {
2950       // If one of the operands is a Zext and the other has enough zero bits to
2951       // be treated as unsigned, we can still general a umull, meaning the zext
2952       // is free.
2953       KnownBits Known =
2954           computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL);
2955       if (Args[0]->getType()->getScalarSizeInBits() -
2956               Known.Zero.countLeadingOnes() >
2957           DstTy->getScalarSizeInBits() / 2)
2958         return false;
2959       if (!SrcTy)
2960         SrcTy = toVectorTy(Type::getIntNTy(DstTy->getContext(),
2961                                            DstTy->getScalarSizeInBits() / 2));
2962     } else
2963       return false;
2964     break;
2965   }
2966   default:
2967     return false;
2968   }
2969 
2970   // Legalize the destination type and ensure it can be used in a widening
2971   // operation.
2972   auto DstTyL = getTypeLegalizationCost(DstTy);
2973   if (!DstTyL.second.isVector() || DstEltSize != DstTy->getScalarSizeInBits())
2974     return false;
2975 
2976   // Legalize the source type and ensure it can be used in a widening
2977   // operation.
2978   assert(SrcTy && "Expected some SrcTy");
2979   auto SrcTyL = getTypeLegalizationCost(SrcTy);
2980   unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits();
2981   if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits())
2982     return false;
2983 
2984   // Get the total number of vector elements in the legalized types.
2985   InstructionCost NumDstEls =
2986       DstTyL.first * DstTyL.second.getVectorMinNumElements();
2987   InstructionCost NumSrcEls =
2988       SrcTyL.first * SrcTyL.second.getVectorMinNumElements();
2989 
2990   // Return true if the legalized types have the same number of vector elements
2991   // and the destination element type size is twice that of the source type.
2992   return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize;
2993 }
2994 
2995 // s/urhadd instructions implement the following pattern, making the
2996 // extends free:
2997 //   %x = add ((zext i8 -> i16), 1)
2998 //   %y = (zext i8 -> i16)
2999 //   trunc i16 (lshr (add %x, %y), 1) -> i8
3000 //
isExtPartOfAvgExpr(const Instruction * ExtUser,Type * Dst,Type * Src) const3001 bool AArch64TTIImpl::isExtPartOfAvgExpr(const Instruction *ExtUser, Type *Dst,
3002                                         Type *Src) const {
3003   // The source should be a legal vector type.
3004   if (!Src->isVectorTy() || !TLI->isTypeLegal(TLI->getValueType(DL, Src)) ||
3005       (Src->isScalableTy() && !ST->hasSVE2()))
3006     return false;
3007 
3008   if (ExtUser->getOpcode() != Instruction::Add || !ExtUser->hasOneUse())
3009     return false;
3010 
3011   // Look for trunc/shl/add before trying to match the pattern.
3012   const Instruction *Add = ExtUser;
3013   auto *AddUser =
3014       dyn_cast_or_null<Instruction>(Add->getUniqueUndroppableUser());
3015   if (AddUser && AddUser->getOpcode() == Instruction::Add)
3016     Add = AddUser;
3017 
3018   auto *Shr = dyn_cast_or_null<Instruction>(Add->getUniqueUndroppableUser());
3019   if (!Shr || Shr->getOpcode() != Instruction::LShr)
3020     return false;
3021 
3022   auto *Trunc = dyn_cast_or_null<Instruction>(Shr->getUniqueUndroppableUser());
3023   if (!Trunc || Trunc->getOpcode() != Instruction::Trunc ||
3024       Src->getScalarSizeInBits() !=
3025           cast<CastInst>(Trunc)->getDestTy()->getScalarSizeInBits())
3026     return false;
3027 
3028   // Try to match the whole pattern. Ext could be either the first or second
3029   // m_ZExtOrSExt matched.
3030   Instruction *Ex1, *Ex2;
3031   if (!(match(Add, m_c_Add(m_Instruction(Ex1),
3032                            m_c_Add(m_Instruction(Ex2), m_SpecificInt(1))))))
3033     return false;
3034 
3035   // Ensure both extends are of the same type
3036   if (match(Ex1, m_ZExtOrSExt(m_Value())) &&
3037       Ex1->getOpcode() == Ex2->getOpcode())
3038     return true;
3039 
3040   return false;
3041 }
3042 
getCastInstrCost(unsigned Opcode,Type * Dst,Type * Src,TTI::CastContextHint CCH,TTI::TargetCostKind CostKind,const Instruction * I) const3043 InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
3044                                                  Type *Src,
3045                                                  TTI::CastContextHint CCH,
3046                                                  TTI::TargetCostKind CostKind,
3047                                                  const Instruction *I) const {
3048   int ISD = TLI->InstructionOpcodeToISD(Opcode);
3049   assert(ISD && "Invalid opcode");
3050   // If the cast is observable, and it is used by a widening instruction (e.g.,
3051   // uaddl, saddw, etc.), it may be free.
3052   if (I && I->hasOneUser()) {
3053     auto *SingleUser = cast<Instruction>(*I->user_begin());
3054     SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
3055     if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands, Src)) {
3056       // For adds only count the second operand as free if both operands are
3057       // extends but not the same operation. (i.e both operands are not free in
3058       // add(sext, zext)).
3059       if (SingleUser->getOpcode() == Instruction::Add) {
3060         if (I == SingleUser->getOperand(1) ||
3061             (isa<CastInst>(SingleUser->getOperand(1)) &&
3062              cast<CastInst>(SingleUser->getOperand(1))->getOpcode() == Opcode))
3063           return 0;
3064       } else // Others are free so long as isWideningInstruction returned true.
3065         return 0;
3066     }
3067 
3068     // The cast will be free for the s/urhadd instructions
3069     if ((isa<ZExtInst>(I) || isa<SExtInst>(I)) &&
3070         isExtPartOfAvgExpr(SingleUser, Dst, Src))
3071       return 0;
3072   }
3073 
3074   // TODO: Allow non-throughput costs that aren't binary.
3075   auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost {
3076     if (CostKind != TTI::TCK_RecipThroughput)
3077       return Cost == 0 ? 0 : 1;
3078     return Cost;
3079   };
3080 
3081   EVT SrcTy = TLI->getValueType(DL, Src);
3082   EVT DstTy = TLI->getValueType(DL, Dst);
3083 
3084   if (!SrcTy.isSimple() || !DstTy.isSimple())
3085     return AdjustCost(
3086         BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
3087 
3088   static const TypeConversionCostTblEntry BF16Tbl[] = {
3089       {ISD::FP_ROUND, MVT::bf16, MVT::f32, 1},     // bfcvt
3090       {ISD::FP_ROUND, MVT::bf16, MVT::f64, 1},     // bfcvt
3091       {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f32, 1}, // bfcvtn
3092       {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f32, 2}, // bfcvtn+bfcvtn2
3093       {ISD::FP_ROUND, MVT::v2bf16, MVT::v2f64, 2}, // bfcvtn+fcvtn
3094       {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f64, 3}, // fcvtn+fcvtl2+bfcvtn
3095       {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f64, 6}, // 2 * fcvtn+fcvtn2+bfcvtn
3096   };
3097 
3098   if (ST->hasBF16())
3099     if (const auto *Entry = ConvertCostTableLookup(
3100             BF16Tbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
3101       return AdjustCost(Entry->Cost);
3102 
3103   // Symbolic constants for the SVE sitofp/uitofp entries in the table below
3104   // The cost of unpacking twice is artificially increased for now in order
3105   // to avoid regressions against NEON, which will use tbl instructions directly
3106   // instead of multiple layers of [s|u]unpk[lo|hi].
3107   // We use the unpacks in cases where the destination type is illegal and
3108   // requires splitting of the input, even if the input type itself is legal.
3109   const unsigned int SVE_EXT_COST = 1;
3110   const unsigned int SVE_FCVT_COST = 1;
3111   const unsigned int SVE_UNPACK_ONCE = 4;
3112   const unsigned int SVE_UNPACK_TWICE = 16;
3113 
3114   static const TypeConversionCostTblEntry ConversionTbl[] = {
3115       {ISD::TRUNCATE, MVT::v2i8, MVT::v2i64, 1},    // xtn
3116       {ISD::TRUNCATE, MVT::v2i16, MVT::v2i64, 1},   // xtn
3117       {ISD::TRUNCATE, MVT::v2i32, MVT::v2i64, 1},   // xtn
3118       {ISD::TRUNCATE, MVT::v4i8, MVT::v4i32, 1},    // xtn
3119       {ISD::TRUNCATE, MVT::v4i8, MVT::v4i64, 3},    // 2 xtn + 1 uzp1
3120       {ISD::TRUNCATE, MVT::v4i16, MVT::v4i32, 1},   // xtn
3121       {ISD::TRUNCATE, MVT::v4i16, MVT::v4i64, 2},   // 1 uzp1 + 1 xtn
3122       {ISD::TRUNCATE, MVT::v4i32, MVT::v4i64, 1},   // 1 uzp1
3123       {ISD::TRUNCATE, MVT::v8i8, MVT::v8i16, 1},    // 1 xtn
3124       {ISD::TRUNCATE, MVT::v8i8, MVT::v8i32, 2},    // 1 uzp1 + 1 xtn
3125       {ISD::TRUNCATE, MVT::v8i8, MVT::v8i64, 4},    // 3 x uzp1 + xtn
3126       {ISD::TRUNCATE, MVT::v8i16, MVT::v8i32, 1},   // 1 uzp1
3127       {ISD::TRUNCATE, MVT::v8i16, MVT::v8i64, 3},   // 3 x uzp1
3128       {ISD::TRUNCATE, MVT::v8i32, MVT::v8i64, 2},   // 2 x uzp1
3129       {ISD::TRUNCATE, MVT::v16i8, MVT::v16i16, 1},  // uzp1
3130       {ISD::TRUNCATE, MVT::v16i8, MVT::v16i32, 3},  // (2 + 1) x uzp1
3131       {ISD::TRUNCATE, MVT::v16i8, MVT::v16i64, 7},  // (4 + 2 + 1) x uzp1
3132       {ISD::TRUNCATE, MVT::v16i16, MVT::v16i32, 2}, // 2 x uzp1
3133       {ISD::TRUNCATE, MVT::v16i16, MVT::v16i64, 6}, // (4 + 2) x uzp1
3134       {ISD::TRUNCATE, MVT::v16i32, MVT::v16i64, 4}, // 4 x uzp1
3135 
3136       // Truncations on nxvmiN
3137       {ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i8, 2},
3138       {ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i16, 2},
3139       {ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i32, 2},
3140       {ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i64, 2},
3141       {ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i8, 2},
3142       {ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i16, 2},
3143       {ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i32, 2},
3144       {ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i64, 5},
3145       {ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i8, 2},
3146       {ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i16, 2},
3147       {ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i32, 5},
3148       {ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i64, 11},
3149       {ISD::TRUNCATE, MVT::nxv16i1, MVT::nxv16i8, 2},
3150       {ISD::TRUNCATE, MVT::nxv2i8, MVT::nxv2i16, 0},
3151       {ISD::TRUNCATE, MVT::nxv2i8, MVT::nxv2i32, 0},
3152       {ISD::TRUNCATE, MVT::nxv2i8, MVT::nxv2i64, 0},
3153       {ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i32, 0},
3154       {ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i64, 0},
3155       {ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 0},
3156       {ISD::TRUNCATE, MVT::nxv4i8, MVT::nxv4i16, 0},
3157       {ISD::TRUNCATE, MVT::nxv4i8, MVT::nxv4i32, 0},
3158       {ISD::TRUNCATE, MVT::nxv4i8, MVT::nxv4i64, 1},
3159       {ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 0},
3160       {ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i64, 1},
3161       {ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 1},
3162       {ISD::TRUNCATE, MVT::nxv8i8, MVT::nxv8i16, 0},
3163       {ISD::TRUNCATE, MVT::nxv8i8, MVT::nxv8i32, 1},
3164       {ISD::TRUNCATE, MVT::nxv8i8, MVT::nxv8i64, 3},
3165       {ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 1},
3166       {ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i64, 3},
3167       {ISD::TRUNCATE, MVT::nxv16i8, MVT::nxv16i16, 1},
3168       {ISD::TRUNCATE, MVT::nxv16i8, MVT::nxv16i32, 3},
3169       {ISD::TRUNCATE, MVT::nxv16i8, MVT::nxv16i64, 7},
3170 
3171       // The number of shll instructions for the extension.
3172       {ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i16, 3},
3173       {ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i16, 3},
3174       {ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i32, 2},
3175       {ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i32, 2},
3176       {ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i8, 3},
3177       {ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i8, 3},
3178       {ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i16, 2},
3179       {ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i16, 2},
3180       {ISD::SIGN_EXTEND, MVT::v8i64, MVT::v8i8, 7},
3181       {ISD::ZERO_EXTEND, MVT::v8i64, MVT::v8i8, 7},
3182       {ISD::SIGN_EXTEND, MVT::v8i64, MVT::v8i16, 6},
3183       {ISD::ZERO_EXTEND, MVT::v8i64, MVT::v8i16, 6},
3184       {ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i8, 2},
3185       {ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i8, 2},
3186       {ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i8, 6},
3187       {ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i8, 6},
3188 
3189       // FP Ext and trunc
3190       {ISD::FP_EXTEND, MVT::f64, MVT::f32, 1},     // fcvt
3191       {ISD::FP_EXTEND, MVT::v2f64, MVT::v2f32, 1}, // fcvtl
3192       {ISD::FP_EXTEND, MVT::v4f64, MVT::v4f32, 2}, // fcvtl+fcvtl2
3193       //   FP16
3194       {ISD::FP_EXTEND, MVT::f32, MVT::f16, 1},     // fcvt
3195       {ISD::FP_EXTEND, MVT::f64, MVT::f16, 1},     // fcvt
3196       {ISD::FP_EXTEND, MVT::v4f32, MVT::v4f16, 1}, // fcvtl
3197       {ISD::FP_EXTEND, MVT::v8f32, MVT::v8f16, 2}, // fcvtl+fcvtl2
3198       {ISD::FP_EXTEND, MVT::v2f64, MVT::v2f16, 2}, // fcvtl+fcvtl
3199       {ISD::FP_EXTEND, MVT::v4f64, MVT::v4f16, 3}, // fcvtl+fcvtl2+fcvtl
3200       {ISD::FP_EXTEND, MVT::v8f64, MVT::v8f16, 6}, // 2 * fcvtl+fcvtl2+fcvtl
3201       //   BF16 (uses shift)
3202       {ISD::FP_EXTEND, MVT::f32, MVT::bf16, 1},     // shl
3203       {ISD::FP_EXTEND, MVT::f64, MVT::bf16, 2},     // shl+fcvt
3204       {ISD::FP_EXTEND, MVT::v4f32, MVT::v4bf16, 1}, // shll
3205       {ISD::FP_EXTEND, MVT::v8f32, MVT::v8bf16, 2}, // shll+shll2
3206       {ISD::FP_EXTEND, MVT::v2f64, MVT::v2bf16, 2}, // shll+fcvtl
3207       {ISD::FP_EXTEND, MVT::v4f64, MVT::v4bf16, 3}, // shll+fcvtl+fcvtl2
3208       {ISD::FP_EXTEND, MVT::v8f64, MVT::v8bf16, 6}, // 2 * shll+fcvtl+fcvtl2
3209       // FP Ext and trunc
3210       {ISD::FP_ROUND, MVT::f32, MVT::f64, 1},     // fcvt
3211       {ISD::FP_ROUND, MVT::v2f32, MVT::v2f64, 1}, // fcvtn
3212       {ISD::FP_ROUND, MVT::v4f32, MVT::v4f64, 2}, // fcvtn+fcvtn2
3213       //   FP16
3214       {ISD::FP_ROUND, MVT::f16, MVT::f32, 1},     // fcvt
3215       {ISD::FP_ROUND, MVT::f16, MVT::f64, 1},     // fcvt
3216       {ISD::FP_ROUND, MVT::v4f16, MVT::v4f32, 1}, // fcvtn
3217       {ISD::FP_ROUND, MVT::v8f16, MVT::v8f32, 2}, // fcvtn+fcvtn2
3218       {ISD::FP_ROUND, MVT::v2f16, MVT::v2f64, 2}, // fcvtn+fcvtn
3219       {ISD::FP_ROUND, MVT::v4f16, MVT::v4f64, 3}, // fcvtn+fcvtn2+fcvtn
3220       {ISD::FP_ROUND, MVT::v8f16, MVT::v8f64, 6}, // 2 * fcvtn+fcvtn2+fcvtn
3221       //   BF16 (more complex, with +bf16 is handled above)
3222       {ISD::FP_ROUND, MVT::bf16, MVT::f32, 8}, // Expansion is ~8 insns
3223       {ISD::FP_ROUND, MVT::bf16, MVT::f64, 9}, // fcvtn + above
3224       {ISD::FP_ROUND, MVT::v2bf16, MVT::v2f32, 8},
3225       {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f32, 8},
3226       {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f32, 15},
3227       {ISD::FP_ROUND, MVT::v2bf16, MVT::v2f64, 9},
3228       {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f64, 10},
3229       {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f64, 19},
3230 
3231       // LowerVectorINT_TO_FP:
3232       {ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1},
3233       {ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 1},
3234       {ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i64, 1},
3235       {ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i32, 1},
3236       {ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 1},
3237       {ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i64, 1},
3238 
3239       // SVE: to nxv2f16
3240       {ISD::SINT_TO_FP, MVT::nxv2f16, MVT::nxv2i8,
3241        SVE_EXT_COST + SVE_FCVT_COST},
3242       {ISD::SINT_TO_FP, MVT::nxv2f16, MVT::nxv2i16, SVE_FCVT_COST},
3243       {ISD::SINT_TO_FP, MVT::nxv2f16, MVT::nxv2i32, SVE_FCVT_COST},
3244       {ISD::SINT_TO_FP, MVT::nxv2f16, MVT::nxv2i64, SVE_FCVT_COST},
3245       {ISD::UINT_TO_FP, MVT::nxv2f16, MVT::nxv2i8,
3246        SVE_EXT_COST + SVE_FCVT_COST},
3247       {ISD::UINT_TO_FP, MVT::nxv2f16, MVT::nxv2i16, SVE_FCVT_COST},
3248       {ISD::UINT_TO_FP, MVT::nxv2f16, MVT::nxv2i32, SVE_FCVT_COST},
3249       {ISD::UINT_TO_FP, MVT::nxv2f16, MVT::nxv2i64, SVE_FCVT_COST},
3250 
3251       // SVE: to nxv4f16
3252       {ISD::SINT_TO_FP, MVT::nxv4f16, MVT::nxv4i8,
3253        SVE_EXT_COST + SVE_FCVT_COST},
3254       {ISD::SINT_TO_FP, MVT::nxv4f16, MVT::nxv4i16, SVE_FCVT_COST},
3255       {ISD::SINT_TO_FP, MVT::nxv4f16, MVT::nxv4i32, SVE_FCVT_COST},
3256       {ISD::UINT_TO_FP, MVT::nxv4f16, MVT::nxv4i8,
3257        SVE_EXT_COST + SVE_FCVT_COST},
3258       {ISD::UINT_TO_FP, MVT::nxv4f16, MVT::nxv4i16, SVE_FCVT_COST},
3259       {ISD::UINT_TO_FP, MVT::nxv4f16, MVT::nxv4i32, SVE_FCVT_COST},
3260 
3261       // SVE: to nxv8f16
3262       {ISD::SINT_TO_FP, MVT::nxv8f16, MVT::nxv8i8,
3263        SVE_EXT_COST + SVE_FCVT_COST},
3264       {ISD::SINT_TO_FP, MVT::nxv8f16, MVT::nxv8i16, SVE_FCVT_COST},
3265       {ISD::UINT_TO_FP, MVT::nxv8f16, MVT::nxv8i8,
3266        SVE_EXT_COST + SVE_FCVT_COST},
3267       {ISD::UINT_TO_FP, MVT::nxv8f16, MVT::nxv8i16, SVE_FCVT_COST},
3268 
3269       // SVE: to nxv16f16
3270       {ISD::SINT_TO_FP, MVT::nxv16f16, MVT::nxv16i8,
3271        SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3272       {ISD::UINT_TO_FP, MVT::nxv16f16, MVT::nxv16i8,
3273        SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3274 
3275       // Complex: to v2f32
3276       {ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i8, 3},
3277       {ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i16, 3},
3278       {ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i8, 3},
3279       {ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i16, 3},
3280 
3281       // SVE: to nxv2f32
3282       {ISD::SINT_TO_FP, MVT::nxv2f32, MVT::nxv2i8,
3283        SVE_EXT_COST + SVE_FCVT_COST},
3284       {ISD::SINT_TO_FP, MVT::nxv2f32, MVT::nxv2i16, SVE_FCVT_COST},
3285       {ISD::SINT_TO_FP, MVT::nxv2f32, MVT::nxv2i32, SVE_FCVT_COST},
3286       {ISD::SINT_TO_FP, MVT::nxv2f32, MVT::nxv2i64, SVE_FCVT_COST},
3287       {ISD::UINT_TO_FP, MVT::nxv2f32, MVT::nxv2i8,
3288        SVE_EXT_COST + SVE_FCVT_COST},
3289       {ISD::UINT_TO_FP, MVT::nxv2f32, MVT::nxv2i16, SVE_FCVT_COST},
3290       {ISD::UINT_TO_FP, MVT::nxv2f32, MVT::nxv2i32, SVE_FCVT_COST},
3291       {ISD::UINT_TO_FP, MVT::nxv2f32, MVT::nxv2i64, SVE_FCVT_COST},
3292 
3293       // Complex: to v4f32
3294       {ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i8, 4},
3295       {ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i16, 2},
3296       {ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8, 3},
3297       {ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i16, 2},
3298 
3299       // SVE: to nxv4f32
3300       {ISD::SINT_TO_FP, MVT::nxv4f32, MVT::nxv4i8,
3301        SVE_EXT_COST + SVE_FCVT_COST},
3302       {ISD::SINT_TO_FP, MVT::nxv4f32, MVT::nxv4i16, SVE_FCVT_COST},
3303       {ISD::SINT_TO_FP, MVT::nxv4f32, MVT::nxv4i32, SVE_FCVT_COST},
3304       {ISD::UINT_TO_FP, MVT::nxv4f32, MVT::nxv4i8,
3305        SVE_EXT_COST + SVE_FCVT_COST},
3306       {ISD::UINT_TO_FP, MVT::nxv4f32, MVT::nxv4i16, SVE_FCVT_COST},
3307       {ISD::SINT_TO_FP, MVT::nxv4f32, MVT::nxv4i32, SVE_FCVT_COST},
3308 
3309       // Complex: to v8f32
3310       {ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8, 10},
3311       {ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i16, 4},
3312       {ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8, 10},
3313       {ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i16, 4},
3314 
3315       // SVE: to nxv8f32
3316       {ISD::SINT_TO_FP, MVT::nxv8f32, MVT::nxv8i8,
3317        SVE_EXT_COST + SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3318       {ISD::SINT_TO_FP, MVT::nxv8f32, MVT::nxv8i16,
3319        SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3320       {ISD::UINT_TO_FP, MVT::nxv8f32, MVT::nxv8i8,
3321        SVE_EXT_COST + SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3322       {ISD::UINT_TO_FP, MVT::nxv8f32, MVT::nxv8i16,
3323        SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3324 
3325       // SVE: to nxv16f32
3326       {ISD::SINT_TO_FP, MVT::nxv16f32, MVT::nxv16i8,
3327        SVE_UNPACK_TWICE + 4 * SVE_FCVT_COST},
3328       {ISD::UINT_TO_FP, MVT::nxv16f32, MVT::nxv16i8,
3329        SVE_UNPACK_TWICE + 4 * SVE_FCVT_COST},
3330 
3331       // Complex: to v16f32
3332       {ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i8, 21},
3333       {ISD::UINT_TO_FP, MVT::v16f32, MVT::v16i8, 21},
3334 
3335       // Complex: to v2f64
3336       {ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i8, 4},
3337       {ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i16, 4},
3338       {ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i32, 2},
3339       {ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i8, 4},
3340       {ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i16, 4},
3341       {ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i32, 2},
3342 
3343       // SVE: to nxv2f64
3344       {ISD::SINT_TO_FP, MVT::nxv2f64, MVT::nxv2i8,
3345        SVE_EXT_COST + SVE_FCVT_COST},
3346       {ISD::SINT_TO_FP, MVT::nxv2f64, MVT::nxv2i16, SVE_FCVT_COST},
3347       {ISD::SINT_TO_FP, MVT::nxv2f64, MVT::nxv2i32, SVE_FCVT_COST},
3348       {ISD::SINT_TO_FP, MVT::nxv2f64, MVT::nxv2i64, SVE_FCVT_COST},
3349       {ISD::UINT_TO_FP, MVT::nxv2f64, MVT::nxv2i8,
3350        SVE_EXT_COST + SVE_FCVT_COST},
3351       {ISD::UINT_TO_FP, MVT::nxv2f64, MVT::nxv2i16, SVE_FCVT_COST},
3352       {ISD::UINT_TO_FP, MVT::nxv2f64, MVT::nxv2i32, SVE_FCVT_COST},
3353       {ISD::UINT_TO_FP, MVT::nxv2f64, MVT::nxv2i64, SVE_FCVT_COST},
3354 
3355       // Complex: to v4f64
3356       {ISD::SINT_TO_FP, MVT::v4f64, MVT::v4i32, 4},
3357       {ISD::UINT_TO_FP, MVT::v4f64, MVT::v4i32, 4},
3358 
3359       // SVE: to nxv4f64
3360       {ISD::SINT_TO_FP, MVT::nxv4f64, MVT::nxv4i8,
3361        SVE_EXT_COST + SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3362       {ISD::SINT_TO_FP, MVT::nxv4f64, MVT::nxv4i16,
3363        SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3364       {ISD::SINT_TO_FP, MVT::nxv4f64, MVT::nxv4i32,
3365        SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3366       {ISD::UINT_TO_FP, MVT::nxv4f64, MVT::nxv4i8,
3367        SVE_EXT_COST + SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3368       {ISD::UINT_TO_FP, MVT::nxv4f64, MVT::nxv4i16,
3369        SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3370       {ISD::UINT_TO_FP, MVT::nxv4f64, MVT::nxv4i32,
3371        SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3372 
3373       // SVE: to nxv8f64
3374       {ISD::SINT_TO_FP, MVT::nxv8f64, MVT::nxv8i8,
3375        SVE_EXT_COST + SVE_UNPACK_TWICE + 4 * SVE_FCVT_COST},
3376       {ISD::SINT_TO_FP, MVT::nxv8f64, MVT::nxv8i16,
3377        SVE_UNPACK_TWICE + 4 * SVE_FCVT_COST},
3378       {ISD::UINT_TO_FP, MVT::nxv8f64, MVT::nxv8i8,
3379        SVE_EXT_COST + SVE_UNPACK_TWICE + 4 * SVE_FCVT_COST},
3380       {ISD::UINT_TO_FP, MVT::nxv8f64, MVT::nxv8i16,
3381        SVE_UNPACK_TWICE + 4 * SVE_FCVT_COST},
3382 
3383       // LowerVectorFP_TO_INT
3384       {ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f32, 1},
3385       {ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f32, 1},
3386       {ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f64, 1},
3387       {ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f32, 1},
3388       {ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f32, 1},
3389       {ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f64, 1},
3390 
3391       // Complex, from v2f32: legal type is v2i32 (no cost) or v2i64 (1 ext).
3392       {ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f32, 2},
3393       {ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f32, 1},
3394       {ISD::FP_TO_SINT, MVT::v2i8, MVT::v2f32, 1},
3395       {ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f32, 2},
3396       {ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f32, 1},
3397       {ISD::FP_TO_UINT, MVT::v2i8, MVT::v2f32, 1},
3398 
3399       // Complex, from v4f32: legal type is v4i16, 1 narrowing => ~2
3400       {ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f32, 2},
3401       {ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f32, 2},
3402       {ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 2},
3403       {ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f32, 2},
3404 
3405       // Complex, from nxv2f32.
3406       {ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f32, 1},
3407       {ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f32, 1},
3408       {ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f32, 1},
3409       {ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f32, 1},
3410       {ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f32, 1},
3411       {ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f32, 1},
3412       {ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f32, 1},
3413       {ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f32, 1},
3414 
3415       // Complex, from v2f64: legal type is v2i32, 1 narrowing => ~2.
3416       {ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f64, 2},
3417       {ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f64, 2},
3418       {ISD::FP_TO_SINT, MVT::v2i8, MVT::v2f64, 2},
3419       {ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f64, 2},
3420       {ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f64, 2},
3421       {ISD::FP_TO_UINT, MVT::v2i8, MVT::v2f64, 2},
3422 
3423       // Complex, from nxv2f64.
3424       {ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f64, 1},
3425       {ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f64, 1},
3426       {ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f64, 1},
3427       {ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f64, 1},
3428       {ISD::FP_TO_SINT, MVT::nxv2i1, MVT::nxv2f64, 1},
3429       {ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f64, 1},
3430       {ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f64, 1},
3431       {ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f64, 1},
3432       {ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f64, 1},
3433       {ISD::FP_TO_UINT, MVT::nxv2i1, MVT::nxv2f64, 1},
3434 
3435       // Complex, from nxv4f32.
3436       {ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f32, 4},
3437       {ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f32, 1},
3438       {ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f32, 1},
3439       {ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f32, 1},
3440       {ISD::FP_TO_SINT, MVT::nxv4i1, MVT::nxv4f32, 1},
3441       {ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f32, 4},
3442       {ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f32, 1},
3443       {ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f32, 1},
3444       {ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f32, 1},
3445       {ISD::FP_TO_UINT, MVT::nxv4i1, MVT::nxv4f32, 1},
3446 
3447       // Complex, from nxv8f64. Illegal -> illegal conversions not required.
3448       {ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f64, 7},
3449       {ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f64, 7},
3450       {ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f64, 7},
3451       {ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f64, 7},
3452 
3453       // Complex, from nxv4f64. Illegal -> illegal conversions not required.
3454       {ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f64, 3},
3455       {ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f64, 3},
3456       {ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f64, 3},
3457       {ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f64, 3},
3458       {ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f64, 3},
3459       {ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f64, 3},
3460 
3461       // Complex, from nxv8f32. Illegal -> illegal conversions not required.
3462       {ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f32, 3},
3463       {ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f32, 3},
3464       {ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f32, 3},
3465       {ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f32, 3},
3466 
3467       // Complex, from nxv8f16.
3468       {ISD::FP_TO_SINT, MVT::nxv8i64, MVT::nxv8f16, 10},
3469       {ISD::FP_TO_SINT, MVT::nxv8i32, MVT::nxv8f16, 4},
3470       {ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f16, 1},
3471       {ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f16, 1},
3472       {ISD::FP_TO_SINT, MVT::nxv8i1, MVT::nxv8f16, 1},
3473       {ISD::FP_TO_UINT, MVT::nxv8i64, MVT::nxv8f16, 10},
3474       {ISD::FP_TO_UINT, MVT::nxv8i32, MVT::nxv8f16, 4},
3475       {ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f16, 1},
3476       {ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f16, 1},
3477       {ISD::FP_TO_UINT, MVT::nxv8i1, MVT::nxv8f16, 1},
3478 
3479       // Complex, from nxv4f16.
3480       {ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f16, 4},
3481       {ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f16, 1},
3482       {ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f16, 1},
3483       {ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f16, 1},
3484       {ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f16, 4},
3485       {ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f16, 1},
3486       {ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f16, 1},
3487       {ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f16, 1},
3488 
3489       // Complex, from nxv2f16.
3490       {ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f16, 1},
3491       {ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f16, 1},
3492       {ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f16, 1},
3493       {ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f16, 1},
3494       {ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f16, 1},
3495       {ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f16, 1},
3496       {ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f16, 1},
3497       {ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f16, 1},
3498 
3499       // Truncate from nxvmf32 to nxvmf16.
3500       {ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f32, 1},
3501       {ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1},
3502       {ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3},
3503 
3504       // Truncate from nxvmf64 to nxvmf16.
3505       {ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1},
3506       {ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3},
3507       {ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7},
3508 
3509       // Truncate from nxvmf64 to nxvmf32.
3510       {ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1},
3511       {ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3},
3512       {ISD::FP_ROUND, MVT::nxv8f32, MVT::nxv8f64, 6},
3513 
3514       // Extend from nxvmf16 to nxvmf32.
3515       {ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2f16, 1},
3516       {ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1},
3517       {ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2},
3518 
3519       // Extend from nxvmf16 to nxvmf64.
3520       {ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1},
3521       {ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2},
3522       {ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4},
3523 
3524       // Extend from nxvmf32 to nxvmf64.
3525       {ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1},
3526       {ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2},
3527       {ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f32, 6},
3528 
3529       // Bitcasts from float to integer
3530       {ISD::BITCAST, MVT::nxv2f16, MVT::nxv2i16, 0},
3531       {ISD::BITCAST, MVT::nxv4f16, MVT::nxv4i16, 0},
3532       {ISD::BITCAST, MVT::nxv2f32, MVT::nxv2i32, 0},
3533 
3534       // Bitcasts from integer to float
3535       {ISD::BITCAST, MVT::nxv2i16, MVT::nxv2f16, 0},
3536       {ISD::BITCAST, MVT::nxv4i16, MVT::nxv4f16, 0},
3537       {ISD::BITCAST, MVT::nxv2i32, MVT::nxv2f32, 0},
3538 
3539       // Add cost for extending to illegal -too wide- scalable vectors.
3540       // zero/sign extend are implemented by multiple unpack operations,
3541       // where each operation has a cost of 1.
3542       {ISD::ZERO_EXTEND, MVT::nxv16i16, MVT::nxv16i8, 2},
3543       {ISD::ZERO_EXTEND, MVT::nxv16i32, MVT::nxv16i8, 6},
3544       {ISD::ZERO_EXTEND, MVT::nxv16i64, MVT::nxv16i8, 14},
3545       {ISD::ZERO_EXTEND, MVT::nxv8i32, MVT::nxv8i16, 2},
3546       {ISD::ZERO_EXTEND, MVT::nxv8i64, MVT::nxv8i16, 6},
3547       {ISD::ZERO_EXTEND, MVT::nxv4i64, MVT::nxv4i32, 2},
3548 
3549       {ISD::SIGN_EXTEND, MVT::nxv16i16, MVT::nxv16i8, 2},
3550       {ISD::SIGN_EXTEND, MVT::nxv16i32, MVT::nxv16i8, 6},
3551       {ISD::SIGN_EXTEND, MVT::nxv16i64, MVT::nxv16i8, 14},
3552       {ISD::SIGN_EXTEND, MVT::nxv8i32, MVT::nxv8i16, 2},
3553       {ISD::SIGN_EXTEND, MVT::nxv8i64, MVT::nxv8i16, 6},
3554       {ISD::SIGN_EXTEND, MVT::nxv4i64, MVT::nxv4i32, 2},
3555   };
3556 
3557   // We have to estimate a cost of fixed length operation upon
3558   // SVE registers(operations) with the number of registers required
3559   // for a fixed type to be represented upon SVE registers.
3560   EVT WiderTy = SrcTy.bitsGT(DstTy) ? SrcTy : DstTy;
3561   if (SrcTy.isFixedLengthVector() && DstTy.isFixedLengthVector() &&
3562       SrcTy.getVectorNumElements() == DstTy.getVectorNumElements() &&
3563       ST->useSVEForFixedLengthVectors(WiderTy)) {
3564     std::pair<InstructionCost, MVT> LT =
3565         getTypeLegalizationCost(WiderTy.getTypeForEVT(Dst->getContext()));
3566     unsigned NumElements =
3567         AArch64::SVEBitsPerBlock / LT.second.getScalarSizeInBits();
3568     return AdjustCost(
3569         LT.first *
3570         getCastInstrCost(
3571             Opcode, ScalableVectorType::get(Dst->getScalarType(), NumElements),
3572             ScalableVectorType::get(Src->getScalarType(), NumElements), CCH,
3573             CostKind, I));
3574   }
3575 
3576   if (const auto *Entry = ConvertCostTableLookup(
3577           ConversionTbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
3578     return AdjustCost(Entry->Cost);
3579 
3580   static const TypeConversionCostTblEntry FP16Tbl[] = {
3581       {ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f16, 1}, // fcvtzs
3582       {ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f16, 1},
3583       {ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f16, 1}, // fcvtzs
3584       {ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f16, 1},
3585       {ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f16, 2}, // fcvtl+fcvtzs
3586       {ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f16, 2},
3587       {ISD::FP_TO_SINT, MVT::v8i8, MVT::v8f16, 2}, // fcvtzs+xtn
3588       {ISD::FP_TO_UINT, MVT::v8i8, MVT::v8f16, 2},
3589       {ISD::FP_TO_SINT, MVT::v8i16, MVT::v8f16, 1}, // fcvtzs
3590       {ISD::FP_TO_UINT, MVT::v8i16, MVT::v8f16, 1},
3591       {ISD::FP_TO_SINT, MVT::v8i32, MVT::v8f16, 4}, // 2*fcvtl+2*fcvtzs
3592       {ISD::FP_TO_UINT, MVT::v8i32, MVT::v8f16, 4},
3593       {ISD::FP_TO_SINT, MVT::v16i8, MVT::v16f16, 3}, // 2*fcvtzs+xtn
3594       {ISD::FP_TO_UINT, MVT::v16i8, MVT::v16f16, 3},
3595       {ISD::FP_TO_SINT, MVT::v16i16, MVT::v16f16, 2}, // 2*fcvtzs
3596       {ISD::FP_TO_UINT, MVT::v16i16, MVT::v16f16, 2},
3597       {ISD::FP_TO_SINT, MVT::v16i32, MVT::v16f16, 8}, // 4*fcvtl+4*fcvtzs
3598       {ISD::FP_TO_UINT, MVT::v16i32, MVT::v16f16, 8},
3599       {ISD::UINT_TO_FP, MVT::v8f16, MVT::v8i8, 2},   // ushll + ucvtf
3600       {ISD::SINT_TO_FP, MVT::v8f16, MVT::v8i8, 2},   // sshll + scvtf
3601       {ISD::UINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * ushl(2) + 2 * ucvtf
3602       {ISD::SINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * sshl(2) + 2 * scvtf
3603   };
3604 
3605   if (ST->hasFullFP16())
3606     if (const auto *Entry = ConvertCostTableLookup(
3607             FP16Tbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
3608       return AdjustCost(Entry->Cost);
3609 
3610   // INT_TO_FP of i64->f32 will scalarize, which is required to avoid
3611   // double-rounding issues.
3612   if ((ISD == ISD::SINT_TO_FP || ISD == ISD::UINT_TO_FP) &&
3613       DstTy.getScalarType() == MVT::f32 && SrcTy.getScalarSizeInBits() > 32 &&
3614       isa<FixedVectorType>(Dst) && isa<FixedVectorType>(Src))
3615     return AdjustCost(
3616         cast<FixedVectorType>(Dst)->getNumElements() *
3617             getCastInstrCost(Opcode, Dst->getScalarType(), Src->getScalarType(),
3618                              CCH, CostKind) +
3619         BaseT::getScalarizationOverhead(cast<FixedVectorType>(Src), false, true,
3620                                         CostKind) +
3621         BaseT::getScalarizationOverhead(cast<FixedVectorType>(Dst), true, false,
3622                                         CostKind));
3623 
3624   if ((ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND) &&
3625       CCH == TTI::CastContextHint::Masked &&
3626       ST->isSVEorStreamingSVEAvailable() &&
3627       TLI->getTypeAction(Src->getContext(), SrcTy) ==
3628           TargetLowering::TypePromoteInteger &&
3629       TLI->getTypeAction(Dst->getContext(), DstTy) ==
3630           TargetLowering::TypeSplitVector) {
3631     // The standard behaviour in the backend for these cases is to split the
3632     // extend up into two parts:
3633     //  1. Perform an extending load or masked load up to the legal type.
3634     //  2. Extend the loaded data to the final type.
3635     std::pair<InstructionCost, MVT> SrcLT = getTypeLegalizationCost(Src);
3636     Type *LegalTy = EVT(SrcLT.second).getTypeForEVT(Src->getContext());
3637     InstructionCost Part1 = AArch64TTIImpl::getCastInstrCost(
3638         Opcode, LegalTy, Src, CCH, CostKind, I);
3639     InstructionCost Part2 = AArch64TTIImpl::getCastInstrCost(
3640         Opcode, Dst, LegalTy, TTI::CastContextHint::None, CostKind, I);
3641     return Part1 + Part2;
3642   }
3643 
3644   // The BasicTTIImpl version only deals with CCH==TTI::CastContextHint::Normal,
3645   // but we also want to include the TTI::CastContextHint::Masked case too.
3646   if ((ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND) &&
3647       CCH == TTI::CastContextHint::Masked &&
3648       ST->isSVEorStreamingSVEAvailable() && TLI->isTypeLegal(DstTy))
3649     CCH = TTI::CastContextHint::Normal;
3650 
3651   return AdjustCost(
3652       BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
3653 }
3654 
3655 InstructionCost
getExtractWithExtendCost(unsigned Opcode,Type * Dst,VectorType * VecTy,unsigned Index,TTI::TargetCostKind CostKind) const3656 AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst,
3657                                          VectorType *VecTy, unsigned Index,
3658                                          TTI::TargetCostKind CostKind) const {
3659 
3660   // Make sure we were given a valid extend opcode.
3661   assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) &&
3662          "Invalid opcode");
3663 
3664   // We are extending an element we extract from a vector, so the source type
3665   // of the extend is the element type of the vector.
3666   auto *Src = VecTy->getElementType();
3667 
3668   // Sign- and zero-extends are for integer types only.
3669   assert(isa<IntegerType>(Dst) && isa<IntegerType>(Src) && "Invalid type");
3670 
3671   // Get the cost for the extract. We compute the cost (if any) for the extend
3672   // below.
3673   InstructionCost Cost = getVectorInstrCost(Instruction::ExtractElement, VecTy,
3674                                             CostKind, Index, nullptr, nullptr);
3675 
3676   // Legalize the types.
3677   auto VecLT = getTypeLegalizationCost(VecTy);
3678   auto DstVT = TLI->getValueType(DL, Dst);
3679   auto SrcVT = TLI->getValueType(DL, Src);
3680 
3681   // If the resulting type is still a vector and the destination type is legal,
3682   // we may get the extension for free. If not, get the default cost for the
3683   // extend.
3684   if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT))
3685     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
3686                                    CostKind);
3687 
3688   // The destination type should be larger than the element type. If not, get
3689   // the default cost for the extend.
3690   if (DstVT.getFixedSizeInBits() < SrcVT.getFixedSizeInBits())
3691     return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
3692                                    CostKind);
3693 
3694   switch (Opcode) {
3695   default:
3696     llvm_unreachable("Opcode should be either SExt or ZExt");
3697 
3698   // For sign-extends, we only need a smov, which performs the extension
3699   // automatically.
3700   case Instruction::SExt:
3701     return Cost;
3702 
3703   // For zero-extends, the extend is performed automatically by a umov unless
3704   // the destination type is i64 and the element type is i8 or i16.
3705   case Instruction::ZExt:
3706     if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u)
3707       return Cost;
3708   }
3709 
3710   // If we are unable to perform the extend for free, get the default cost.
3711   return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
3712                                  CostKind);
3713 }
3714 
getCFInstrCost(unsigned Opcode,TTI::TargetCostKind CostKind,const Instruction * I) const3715 InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
3716                                                TTI::TargetCostKind CostKind,
3717                                                const Instruction *I) const {
3718   if (CostKind != TTI::TCK_RecipThroughput)
3719     return Opcode == Instruction::PHI ? 0 : 1;
3720   assert(CostKind == TTI::TCK_RecipThroughput && "unexpected CostKind");
3721   // Branches are assumed to be predicted.
3722   return 0;
3723 }
3724 
getVectorInstrCostHelper(unsigned Opcode,Type * Val,TTI::TargetCostKind CostKind,unsigned Index,const Instruction * I,Value * Scalar,ArrayRef<std::tuple<Value *,User *,int>> ScalarUserAndIdx) const3725 InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
3726     unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
3727     const Instruction *I, Value *Scalar,
3728     ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const {
3729   assert(Val->isVectorTy() && "This must be a vector type");
3730 
3731   if (Index != -1U) {
3732     // Legalize the type.
3733     std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Val);
3734 
3735     // This type is legalized to a scalar type.
3736     if (!LT.second.isVector())
3737       return 0;
3738 
3739     // The type may be split. For fixed-width vectors we can normalize the
3740     // index to the new type.
3741     if (LT.second.isFixedLengthVector()) {
3742       unsigned Width = LT.second.getVectorNumElements();
3743       Index = Index % Width;
3744     }
3745 
3746     // The element at index zero is already inside the vector.
3747     // - For a insert-element or extract-element
3748     // instruction that extracts integers, an explicit FPR -> GPR move is
3749     // needed. So it has non-zero cost.
3750     if (Index == 0 && !Val->getScalarType()->isIntegerTy())
3751       return 0;
3752 
3753     // This is recognising a LD1 single-element structure to one lane of one
3754     // register instruction. I.e., if this is an `insertelement` instruction,
3755     // and its second operand is a load, then we will generate a LD1, which
3756     // are expensive instructions.
3757     if (I && dyn_cast<LoadInst>(I->getOperand(1)))
3758       return CostKind == TTI::TCK_CodeSize
3759                  ? 0
3760                  : ST->getVectorInsertExtractBaseCost() + 1;
3761 
3762     // i1 inserts and extract will include an extra cset or cmp of the vector
3763     // value. Increase the cost by 1 to account.
3764     if (Val->getScalarSizeInBits() == 1)
3765       return CostKind == TTI::TCK_CodeSize
3766                  ? 2
3767                  : ST->getVectorInsertExtractBaseCost() + 1;
3768 
3769     // FIXME:
3770     // If the extract-element and insert-element instructions could be
3771     // simplified away (e.g., could be combined into users by looking at use-def
3772     // context), they have no cost. This is not done in the first place for
3773     // compile-time considerations.
3774   }
3775 
3776   // In case of Neon, if there exists extractelement from lane != 0 such that
3777   // 1. extractelement does not necessitate a move from vector_reg -> GPR.
3778   // 2. extractelement result feeds into fmul.
3779   // 3. Other operand of fmul is an extractelement from lane 0 or lane
3780   // equivalent to 0.
3781   // then the extractelement can be merged with fmul in the backend and it
3782   // incurs no cost.
3783   // e.g.
3784   // define double @foo(<2 x double> %a) {
3785   //   %1 = extractelement <2 x double> %a, i32 0
3786   //   %2 = extractelement <2 x double> %a, i32 1
3787   //   %res = fmul double %1, %2
3788   //   ret double %res
3789   // }
3790   // %2 and %res can be merged in the backend to generate fmul d0, d0, v1.d[1]
3791   auto ExtractCanFuseWithFmul = [&]() {
3792     // We bail out if the extract is from lane 0.
3793     if (Index == 0)
3794       return false;
3795 
3796     // Check if the scalar element type of the vector operand of ExtractElement
3797     // instruction is one of the allowed types.
3798     auto IsAllowedScalarTy = [&](const Type *T) {
3799       return T->isFloatTy() || T->isDoubleTy() ||
3800              (T->isHalfTy() && ST->hasFullFP16());
3801     };
3802 
3803     // Check if the extractelement user is scalar fmul.
3804     auto IsUserFMulScalarTy = [](const Value *EEUser) {
3805       // Check if the user is scalar fmul.
3806       const auto *BO = dyn_cast<BinaryOperator>(EEUser);
3807       return BO && BO->getOpcode() == BinaryOperator::FMul &&
3808              !BO->getType()->isVectorTy();
3809     };
3810 
3811     // Check if the extract index is from lane 0 or lane equivalent to 0 for a
3812     // certain scalar type and a certain vector register width.
3813     auto IsExtractLaneEquivalentToZero = [&](unsigned Idx, unsigned EltSz) {
3814       auto RegWidth =
3815           getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
3816               .getFixedValue();
3817       return Idx == 0 || (RegWidth != 0 && (Idx * EltSz) % RegWidth == 0);
3818     };
3819 
3820     // Check if the type constraints on input vector type and result scalar type
3821     // of extractelement instruction are satisfied.
3822     if (!isa<FixedVectorType>(Val) || !IsAllowedScalarTy(Val->getScalarType()))
3823       return false;
3824 
3825     if (Scalar) {
3826       DenseMap<User *, unsigned> UserToExtractIdx;
3827       for (auto *U : Scalar->users()) {
3828         if (!IsUserFMulScalarTy(U))
3829           return false;
3830         // Recording entry for the user is important. Index value is not
3831         // important.
3832         UserToExtractIdx[U];
3833       }
3834       if (UserToExtractIdx.empty())
3835         return false;
3836       for (auto &[S, U, L] : ScalarUserAndIdx) {
3837         for (auto *U : S->users()) {
3838           if (UserToExtractIdx.contains(U)) {
3839             auto *FMul = cast<BinaryOperator>(U);
3840             auto *Op0 = FMul->getOperand(0);
3841             auto *Op1 = FMul->getOperand(1);
3842             if ((Op0 == S && Op1 == S) || Op0 != S || Op1 != S) {
3843               UserToExtractIdx[U] = L;
3844               break;
3845             }
3846           }
3847         }
3848       }
3849       for (auto &[U, L] : UserToExtractIdx) {
3850         if (!IsExtractLaneEquivalentToZero(Index, Val->getScalarSizeInBits()) &&
3851             !IsExtractLaneEquivalentToZero(L, Val->getScalarSizeInBits()))
3852           return false;
3853       }
3854     } else {
3855       const auto *EE = cast<ExtractElementInst>(I);
3856 
3857       const auto *IdxOp = dyn_cast<ConstantInt>(EE->getIndexOperand());
3858       if (!IdxOp)
3859         return false;
3860 
3861       return !EE->users().empty() && all_of(EE->users(), [&](const User *U) {
3862         if (!IsUserFMulScalarTy(U))
3863           return false;
3864 
3865         // Check if the other operand of extractelement is also extractelement
3866         // from lane equivalent to 0.
3867         const auto *BO = cast<BinaryOperator>(U);
3868         const auto *OtherEE = dyn_cast<ExtractElementInst>(
3869             BO->getOperand(0) == EE ? BO->getOperand(1) : BO->getOperand(0));
3870         if (OtherEE) {
3871           const auto *IdxOp = dyn_cast<ConstantInt>(OtherEE->getIndexOperand());
3872           if (!IdxOp)
3873             return false;
3874           return IsExtractLaneEquivalentToZero(
3875               cast<ConstantInt>(OtherEE->getIndexOperand())
3876                   ->getValue()
3877                   .getZExtValue(),
3878               OtherEE->getType()->getScalarSizeInBits());
3879         }
3880         return true;
3881       });
3882     }
3883     return true;
3884   };
3885 
3886   if (Opcode == Instruction::ExtractElement && (I || Scalar) &&
3887       ExtractCanFuseWithFmul())
3888     return 0;
3889 
3890   // All other insert/extracts cost this much.
3891   return CostKind == TTI::TCK_CodeSize ? 1
3892                                        : ST->getVectorInsertExtractBaseCost();
3893 }
3894 
getVectorInstrCost(unsigned Opcode,Type * Val,TTI::TargetCostKind CostKind,unsigned Index,const Value * Op0,const Value * Op1) const3895 InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
3896                                                    TTI::TargetCostKind CostKind,
3897                                                    unsigned Index,
3898                                                    const Value *Op0,
3899                                                    const Value *Op1) const {
3900   // Treat insert at lane 0 into a poison vector as having zero cost. This
3901   // ensures vector broadcasts via an insert + shuffle (and will be lowered to a
3902   // single dup) are treated as cheap.
3903   if (Opcode == Instruction::InsertElement && Index == 0 && Op0 &&
3904       isa<PoisonValue>(Op0))
3905     return 0;
3906   return getVectorInstrCostHelper(Opcode, Val, CostKind, Index);
3907 }
3908 
getVectorInstrCost(unsigned Opcode,Type * Val,TTI::TargetCostKind CostKind,unsigned Index,Value * Scalar,ArrayRef<std::tuple<Value *,User *,int>> ScalarUserAndIdx) const3909 InstructionCost AArch64TTIImpl::getVectorInstrCost(
3910     unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
3911     Value *Scalar,
3912     ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const {
3913   return getVectorInstrCostHelper(Opcode, Val, CostKind, Index, nullptr, Scalar,
3914                                   ScalarUserAndIdx);
3915 }
3916 
getVectorInstrCost(const Instruction & I,Type * Val,TTI::TargetCostKind CostKind,unsigned Index) const3917 InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,
3918                                                    Type *Val,
3919                                                    TTI::TargetCostKind CostKind,
3920                                                    unsigned Index) const {
3921   return getVectorInstrCostHelper(I.getOpcode(), Val, CostKind, Index, &I);
3922 }
3923 
getScalarizationOverhead(VectorType * Ty,const APInt & DemandedElts,bool Insert,bool Extract,TTI::TargetCostKind CostKind,bool ForPoisonSrc,ArrayRef<Value * > VL) const3924 InstructionCost AArch64TTIImpl::getScalarizationOverhead(
3925     VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
3926     TTI::TargetCostKind CostKind, bool ForPoisonSrc,
3927     ArrayRef<Value *> VL) const {
3928   if (isa<ScalableVectorType>(Ty))
3929     return InstructionCost::getInvalid();
3930   if (Ty->getElementType()->isFloatingPointTy())
3931     return BaseT::getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
3932                                            CostKind);
3933   unsigned VecInstCost =
3934       CostKind == TTI::TCK_CodeSize ? 1 : ST->getVectorInsertExtractBaseCost();
3935   return DemandedElts.popcount() * (Insert + Extract) * VecInstCost;
3936 }
3937 
getArithmeticInstrCost(unsigned Opcode,Type * Ty,TTI::TargetCostKind CostKind,TTI::OperandValueInfo Op1Info,TTI::OperandValueInfo Op2Info,ArrayRef<const Value * > Args,const Instruction * CxtI) const3938 InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
3939     unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
3940     TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info,
3941     ArrayRef<const Value *> Args, const Instruction *CxtI) const {
3942 
3943   // The code-generator is currently not able to handle scalable vectors
3944   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
3945   // it. This change will be removed when code-generation for these types is
3946   // sufficiently reliable.
3947   if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
3948     if (VTy->getElementCount() == ElementCount::getScalable(1))
3949       return InstructionCost::getInvalid();
3950 
3951   // TODO: Handle more cost kinds.
3952   if (CostKind != TTI::TCK_RecipThroughput)
3953     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
3954                                          Op2Info, Args, CxtI);
3955 
3956   // Legalize the type.
3957   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
3958   int ISD = TLI->InstructionOpcodeToISD(Opcode);
3959 
3960   switch (ISD) {
3961   default:
3962     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
3963                                          Op2Info);
3964   case ISD::SREM:
3965   case ISD::SDIV:
3966     /*
3967     Notes for sdiv/srem specific costs:
3968     1. This only considers the cases where the divisor is constant, uniform and
3969     (pow-of-2/non-pow-of-2). Other cases are not important since they either
3970     result in some form of (ldr + adrp), corresponding to constant vectors, or
3971     scalarization of the division operation.
3972     2. Constant divisors, either negative in whole or partially, don't result in
3973     significantly different codegen as compared to positive constant divisors.
3974     So, we don't consider negative divisors separately.
3975     3. If the codegen is significantly different with SVE, it has been indicated
3976     using comments at appropriate places.
3977 
3978     sdiv specific cases:
3979     -----------------------------------------------------------------------
3980     codegen                       | pow-of-2               | Type
3981     -----------------------------------------------------------------------
3982     add + cmp + csel + asr        | Y                      | i64
3983     add + cmp + csel + asr        | Y                      | i32
3984     -----------------------------------------------------------------------
3985 
3986     srem specific cases:
3987     -----------------------------------------------------------------------
3988     codegen                       | pow-of-2               | Type
3989     -----------------------------------------------------------------------
3990     negs + and + and + csneg      | Y                      | i64
3991     negs + and + and + csneg      | Y                      | i32
3992     -----------------------------------------------------------------------
3993 
3994     other sdiv/srem cases:
3995     -------------------------------------------------------------------------
3996     common codegen            | + srem     | + sdiv     | pow-of-2  | Type
3997     -------------------------------------------------------------------------
3998     smulh + asr + add + add   | -          | -          | N         | i64
3999     smull + lsr + add + add   | -          | -          | N         | i32
4000     usra                      | and + sub  | sshr       | Y         | <2 x i64>
4001     2 * (scalar code)         | -          | -          | N         | <2 x i64>
4002     usra                      | bic + sub  | sshr + neg | Y         | <4 x i32>
4003     smull2 + smull + uzp2     | mls        | -          | N         | <4 x i32>
4004            + sshr  + usra     |            |            |           |
4005     -------------------------------------------------------------------------
4006     */
4007     if (Op2Info.isConstant() && Op2Info.isUniform()) {
4008       InstructionCost AddCost =
4009           getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
4010                                  Op1Info.getNoProps(), Op2Info.getNoProps());
4011       InstructionCost AsrCost =
4012           getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
4013                                  Op1Info.getNoProps(), Op2Info.getNoProps());
4014       InstructionCost MulCost =
4015           getArithmeticInstrCost(Instruction::Mul, Ty, CostKind,
4016                                  Op1Info.getNoProps(), Op2Info.getNoProps());
4017       // add/cmp/csel/csneg should have similar cost while asr/negs/and should
4018       // have similar cost.
4019       auto VT = TLI->getValueType(DL, Ty);
4020       if (VT.isScalarInteger() && VT.getSizeInBits() <= 64) {
4021         if (Op2Info.isPowerOf2() || Op2Info.isNegatedPowerOf2()) {
4022           // Neg can be folded into the asr instruction.
4023           return ISD == ISD::SDIV ? (3 * AddCost + AsrCost)
4024                                   : (3 * AsrCost + AddCost);
4025         } else {
4026           return MulCost + AsrCost + 2 * AddCost;
4027         }
4028       } else if (VT.isVector()) {
4029         InstructionCost UsraCost = 2 * AsrCost;
4030         if (Op2Info.isPowerOf2() || Op2Info.isNegatedPowerOf2()) {
4031           // Division with scalable types corresponds to native 'asrd'
4032           // instruction when SVE is available.
4033           // e.g. %1 = sdiv <vscale x 4 x i32> %a, splat (i32 8)
4034 
4035           // One more for the negation in SDIV
4036           InstructionCost Cost =
4037               (Op2Info.isNegatedPowerOf2() && ISD == ISD::SDIV) ? AsrCost : 0;
4038           if (Ty->isScalableTy() && ST->hasSVE())
4039             Cost += 2 * AsrCost;
4040           else {
4041             Cost +=
4042                 UsraCost +
4043                 (ISD == ISD::SDIV
4044                      ? (LT.second.getScalarType() == MVT::i64 ? 1 : 2) * AsrCost
4045                      : 2 * AddCost);
4046           }
4047           return Cost;
4048         } else if (LT.second == MVT::v2i64) {
4049           return VT.getVectorNumElements() *
4050                  getArithmeticInstrCost(Opcode, Ty->getScalarType(), CostKind,
4051                                         Op1Info.getNoProps(),
4052                                         Op2Info.getNoProps());
4053         } else {
4054           // When SVE is available, we get:
4055           // smulh + lsr + add/sub + asr + add/sub.
4056           if (Ty->isScalableTy() && ST->hasSVE())
4057             return MulCost /*smulh cost*/ + 2 * AddCost + 2 * AsrCost;
4058           return 2 * MulCost + AddCost /*uzp2 cost*/ + AsrCost + UsraCost;
4059         }
4060       }
4061     }
4062     if (Op2Info.isConstant() && !Op2Info.isUniform() &&
4063         LT.second.isFixedLengthVector()) {
4064       // FIXME: When the constant vector is non-uniform, this may result in
4065       // loading the vector from constant pool or in some cases, may also result
4066       // in scalarization. For now, we are approximating this with the
4067       // scalarization cost.
4068       auto ExtractCost = 2 * getVectorInstrCost(Instruction::ExtractElement, Ty,
4069                                                 CostKind, -1, nullptr, nullptr);
4070       auto InsertCost = getVectorInstrCost(Instruction::InsertElement, Ty,
4071                                            CostKind, -1, nullptr, nullptr);
4072       unsigned NElts = cast<FixedVectorType>(Ty)->getNumElements();
4073       return ExtractCost + InsertCost +
4074              NElts * getArithmeticInstrCost(Opcode, Ty->getScalarType(),
4075                                             CostKind, Op1Info.getNoProps(),
4076                                             Op2Info.getNoProps());
4077     }
4078     [[fallthrough]];
4079   case ISD::UDIV:
4080   case ISD::UREM: {
4081     auto VT = TLI->getValueType(DL, Ty);
4082     if (Op2Info.isConstant()) {
4083       // If the operand is a power of 2 we can use the shift or and cost.
4084       if (ISD == ISD::UDIV && Op2Info.isPowerOf2())
4085         return getArithmeticInstrCost(Instruction::LShr, Ty, CostKind,
4086                                       Op1Info.getNoProps(),
4087                                       Op2Info.getNoProps());
4088       if (ISD == ISD::UREM && Op2Info.isPowerOf2())
4089         return getArithmeticInstrCost(Instruction::And, Ty, CostKind,
4090                                       Op1Info.getNoProps(),
4091                                       Op2Info.getNoProps());
4092 
4093       if (ISD == ISD::UDIV || ISD == ISD::UREM) {
4094         // Divides by a constant are expanded to MULHU + SUB + SRL + ADD + SRL.
4095         // The MULHU will be expanded to UMULL for the types not listed below,
4096         // and will become a pair of UMULL+MULL2 for 128bit vectors.
4097         bool HasMULH = VT == MVT::i64 || LT.second == MVT::nxv2i64 ||
4098                        LT.second == MVT::nxv4i32 || LT.second == MVT::nxv8i16 ||
4099                        LT.second == MVT::nxv16i8;
4100         bool Is128bit = LT.second.is128BitVector();
4101 
4102         InstructionCost MulCost =
4103             getArithmeticInstrCost(Instruction::Mul, Ty, CostKind,
4104                                    Op1Info.getNoProps(), Op2Info.getNoProps());
4105         InstructionCost AddCost =
4106             getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
4107                                    Op1Info.getNoProps(), Op2Info.getNoProps());
4108         InstructionCost ShrCost =
4109             getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
4110                                    Op1Info.getNoProps(), Op2Info.getNoProps());
4111         InstructionCost DivCost = MulCost * (Is128bit ? 2 : 1) + // UMULL/UMULH
4112                                   (HasMULH ? 0 : ShrCost) +      // UMULL shift
4113                                   AddCost * 2 + ShrCost;
4114         return DivCost + (ISD == ISD::UREM ? MulCost + AddCost : 0);
4115       }
4116     }
4117 
4118     // div i128's are lowered as libcalls.  Pass nullptr as (u)divti3 calls are
4119     // emitted by the backend even when those functions are not declared in the
4120     // module.
4121     if (!VT.isVector() && VT.getSizeInBits() > 64)
4122       return getCallInstrCost(/*Function*/ nullptr, Ty, {Ty, Ty}, CostKind);
4123 
4124     InstructionCost Cost = BaseT::getArithmeticInstrCost(
4125         Opcode, Ty, CostKind, Op1Info, Op2Info);
4126     if (Ty->isVectorTy() && (ISD == ISD::SDIV || ISD == ISD::UDIV)) {
4127       if (TLI->isOperationLegalOrCustom(ISD, LT.second) && ST->hasSVE()) {
4128         // SDIV/UDIV operations are lowered using SVE, then we can have less
4129         // costs.
4130         if (VT.isSimple() && isa<FixedVectorType>(Ty) &&
4131             Ty->getPrimitiveSizeInBits().getFixedValue() < 128) {
4132           static const CostTblEntry DivTbl[]{
4133               {ISD::SDIV, MVT::v2i8, 5},  {ISD::SDIV, MVT::v4i8, 8},
4134               {ISD::SDIV, MVT::v8i8, 8},  {ISD::SDIV, MVT::v2i16, 5},
4135               {ISD::SDIV, MVT::v4i16, 5}, {ISD::SDIV, MVT::v2i32, 1},
4136               {ISD::UDIV, MVT::v2i8, 5},  {ISD::UDIV, MVT::v4i8, 8},
4137               {ISD::UDIV, MVT::v8i8, 8},  {ISD::UDIV, MVT::v2i16, 5},
4138               {ISD::UDIV, MVT::v4i16, 5}, {ISD::UDIV, MVT::v2i32, 1}};
4139 
4140           const auto *Entry = CostTableLookup(DivTbl, ISD, VT.getSimpleVT());
4141           if (nullptr != Entry)
4142             return Entry->Cost;
4143         }
4144         // For 8/16-bit elements, the cost is higher because the type
4145         // requires promotion and possibly splitting:
4146         if (LT.second.getScalarType() == MVT::i8)
4147           Cost *= 8;
4148         else if (LT.second.getScalarType() == MVT::i16)
4149           Cost *= 4;
4150         return Cost;
4151       } else {
4152         // If one of the operands is a uniform constant then the cost for each
4153         // element is Cost for insertion, extraction and division.
4154         // Insertion cost = 2, Extraction Cost = 2, Division = cost for the
4155         // operation with scalar type
4156         if ((Op1Info.isConstant() && Op1Info.isUniform()) ||
4157             (Op2Info.isConstant() && Op2Info.isUniform())) {
4158           if (auto *VTy = dyn_cast<FixedVectorType>(Ty)) {
4159             InstructionCost DivCost = BaseT::getArithmeticInstrCost(
4160                 Opcode, Ty->getScalarType(), CostKind, Op1Info, Op2Info);
4161             return (4 + DivCost) * VTy->getNumElements();
4162           }
4163         }
4164         // On AArch64, without SVE, vector divisions are expanded
4165         // into scalar divisions of each pair of elements.
4166         Cost += getVectorInstrCost(Instruction::ExtractElement, Ty, CostKind,
4167                                    -1, nullptr, nullptr);
4168         Cost += getVectorInstrCost(Instruction::InsertElement, Ty, CostKind, -1,
4169                                    nullptr, nullptr);
4170       }
4171 
4172       // TODO: if one of the arguments is scalar, then it's not necessary to
4173       // double the cost of handling the vector elements.
4174       Cost += Cost;
4175     }
4176     return Cost;
4177   }
4178   case ISD::MUL:
4179     // When SVE is available, then we can lower the v2i64 operation using
4180     // the SVE mul instruction, which has a lower cost.
4181     if (LT.second == MVT::v2i64 && ST->hasSVE())
4182       return LT.first;
4183 
4184     // When SVE is not available, there is no MUL.2d instruction,
4185     // which means mul <2 x i64> is expensive as elements are extracted
4186     // from the vectors and the muls scalarized.
4187     // As getScalarizationOverhead is a bit too pessimistic, we
4188     // estimate the cost for a i64 vector directly here, which is:
4189     // - four 2-cost i64 extracts,
4190     // - two 2-cost i64 inserts, and
4191     // - two 1-cost muls.
4192     // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with
4193     // LT.first = 2 the cost is 28. If both operands are extensions it will not
4194     // need to scalarize so the cost can be cheaper (smull or umull).
4195     // so the cost can be cheaper (smull or umull).
4196     if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args))
4197       return LT.first;
4198     return cast<VectorType>(Ty)->getElementCount().getKnownMinValue() *
4199            (getArithmeticInstrCost(Opcode, Ty->getScalarType(), CostKind) +
4200             getVectorInstrCost(Instruction::ExtractElement, Ty, CostKind, -1,
4201                                nullptr, nullptr) *
4202                 2 +
4203             getVectorInstrCost(Instruction::InsertElement, Ty, CostKind, -1,
4204                                nullptr, nullptr));
4205   case ISD::ADD:
4206   case ISD::XOR:
4207   case ISD::OR:
4208   case ISD::AND:
4209   case ISD::SRL:
4210   case ISD::SRA:
4211   case ISD::SHL:
4212     // These nodes are marked as 'custom' for combining purposes only.
4213     // We know that they are legal. See LowerAdd in ISelLowering.
4214     return LT.first;
4215 
4216   case ISD::FNEG:
4217     // Scalar fmul(fneg) or fneg(fmul) can be converted to fnmul
4218     if ((Ty->isFloatTy() || Ty->isDoubleTy() ||
4219          (Ty->isHalfTy() && ST->hasFullFP16())) &&
4220         CxtI &&
4221         ((CxtI->hasOneUse() &&
4222           match(*CxtI->user_begin(), m_FMul(m_Value(), m_Value()))) ||
4223          match(CxtI->getOperand(0), m_FMul(m_Value(), m_Value()))))
4224       return 0;
4225     [[fallthrough]];
4226   case ISD::FADD:
4227   case ISD::FSUB:
4228     // Increase the cost for half and bfloat types if not architecturally
4229     // supported.
4230     if ((Ty->getScalarType()->isHalfTy() && !ST->hasFullFP16()) ||
4231         (Ty->getScalarType()->isBFloatTy() && !ST->hasBF16()))
4232       return 2 * LT.first;
4233     if (!Ty->getScalarType()->isFP128Ty())
4234       return LT.first;
4235     [[fallthrough]];
4236   case ISD::FMUL:
4237   case ISD::FDIV:
4238     // These nodes are marked as 'custom' just to lower them to SVE.
4239     // We know said lowering will incur no additional cost.
4240     if (!Ty->getScalarType()->isFP128Ty())
4241       return 2 * LT.first;
4242 
4243     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
4244                                          Op2Info);
4245   case ISD::FREM:
4246     // Pass nullptr as fmod/fmodf calls are emitted by the backend even when
4247     // those functions are not declared in the module.
4248     if (!Ty->isVectorTy())
4249       return getCallInstrCost(/*Function*/ nullptr, Ty, {Ty, Ty}, CostKind);
4250     return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
4251                                          Op2Info);
4252   }
4253 }
4254 
4255 InstructionCost
getAddressComputationCost(Type * Ty,ScalarEvolution * SE,const SCEV * Ptr) const4256 AArch64TTIImpl::getAddressComputationCost(Type *Ty, ScalarEvolution *SE,
4257                                           const SCEV *Ptr) const {
4258   // Address computations in vectorized code with non-consecutive addresses will
4259   // likely result in more instructions compared to scalar code where the
4260   // computation can more often be merged into the index mode. The resulting
4261   // extra micro-ops can significantly decrease throughput.
4262   unsigned NumVectorInstToHideOverhead = NeonNonConstStrideOverhead;
4263   int MaxMergeDistance = 64;
4264 
4265   if (Ty->isVectorTy() && SE &&
4266       !BaseT::isConstantStridedAccessLessThan(SE, Ptr, MaxMergeDistance + 1))
4267     return NumVectorInstToHideOverhead;
4268 
4269   // In many cases the address computation is not merged into the instruction
4270   // addressing mode.
4271   return 1;
4272 }
4273 
getCmpSelInstrCost(unsigned Opcode,Type * ValTy,Type * CondTy,CmpInst::Predicate VecPred,TTI::TargetCostKind CostKind,TTI::OperandValueInfo Op1Info,TTI::OperandValueInfo Op2Info,const Instruction * I) const4274 InstructionCost AArch64TTIImpl::getCmpSelInstrCost(
4275     unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
4276     TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info,
4277     TTI::OperandValueInfo Op2Info, const Instruction *I) const {
4278   int ISD = TLI->InstructionOpcodeToISD(Opcode);
4279   // We don't lower some vector selects well that are wider than the register
4280   // width. TODO: Improve this with different cost kinds.
4281   if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT) {
4282     // We would need this many instructions to hide the scalarization happening.
4283     const int AmortizationCost = 20;
4284 
4285     // If VecPred is not set, check if we can get a predicate from the context
4286     // instruction, if its type matches the requested ValTy.
4287     if (VecPred == CmpInst::BAD_ICMP_PREDICATE && I && I->getType() == ValTy) {
4288       CmpPredicate CurrentPred;
4289       if (match(I, m_Select(m_Cmp(CurrentPred, m_Value(), m_Value()), m_Value(),
4290                             m_Value())))
4291         VecPred = CurrentPred;
4292     }
4293     // Check if we have a compare/select chain that can be lowered using
4294     // a (F)CMxx & BFI pair.
4295     if (CmpInst::isIntPredicate(VecPred) || VecPred == CmpInst::FCMP_OLE ||
4296         VecPred == CmpInst::FCMP_OLT || VecPred == CmpInst::FCMP_OGT ||
4297         VecPred == CmpInst::FCMP_OGE || VecPred == CmpInst::FCMP_OEQ ||
4298         VecPred == CmpInst::FCMP_UNE) {
4299       static const auto ValidMinMaxTys = {
4300           MVT::v8i8,  MVT::v16i8, MVT::v4i16, MVT::v8i16, MVT::v2i32,
4301           MVT::v4i32, MVT::v2i64, MVT::v2f32, MVT::v4f32, MVT::v2f64};
4302       static const auto ValidFP16MinMaxTys = {MVT::v4f16, MVT::v8f16};
4303 
4304       auto LT = getTypeLegalizationCost(ValTy);
4305       if (any_of(ValidMinMaxTys, [&LT](MVT M) { return M == LT.second; }) ||
4306           (ST->hasFullFP16() &&
4307            any_of(ValidFP16MinMaxTys, [&LT](MVT M) { return M == LT.second; })))
4308         return LT.first;
4309     }
4310 
4311     static const TypeConversionCostTblEntry
4312     VectorSelectTbl[] = {
4313       { ISD::SELECT, MVT::v2i1, MVT::v2f32, 2 },
4314       { ISD::SELECT, MVT::v2i1, MVT::v2f64, 2 },
4315       { ISD::SELECT, MVT::v4i1, MVT::v4f32, 2 },
4316       { ISD::SELECT, MVT::v4i1, MVT::v4f16, 2 },
4317       { ISD::SELECT, MVT::v8i1, MVT::v8f16, 2 },
4318       { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 },
4319       { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 },
4320       { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 },
4321       { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost },
4322       { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost },
4323       { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost }
4324     };
4325 
4326     EVT SelCondTy = TLI->getValueType(DL, CondTy);
4327     EVT SelValTy = TLI->getValueType(DL, ValTy);
4328     if (SelCondTy.isSimple() && SelValTy.isSimple()) {
4329       if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD,
4330                                                      SelCondTy.getSimpleVT(),
4331                                                      SelValTy.getSimpleVT()))
4332         return Entry->Cost;
4333     }
4334   }
4335 
4336   if (isa<FixedVectorType>(ValTy) && ISD == ISD::SETCC) {
4337     Type *ValScalarTy = ValTy->getScalarType();
4338     if ((ValScalarTy->isHalfTy() && !ST->hasFullFP16()) ||
4339         ValScalarTy->isBFloatTy()) {
4340       auto *ValVTy = cast<FixedVectorType>(ValTy);
4341 
4342       // Without dedicated instructions we promote [b]f16 compares to f32.
4343       auto *PromotedTy =
4344           VectorType::get(Type::getFloatTy(ValTy->getContext()), ValVTy);
4345 
4346       InstructionCost Cost = 0;
4347       // Promote operands to float vectors.
4348       Cost += 2 * getCastInstrCost(Instruction::FPExt, PromotedTy, ValTy,
4349                                    TTI::CastContextHint::None, CostKind);
4350       // Compare float vectors.
4351       Cost += getCmpSelInstrCost(Opcode, PromotedTy, CondTy, VecPred, CostKind,
4352                                  Op1Info, Op2Info);
4353       // During codegen we'll truncate the vector result from i32 to i16.
4354       Cost +=
4355           getCastInstrCost(Instruction::Trunc, VectorType::getInteger(ValVTy),
4356                            VectorType::getInteger(PromotedTy),
4357                            TTI::CastContextHint::None, CostKind);
4358       return Cost;
4359     }
4360   }
4361 
4362   // Treat the icmp in icmp(and, 0) or icmp(and, -1/1) when it can be folded to
4363   // icmp(and, 0) as free, as we can make use of ands, but only if the
4364   // comparison is not unsigned. FIXME: Enable for non-throughput cost kinds
4365   // providing it will not cause performance regressions.
4366   if (CostKind == TTI::TCK_RecipThroughput && ValTy->isIntegerTy() &&
4367       ISD == ISD::SETCC && I && !CmpInst::isUnsigned(VecPred) &&
4368       TLI->isTypeLegal(TLI->getValueType(DL, ValTy)) &&
4369       match(I->getOperand(0), m_And(m_Value(), m_Value()))) {
4370     if (match(I->getOperand(1), m_Zero()))
4371       return 0;
4372 
4373     // x >= 1 / x < 1 -> x > 0 / x <= 0
4374     if (match(I->getOperand(1), m_One()) &&
4375         (VecPred == CmpInst::ICMP_SLT || VecPred == CmpInst::ICMP_SGE))
4376       return 0;
4377 
4378     // x <= -1 / x > -1 -> x > 0 / x <= 0
4379     if (match(I->getOperand(1), m_AllOnes()) &&
4380         (VecPred == CmpInst::ICMP_SLE || VecPred == CmpInst::ICMP_SGT))
4381       return 0;
4382   }
4383 
4384   // The base case handles scalable vectors fine for now, since it treats the
4385   // cost as 1 * legalization cost.
4386   return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
4387                                    Op1Info, Op2Info, I);
4388 }
4389 
4390 AArch64TTIImpl::TTI::MemCmpExpansionOptions
enableMemCmpExpansion(bool OptSize,bool IsZeroCmp) const4391 AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
4392   TTI::MemCmpExpansionOptions Options;
4393   if (ST->requiresStrictAlign()) {
4394     // TODO: Add cost modeling for strict align. Misaligned loads expand to
4395     // a bunch of instructions when strict align is enabled.
4396     return Options;
4397   }
4398   Options.AllowOverlappingLoads = true;
4399   Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize);
4400   Options.NumLoadsPerBlock = Options.MaxNumLoads;
4401   // TODO: Though vector loads usually perform well on AArch64, in some targets
4402   // they may wake up the FP unit, which raises the power consumption.  Perhaps
4403   // they could be used with no holds barred (-O3).
4404   Options.LoadSizes = {8, 4, 2, 1};
4405   Options.AllowedTailExpansions = {3, 5, 6};
4406   return Options;
4407 }
4408 
prefersVectorizedAddressing() const4409 bool AArch64TTIImpl::prefersVectorizedAddressing() const {
4410   return ST->hasSVE();
4411 }
4412 
4413 InstructionCost
getMaskedMemoryOpCost(unsigned Opcode,Type * Src,Align Alignment,unsigned AddressSpace,TTI::TargetCostKind CostKind) const4414 AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
4415                                       Align Alignment, unsigned AddressSpace,
4416                                       TTI::TargetCostKind CostKind) const {
4417   if (useNeonVector(Src))
4418     return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace,
4419                                         CostKind);
4420   auto LT = getTypeLegalizationCost(Src);
4421   if (!LT.first.isValid())
4422     return InstructionCost::getInvalid();
4423 
4424   // Return an invalid cost for element types that we are unable to lower.
4425   auto *VT = cast<VectorType>(Src);
4426   if (VT->getElementType()->isIntegerTy(1))
4427     return InstructionCost::getInvalid();
4428 
4429   // The code-generator is currently not able to handle scalable vectors
4430   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
4431   // it. This change will be removed when code-generation for these types is
4432   // sufficiently reliable.
4433   if (VT->getElementCount() == ElementCount::getScalable(1))
4434     return InstructionCost::getInvalid();
4435 
4436   return LT.first;
4437 }
4438 
4439 // This function returns gather/scatter overhead either from
4440 // user-provided value or specialized values per-target from \p ST.
getSVEGatherScatterOverhead(unsigned Opcode,const AArch64Subtarget * ST)4441 static unsigned getSVEGatherScatterOverhead(unsigned Opcode,
4442                                             const AArch64Subtarget *ST) {
4443   assert((Opcode == Instruction::Load || Opcode == Instruction::Store) &&
4444          "Should be called on only load or stores.");
4445   switch (Opcode) {
4446   case Instruction::Load:
4447     if (SVEGatherOverhead.getNumOccurrences() > 0)
4448       return SVEGatherOverhead;
4449     return ST->getGatherOverhead();
4450     break;
4451   case Instruction::Store:
4452     if (SVEScatterOverhead.getNumOccurrences() > 0)
4453       return SVEScatterOverhead;
4454     return ST->getScatterOverhead();
4455     break;
4456   default:
4457     llvm_unreachable("Shouldn't have reached here");
4458   }
4459 }
4460 
getGatherScatterOpCost(unsigned Opcode,Type * DataTy,const Value * Ptr,bool VariableMask,Align Alignment,TTI::TargetCostKind CostKind,const Instruction * I) const4461 InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
4462     unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
4463     Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
4464   if (useNeonVector(DataTy) || !isLegalMaskedGatherScatter(DataTy))
4465     return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
4466                                          Alignment, CostKind, I);
4467   auto *VT = cast<VectorType>(DataTy);
4468   auto LT = getTypeLegalizationCost(DataTy);
4469   if (!LT.first.isValid())
4470     return InstructionCost::getInvalid();
4471 
4472   // Return an invalid cost for element types that we are unable to lower.
4473   if (!LT.second.isVector() ||
4474       !isElementTypeLegalForScalableVector(VT->getElementType()) ||
4475       VT->getElementType()->isIntegerTy(1))
4476     return InstructionCost::getInvalid();
4477 
4478   // The code-generator is currently not able to handle scalable vectors
4479   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
4480   // it. This change will be removed when code-generation for these types is
4481   // sufficiently reliable.
4482   if (VT->getElementCount() == ElementCount::getScalable(1))
4483     return InstructionCost::getInvalid();
4484 
4485   ElementCount LegalVF = LT.second.getVectorElementCount();
4486   InstructionCost MemOpCost =
4487       getMemoryOpCost(Opcode, VT->getElementType(), Alignment, 0, CostKind,
4488                       {TTI::OK_AnyValue, TTI::OP_None}, I);
4489   // Add on an overhead cost for using gathers/scatters.
4490   MemOpCost *= getSVEGatherScatterOverhead(Opcode, ST);
4491   return LT.first * MemOpCost * getMaxNumElements(LegalVF);
4492 }
4493 
useNeonVector(const Type * Ty) const4494 bool AArch64TTIImpl::useNeonVector(const Type *Ty) const {
4495   return isa<FixedVectorType>(Ty) && !ST->useSVEForFixedLengthVectors();
4496 }
4497 
getMemoryOpCost(unsigned Opcode,Type * Ty,Align Alignment,unsigned AddressSpace,TTI::TargetCostKind CostKind,TTI::OperandValueInfo OpInfo,const Instruction * I) const4498 InstructionCost AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty,
4499                                                 Align Alignment,
4500                                                 unsigned AddressSpace,
4501                                                 TTI::TargetCostKind CostKind,
4502                                                 TTI::OperandValueInfo OpInfo,
4503                                                 const Instruction *I) const {
4504   EVT VT = TLI->getValueType(DL, Ty, true);
4505   // Type legalization can't handle structs
4506   if (VT == MVT::Other)
4507     return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
4508                                   CostKind);
4509 
4510   auto LT = getTypeLegalizationCost(Ty);
4511   if (!LT.first.isValid())
4512     return InstructionCost::getInvalid();
4513 
4514   // The code-generator is currently not able to handle scalable vectors
4515   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
4516   // it. This change will be removed when code-generation for these types is
4517   // sufficiently reliable.
4518   // We also only support full register predicate loads and stores.
4519   if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
4520     if (VTy->getElementCount() == ElementCount::getScalable(1) ||
4521         (VTy->getElementType()->isIntegerTy(1) &&
4522          !VTy->getElementCount().isKnownMultipleOf(
4523              ElementCount::getScalable(16))))
4524       return InstructionCost::getInvalid();
4525 
4526   // TODO: consider latency as well for TCK_SizeAndLatency.
4527   if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency)
4528     return LT.first;
4529 
4530   if (CostKind != TTI::TCK_RecipThroughput)
4531     return 1;
4532 
4533   if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store &&
4534       LT.second.is128BitVector() && Alignment < Align(16)) {
4535     // Unaligned stores are extremely inefficient. We don't split all
4536     // unaligned 128-bit stores because the negative impact that has shown in
4537     // practice on inlined block copy code.
4538     // We make such stores expensive so that we will only vectorize if there
4539     // are 6 other instructions getting vectorized.
4540     const int AmortizationCost = 6;
4541 
4542     return LT.first * 2 * AmortizationCost;
4543   }
4544 
4545   // Opaque ptr or ptr vector types are i64s and can be lowered to STP/LDPs.
4546   if (Ty->isPtrOrPtrVectorTy())
4547     return LT.first;
4548 
4549   if (useNeonVector(Ty)) {
4550     // Check truncating stores and extending loads.
4551     if (Ty->getScalarSizeInBits() != LT.second.getScalarSizeInBits()) {
4552       // v4i8 types are lowered to scalar a load/store and sshll/xtn.
4553       if (VT == MVT::v4i8)
4554         return 2;
4555       // Otherwise we need to scalarize.
4556       return cast<FixedVectorType>(Ty)->getNumElements() * 2;
4557     }
4558     EVT EltVT = VT.getVectorElementType();
4559     unsigned EltSize = EltVT.getScalarSizeInBits();
4560     if (!isPowerOf2_32(EltSize) || EltSize < 8 || EltSize > 64 ||
4561         VT.getVectorNumElements() >= (128 / EltSize) || Alignment != Align(1))
4562       return LT.first;
4563     // FIXME: v3i8 lowering currently is very inefficient, due to automatic
4564     // widening to v4i8, which produces suboptimal results.
4565     if (VT.getVectorNumElements() == 3 && EltVT == MVT::i8)
4566       return LT.first;
4567 
4568     // Check non-power-of-2 loads/stores for legal vector element types with
4569     // NEON. Non-power-of-2 memory ops will get broken down to a set of
4570     // operations on smaller power-of-2 ops, including ld1/st1.
4571     LLVMContext &C = Ty->getContext();
4572     InstructionCost Cost(0);
4573     SmallVector<EVT> TypeWorklist;
4574     TypeWorklist.push_back(VT);
4575     while (!TypeWorklist.empty()) {
4576       EVT CurrVT = TypeWorklist.pop_back_val();
4577       unsigned CurrNumElements = CurrVT.getVectorNumElements();
4578       if (isPowerOf2_32(CurrNumElements)) {
4579         Cost += 1;
4580         continue;
4581       }
4582 
4583       unsigned PrevPow2 = NextPowerOf2(CurrNumElements) / 2;
4584       TypeWorklist.push_back(EVT::getVectorVT(C, EltVT, PrevPow2));
4585       TypeWorklist.push_back(
4586           EVT::getVectorVT(C, EltVT, CurrNumElements - PrevPow2));
4587     }
4588     return Cost;
4589   }
4590 
4591   return LT.first;
4592 }
4593 
getInterleavedMemoryOpCost(unsigned Opcode,Type * VecTy,unsigned Factor,ArrayRef<unsigned> Indices,Align Alignment,unsigned AddressSpace,TTI::TargetCostKind CostKind,bool UseMaskForCond,bool UseMaskForGaps) const4594 InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
4595     unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
4596     Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
4597     bool UseMaskForCond, bool UseMaskForGaps) const {
4598   assert(Factor >= 2 && "Invalid interleave factor");
4599   auto *VecVTy = cast<VectorType>(VecTy);
4600 
4601   if (VecTy->isScalableTy() && !ST->hasSVE())
4602     return InstructionCost::getInvalid();
4603 
4604   // Scalable VFs will emit vector.[de]interleave intrinsics, and currently we
4605   // only have lowering for power-of-2 factors.
4606   // TODO: Add lowering for vector.[de]interleave3 intrinsics and support in
4607   // InterleavedAccessPass for ld3/st3
4608   if (VecTy->isScalableTy() && !isPowerOf2_32(Factor))
4609     return InstructionCost::getInvalid();
4610 
4611   // Vectorization for masked interleaved accesses is only enabled for scalable
4612   // VF.
4613   if (!VecTy->isScalableTy() && (UseMaskForCond || UseMaskForGaps))
4614     return InstructionCost::getInvalid();
4615 
4616   if (!UseMaskForGaps && Factor <= TLI->getMaxSupportedInterleaveFactor()) {
4617     unsigned MinElts = VecVTy->getElementCount().getKnownMinValue();
4618     auto *SubVecTy =
4619         VectorType::get(VecVTy->getElementType(),
4620                         VecVTy->getElementCount().divideCoefficientBy(Factor));
4621 
4622     // ldN/stN only support legal vector types of size 64 or 128 in bits.
4623     // Accesses having vector types that are a multiple of 128 bits can be
4624     // matched to more than one ldN/stN instruction.
4625     bool UseScalable;
4626     if (MinElts % Factor == 0 &&
4627         TLI->isLegalInterleavedAccessType(SubVecTy, DL, UseScalable))
4628       return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL, UseScalable);
4629   }
4630 
4631   return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices,
4632                                            Alignment, AddressSpace, CostKind,
4633                                            UseMaskForCond, UseMaskForGaps);
4634 }
4635 
4636 InstructionCost
getCostOfKeepingLiveOverCall(ArrayRef<Type * > Tys) const4637 AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) const {
4638   InstructionCost Cost = 0;
4639   TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
4640   for (auto *I : Tys) {
4641     if (!I->isVectorTy())
4642       continue;
4643     if (I->getScalarSizeInBits() * cast<FixedVectorType>(I)->getNumElements() ==
4644         128)
4645       Cost += getMemoryOpCost(Instruction::Store, I, Align(128), 0, CostKind) +
4646               getMemoryOpCost(Instruction::Load, I, Align(128), 0, CostKind);
4647   }
4648   return Cost;
4649 }
4650 
getMaxInterleaveFactor(ElementCount VF) const4651 unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) const {
4652   return ST->getMaxInterleaveFactor();
4653 }
4654 
4655 // For Falkor, we want to avoid having too many strided loads in a loop since
4656 // that can exhaust the HW prefetcher resources.  We adjust the unroller
4657 // MaxCount preference below to attempt to ensure unrolling doesn't create too
4658 // many strided loads.
4659 static void
getFalkorUnrollingPreferences(Loop * L,ScalarEvolution & SE,TargetTransformInfo::UnrollingPreferences & UP)4660 getFalkorUnrollingPreferences(Loop *L, ScalarEvolution &SE,
4661                               TargetTransformInfo::UnrollingPreferences &UP) {
4662   enum { MaxStridedLoads = 7 };
4663   auto countStridedLoads = [](Loop *L, ScalarEvolution &SE) {
4664     int StridedLoads = 0;
4665     // FIXME? We could make this more precise by looking at the CFG and
4666     // e.g. not counting loads in each side of an if-then-else diamond.
4667     for (const auto BB : L->blocks()) {
4668       for (auto &I : *BB) {
4669         LoadInst *LMemI = dyn_cast<LoadInst>(&I);
4670         if (!LMemI)
4671           continue;
4672 
4673         Value *PtrValue = LMemI->getPointerOperand();
4674         if (L->isLoopInvariant(PtrValue))
4675           continue;
4676 
4677         const SCEV *LSCEV = SE.getSCEV(PtrValue);
4678         const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
4679         if (!LSCEVAddRec || !LSCEVAddRec->isAffine())
4680           continue;
4681 
4682         // FIXME? We could take pairing of unrolled load copies into account
4683         // by looking at the AddRec, but we would probably have to limit this
4684         // to loops with no stores or other memory optimization barriers.
4685         ++StridedLoads;
4686         // We've seen enough strided loads that seeing more won't make a
4687         // difference.
4688         if (StridedLoads > MaxStridedLoads / 2)
4689           return StridedLoads;
4690       }
4691     }
4692     return StridedLoads;
4693   };
4694 
4695   int StridedLoads = countStridedLoads(L, SE);
4696   LLVM_DEBUG(dbgs() << "falkor-hwpf: detected " << StridedLoads
4697                     << " strided loads\n");
4698   // Pick the largest power of 2 unroll count that won't result in too many
4699   // strided loads.
4700   if (StridedLoads) {
4701     UP.MaxCount = 1 << Log2_32(MaxStridedLoads / StridedLoads);
4702     LLVM_DEBUG(dbgs() << "falkor-hwpf: setting unroll MaxCount to "
4703                       << UP.MaxCount << '\n');
4704   }
4705 }
4706 
4707 // This function returns true if the loop:
4708 //  1. Has a valid cost, and
4709 //  2. Has a cost within the supplied budget.
4710 // Otherwise it returns false.
isLoopSizeWithinBudget(Loop * L,const AArch64TTIImpl & TTI,InstructionCost Budget,unsigned * FinalSize)4711 static bool isLoopSizeWithinBudget(Loop *L, const AArch64TTIImpl &TTI,
4712                                    InstructionCost Budget,
4713                                    unsigned *FinalSize) {
4714   // Estimate the size of the loop.
4715   InstructionCost LoopCost = 0;
4716 
4717   for (auto *BB : L->getBlocks()) {
4718     for (auto &I : *BB) {
4719       SmallVector<const Value *, 4> Operands(I.operand_values());
4720       InstructionCost Cost =
4721           TTI.getInstructionCost(&I, Operands, TTI::TCK_CodeSize);
4722       // This can happen with intrinsics that don't currently have a cost model
4723       // or for some operations that require SVE.
4724       if (!Cost.isValid())
4725         return false;
4726 
4727       LoopCost += Cost;
4728       if (LoopCost > Budget)
4729         return false;
4730     }
4731   }
4732 
4733   if (FinalSize)
4734     *FinalSize = LoopCost.getValue();
4735   return true;
4736 }
4737 
shouldUnrollMultiExitLoop(Loop * L,ScalarEvolution & SE,const AArch64TTIImpl & TTI)4738 static bool shouldUnrollMultiExitLoop(Loop *L, ScalarEvolution &SE,
4739                                       const AArch64TTIImpl &TTI) {
4740   // Only consider loops with unknown trip counts for which we can determine
4741   // a symbolic expression. Multi-exit loops with small known trip counts will
4742   // likely be unrolled anyway.
4743   const SCEV *BTC = SE.getSymbolicMaxBackedgeTakenCount(L);
4744   if (isa<SCEVConstant>(BTC) || isa<SCEVCouldNotCompute>(BTC))
4745     return false;
4746 
4747   // It might not be worth unrolling loops with low max trip counts. Restrict
4748   // this to max trip counts > 32 for now.
4749   unsigned MaxTC = SE.getSmallConstantMaxTripCount(L);
4750   if (MaxTC > 0 && MaxTC <= 32)
4751     return false;
4752 
4753   // Make sure the loop size is <= 5.
4754   if (!isLoopSizeWithinBudget(L, TTI, 5, nullptr))
4755     return false;
4756 
4757   // Small search loops with multiple exits can be highly beneficial to unroll.
4758   // We only care about loops with exactly two exiting blocks, although each
4759   // block could jump to the same exit block.
4760   ArrayRef<BasicBlock *> Blocks = L->getBlocks();
4761   if (Blocks.size() != 2)
4762     return false;
4763 
4764   if (any_of(Blocks, [](BasicBlock *BB) {
4765         return !isa<BranchInst>(BB->getTerminator());
4766       }))
4767     return false;
4768 
4769   return true;
4770 }
4771 
4772 /// For Apple CPUs, we want to runtime-unroll loops to make better use if the
4773 /// OOO engine's wide instruction window and various predictors.
4774 static void
getAppleRuntimeUnrollPreferences(Loop * L,ScalarEvolution & SE,TargetTransformInfo::UnrollingPreferences & UP,const AArch64TTIImpl & TTI)4775 getAppleRuntimeUnrollPreferences(Loop *L, ScalarEvolution &SE,
4776                                  TargetTransformInfo::UnrollingPreferences &UP,
4777                                  const AArch64TTIImpl &TTI) {
4778   // Limit loops with structure that is highly likely to benefit from runtime
4779   // unrolling; that is we exclude outer loops and loops with many blocks (i.e.
4780   // likely with complex control flow). Note that the heuristics here may be
4781   // overly conservative and we err on the side of avoiding runtime unrolling
4782   // rather than unroll excessively. They are all subject to further refinement.
4783   if (!L->isInnermost() || L->getNumBlocks() > 8)
4784     return;
4785 
4786   // Loops with multiple exits are handled by common code.
4787   if (!L->getExitBlock())
4788     return;
4789 
4790   const SCEV *BTC = SE.getSymbolicMaxBackedgeTakenCount(L);
4791   if (isa<SCEVConstant>(BTC) || isa<SCEVCouldNotCompute>(BTC) ||
4792       (SE.getSmallConstantMaxTripCount(L) > 0 &&
4793        SE.getSmallConstantMaxTripCount(L) <= 32))
4794     return;
4795 
4796   if (findStringMetadataForLoop(L, "llvm.loop.isvectorized"))
4797     return;
4798 
4799   if (SE.getSymbolicMaxBackedgeTakenCount(L) != SE.getBackedgeTakenCount(L))
4800     return;
4801 
4802   // Limit to loops with trip counts that are cheap to expand.
4803   UP.SCEVExpansionBudget = 1;
4804 
4805   // Try to unroll small, single block loops, if they have load/store
4806   // dependencies, to expose more parallel memory access streams.
4807   BasicBlock *Header = L->getHeader();
4808   if (Header == L->getLoopLatch()) {
4809     // Estimate the size of the loop.
4810     unsigned Size;
4811     if (!isLoopSizeWithinBudget(L, TTI, 8, &Size))
4812       return;
4813 
4814     SmallPtrSet<Value *, 8> LoadedValues;
4815     SmallVector<StoreInst *> Stores;
4816     for (auto *BB : L->blocks()) {
4817       for (auto &I : *BB) {
4818         Value *Ptr = getLoadStorePointerOperand(&I);
4819         if (!Ptr)
4820           continue;
4821         const SCEV *PtrSCEV = SE.getSCEV(Ptr);
4822         if (SE.isLoopInvariant(PtrSCEV, L))
4823           continue;
4824         if (isa<LoadInst>(&I))
4825           LoadedValues.insert(&I);
4826         else
4827           Stores.push_back(cast<StoreInst>(&I));
4828       }
4829     }
4830 
4831     // Try to find an unroll count that maximizes the use of the instruction
4832     // window, i.e. trying to fetch as many instructions per cycle as possible.
4833     unsigned MaxInstsPerLine = 16;
4834     unsigned UC = 1;
4835     unsigned BestUC = 1;
4836     unsigned SizeWithBestUC = BestUC * Size;
4837     while (UC <= 8) {
4838       unsigned SizeWithUC = UC * Size;
4839       if (SizeWithUC > 48)
4840         break;
4841       if ((SizeWithUC % MaxInstsPerLine) == 0 ||
4842           (SizeWithBestUC % MaxInstsPerLine) < (SizeWithUC % MaxInstsPerLine)) {
4843         BestUC = UC;
4844         SizeWithBestUC = BestUC * Size;
4845       }
4846       UC++;
4847     }
4848 
4849     if (BestUC == 1 || none_of(Stores, [&LoadedValues](StoreInst *SI) {
4850           return LoadedValues.contains(SI->getOperand(0));
4851         }))
4852       return;
4853 
4854     UP.Runtime = true;
4855     UP.DefaultUnrollRuntimeCount = BestUC;
4856     return;
4857   }
4858 
4859   // Try to runtime-unroll loops with early-continues depending on loop-varying
4860   // loads; this helps with branch-prediction for the early-continues.
4861   auto *Term = dyn_cast<BranchInst>(Header->getTerminator());
4862   auto *Latch = L->getLoopLatch();
4863   SmallVector<BasicBlock *> Preds(predecessors(Latch));
4864   if (!Term || !Term->isConditional() || Preds.size() == 1 ||
4865       !llvm::is_contained(Preds, Header) ||
4866       none_of(Preds, [L](BasicBlock *Pred) { return L->contains(Pred); }))
4867     return;
4868 
4869   std::function<bool(Instruction *, unsigned)> DependsOnLoopLoad =
4870       [&](Instruction *I, unsigned Depth) -> bool {
4871     if (isa<PHINode>(I) || L->isLoopInvariant(I) || Depth > 8)
4872       return false;
4873 
4874     if (isa<LoadInst>(I))
4875       return true;
4876 
4877     return any_of(I->operands(), [&](Value *V) {
4878       auto *I = dyn_cast<Instruction>(V);
4879       return I && DependsOnLoopLoad(I, Depth + 1);
4880     });
4881   };
4882   CmpPredicate Pred;
4883   Instruction *I;
4884   if (match(Term, m_Br(m_ICmp(Pred, m_Instruction(I), m_Value()), m_Value(),
4885                        m_Value())) &&
4886       DependsOnLoopLoad(I, 0)) {
4887     UP.Runtime = true;
4888   }
4889 }
4890 
getUnrollingPreferences(Loop * L,ScalarEvolution & SE,TTI::UnrollingPreferences & UP,OptimizationRemarkEmitter * ORE) const4891 void AArch64TTIImpl::getUnrollingPreferences(
4892     Loop *L, ScalarEvolution &SE, TTI::UnrollingPreferences &UP,
4893     OptimizationRemarkEmitter *ORE) const {
4894   // Enable partial unrolling and runtime unrolling.
4895   BaseT::getUnrollingPreferences(L, SE, UP, ORE);
4896 
4897   UP.UpperBound = true;
4898 
4899   // For inner loop, it is more likely to be a hot one, and the runtime check
4900   // can be promoted out from LICM pass, so the overhead is less, let's try
4901   // a larger threshold to unroll more loops.
4902   if (L->getLoopDepth() > 1)
4903     UP.PartialThreshold *= 2;
4904 
4905   // Disable partial & runtime unrolling on -Os.
4906   UP.PartialOptSizeThreshold = 0;
4907 
4908   // No need to unroll auto-vectorized loops
4909   if (findStringMetadataForLoop(L, "llvm.loop.isvectorized"))
4910     return;
4911 
4912   // Scan the loop: don't unroll loops with calls as this could prevent
4913   // inlining.
4914   for (auto *BB : L->getBlocks()) {
4915     for (auto &I : *BB) {
4916       if (isa<CallBase>(I)) {
4917         if (isa<CallInst>(I) || isa<InvokeInst>(I))
4918           if (const Function *F = cast<CallBase>(I).getCalledFunction())
4919             if (!isLoweredToCall(F))
4920               continue;
4921         return;
4922       }
4923     }
4924   }
4925 
4926   // Apply subtarget-specific unrolling preferences.
4927   switch (ST->getProcFamily()) {
4928   case AArch64Subtarget::AppleA14:
4929   case AArch64Subtarget::AppleA15:
4930   case AArch64Subtarget::AppleA16:
4931   case AArch64Subtarget::AppleM4:
4932     getAppleRuntimeUnrollPreferences(L, SE, UP, *this);
4933     break;
4934   case AArch64Subtarget::Falkor:
4935     if (EnableFalkorHWPFUnrollFix)
4936       getFalkorUnrollingPreferences(L, SE, UP);
4937     break;
4938   default:
4939     break;
4940   }
4941 
4942   // If this is a small, multi-exit loop similar to something like std::find,
4943   // then there is typically a performance improvement achieved by unrolling.
4944   if (!L->getExitBlock() && shouldUnrollMultiExitLoop(L, SE, *this)) {
4945     UP.RuntimeUnrollMultiExit = true;
4946     UP.Runtime = true;
4947     // Limit unroll count.
4948     UP.DefaultUnrollRuntimeCount = 4;
4949     // Allow slightly more costly trip-count expansion to catch search loops
4950     // with pointer inductions.
4951     UP.SCEVExpansionBudget = 5;
4952     return;
4953   }
4954 
4955   // Enable runtime unrolling for in-order models
4956   // If mcpu is omitted, getProcFamily() returns AArch64Subtarget::Others, so by
4957   // checking for that case, we can ensure that the default behaviour is
4958   // unchanged
4959   if (ST->getProcFamily() != AArch64Subtarget::Generic &&
4960       !ST->getSchedModel().isOutOfOrder()) {
4961     UP.Runtime = true;
4962     UP.Partial = true;
4963     UP.UnrollRemainder = true;
4964     UP.DefaultUnrollRuntimeCount = 4;
4965 
4966     UP.UnrollAndJam = true;
4967     UP.UnrollAndJamInnerLoopThreshold = 60;
4968   }
4969 }
4970 
getPeelingPreferences(Loop * L,ScalarEvolution & SE,TTI::PeelingPreferences & PP) const4971 void AArch64TTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE,
4972                                            TTI::PeelingPreferences &PP) const {
4973   BaseT::getPeelingPreferences(L, SE, PP);
4974 }
4975 
getOrCreateResultFromMemIntrinsic(IntrinsicInst * Inst,Type * ExpectedType,bool CanCreate) const4976 Value *AArch64TTIImpl::getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst,
4977                                                          Type *ExpectedType,
4978                                                          bool CanCreate) const {
4979   switch (Inst->getIntrinsicID()) {
4980   default:
4981     return nullptr;
4982   case Intrinsic::aarch64_neon_st2:
4983   case Intrinsic::aarch64_neon_st3:
4984   case Intrinsic::aarch64_neon_st4: {
4985     // Create a struct type
4986     StructType *ST = dyn_cast<StructType>(ExpectedType);
4987     if (!CanCreate || !ST)
4988       return nullptr;
4989     unsigned NumElts = Inst->arg_size() - 1;
4990     if (ST->getNumElements() != NumElts)
4991       return nullptr;
4992     for (unsigned i = 0, e = NumElts; i != e; ++i) {
4993       if (Inst->getArgOperand(i)->getType() != ST->getElementType(i))
4994         return nullptr;
4995     }
4996     Value *Res = PoisonValue::get(ExpectedType);
4997     IRBuilder<> Builder(Inst);
4998     for (unsigned i = 0, e = NumElts; i != e; ++i) {
4999       Value *L = Inst->getArgOperand(i);
5000       Res = Builder.CreateInsertValue(Res, L, i);
5001     }
5002     return Res;
5003   }
5004   case Intrinsic::aarch64_neon_ld2:
5005   case Intrinsic::aarch64_neon_ld3:
5006   case Intrinsic::aarch64_neon_ld4:
5007     if (Inst->getType() == ExpectedType)
5008       return Inst;
5009     return nullptr;
5010   }
5011 }
5012 
getTgtMemIntrinsic(IntrinsicInst * Inst,MemIntrinsicInfo & Info) const5013 bool AArch64TTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
5014                                         MemIntrinsicInfo &Info) const {
5015   switch (Inst->getIntrinsicID()) {
5016   default:
5017     break;
5018   case Intrinsic::aarch64_neon_ld2:
5019   case Intrinsic::aarch64_neon_ld3:
5020   case Intrinsic::aarch64_neon_ld4:
5021     Info.ReadMem = true;
5022     Info.WriteMem = false;
5023     Info.PtrVal = Inst->getArgOperand(0);
5024     break;
5025   case Intrinsic::aarch64_neon_st2:
5026   case Intrinsic::aarch64_neon_st3:
5027   case Intrinsic::aarch64_neon_st4:
5028     Info.ReadMem = false;
5029     Info.WriteMem = true;
5030     Info.PtrVal = Inst->getArgOperand(Inst->arg_size() - 1);
5031     break;
5032   }
5033 
5034   switch (Inst->getIntrinsicID()) {
5035   default:
5036     return false;
5037   case Intrinsic::aarch64_neon_ld2:
5038   case Intrinsic::aarch64_neon_st2:
5039     Info.MatchingId = VECTOR_LDST_TWO_ELEMENTS;
5040     break;
5041   case Intrinsic::aarch64_neon_ld3:
5042   case Intrinsic::aarch64_neon_st3:
5043     Info.MatchingId = VECTOR_LDST_THREE_ELEMENTS;
5044     break;
5045   case Intrinsic::aarch64_neon_ld4:
5046   case Intrinsic::aarch64_neon_st4:
5047     Info.MatchingId = VECTOR_LDST_FOUR_ELEMENTS;
5048     break;
5049   }
5050   return true;
5051 }
5052 
5053 /// See if \p I should be considered for address type promotion. We check if \p
5054 /// I is a sext with right type and used in memory accesses. If it used in a
5055 /// "complex" getelementptr, we allow it to be promoted without finding other
5056 /// sext instructions that sign extended the same initial value. A getelementptr
5057 /// is considered as "complex" if it has more than 2 operands.
shouldConsiderAddressTypePromotion(const Instruction & I,bool & AllowPromotionWithoutCommonHeader) const5058 bool AArch64TTIImpl::shouldConsiderAddressTypePromotion(
5059     const Instruction &I, bool &AllowPromotionWithoutCommonHeader) const {
5060   bool Considerable = false;
5061   AllowPromotionWithoutCommonHeader = false;
5062   if (!isa<SExtInst>(&I))
5063     return false;
5064   Type *ConsideredSExtType =
5065       Type::getInt64Ty(I.getParent()->getParent()->getContext());
5066   if (I.getType() != ConsideredSExtType)
5067     return false;
5068   // See if the sext is the one with the right type and used in at least one
5069   // GetElementPtrInst.
5070   for (const User *U : I.users()) {
5071     if (const GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(U)) {
5072       Considerable = true;
5073       // A getelementptr is considered as "complex" if it has more than 2
5074       // operands. We will promote a SExt used in such complex GEP as we
5075       // expect some computation to be merged if they are done on 64 bits.
5076       if (GEPInst->getNumOperands() > 2) {
5077         AllowPromotionWithoutCommonHeader = true;
5078         break;
5079       }
5080     }
5081   }
5082   return Considerable;
5083 }
5084 
isLegalToVectorizeReduction(const RecurrenceDescriptor & RdxDesc,ElementCount VF) const5085 bool AArch64TTIImpl::isLegalToVectorizeReduction(
5086     const RecurrenceDescriptor &RdxDesc, ElementCount VF) const {
5087   if (!VF.isScalable())
5088     return true;
5089 
5090   Type *Ty = RdxDesc.getRecurrenceType();
5091   if (Ty->isBFloatTy() || !isElementTypeLegalForScalableVector(Ty))
5092     return false;
5093 
5094   switch (RdxDesc.getRecurrenceKind()) {
5095   case RecurKind::Add:
5096   case RecurKind::FAdd:
5097   case RecurKind::And:
5098   case RecurKind::Or:
5099   case RecurKind::Xor:
5100   case RecurKind::SMin:
5101   case RecurKind::SMax:
5102   case RecurKind::UMin:
5103   case RecurKind::UMax:
5104   case RecurKind::FMin:
5105   case RecurKind::FMax:
5106   case RecurKind::FMulAdd:
5107   case RecurKind::AnyOf:
5108     return true;
5109   default:
5110     return false;
5111   }
5112 }
5113 
5114 InstructionCost
getMinMaxReductionCost(Intrinsic::ID IID,VectorType * Ty,FastMathFlags FMF,TTI::TargetCostKind CostKind) const5115 AArch64TTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty,
5116                                        FastMathFlags FMF,
5117                                        TTI::TargetCostKind CostKind) const {
5118   // The code-generator is currently not able to handle scalable vectors
5119   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
5120   // it. This change will be removed when code-generation for these types is
5121   // sufficiently reliable.
5122   if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
5123     if (VTy->getElementCount() == ElementCount::getScalable(1))
5124       return InstructionCost::getInvalid();
5125 
5126   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
5127 
5128   if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
5129     return BaseT::getMinMaxReductionCost(IID, Ty, FMF, CostKind);
5130 
5131   InstructionCost LegalizationCost = 0;
5132   if (LT.first > 1) {
5133     Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext());
5134     IntrinsicCostAttributes Attrs(IID, LegalVTy, {LegalVTy, LegalVTy}, FMF);
5135     LegalizationCost = getIntrinsicInstrCost(Attrs, CostKind) * (LT.first - 1);
5136   }
5137 
5138   return LegalizationCost + /*Cost of horizontal reduction*/ 2;
5139 }
5140 
getArithmeticReductionCostSVE(unsigned Opcode,VectorType * ValTy,TTI::TargetCostKind CostKind) const5141 InstructionCost AArch64TTIImpl::getArithmeticReductionCostSVE(
5142     unsigned Opcode, VectorType *ValTy, TTI::TargetCostKind CostKind) const {
5143   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
5144   InstructionCost LegalizationCost = 0;
5145   if (LT.first > 1) {
5146     Type *LegalVTy = EVT(LT.second).getTypeForEVT(ValTy->getContext());
5147     LegalizationCost = getArithmeticInstrCost(Opcode, LegalVTy, CostKind);
5148     LegalizationCost *= LT.first - 1;
5149   }
5150 
5151   int ISD = TLI->InstructionOpcodeToISD(Opcode);
5152   assert(ISD && "Invalid opcode");
5153   // Add the final reduction cost for the legal horizontal reduction
5154   switch (ISD) {
5155   case ISD::ADD:
5156   case ISD::AND:
5157   case ISD::OR:
5158   case ISD::XOR:
5159   case ISD::FADD:
5160     return LegalizationCost + 2;
5161   default:
5162     return InstructionCost::getInvalid();
5163   }
5164 }
5165 
5166 InstructionCost
getArithmeticReductionCost(unsigned Opcode,VectorType * ValTy,std::optional<FastMathFlags> FMF,TTI::TargetCostKind CostKind) const5167 AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
5168                                            std::optional<FastMathFlags> FMF,
5169                                            TTI::TargetCostKind CostKind) const {
5170   // The code-generator is currently not able to handle scalable vectors
5171   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
5172   // it. This change will be removed when code-generation for these types is
5173   // sufficiently reliable.
5174   if (auto *VTy = dyn_cast<ScalableVectorType>(ValTy))
5175     if (VTy->getElementCount() == ElementCount::getScalable(1))
5176       return InstructionCost::getInvalid();
5177 
5178   if (TTI::requiresOrderedReduction(FMF)) {
5179     if (auto *FixedVTy = dyn_cast<FixedVectorType>(ValTy)) {
5180       InstructionCost BaseCost =
5181           BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
5182       // Add on extra cost to reflect the extra overhead on some CPUs. We still
5183       // end up vectorizing for more computationally intensive loops.
5184       return BaseCost + FixedVTy->getNumElements();
5185     }
5186 
5187     if (Opcode != Instruction::FAdd)
5188       return InstructionCost::getInvalid();
5189 
5190     auto *VTy = cast<ScalableVectorType>(ValTy);
5191     InstructionCost Cost =
5192         getArithmeticInstrCost(Opcode, VTy->getScalarType(), CostKind);
5193     Cost *= getMaxNumElements(VTy->getElementCount());
5194     return Cost;
5195   }
5196 
5197   if (isa<ScalableVectorType>(ValTy))
5198     return getArithmeticReductionCostSVE(Opcode, ValTy, CostKind);
5199 
5200   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
5201   MVT MTy = LT.second;
5202   int ISD = TLI->InstructionOpcodeToISD(Opcode);
5203   assert(ISD && "Invalid opcode");
5204 
5205   // Horizontal adds can use the 'addv' instruction. We model the cost of these
5206   // instructions as twice a normal vector add, plus 1 for each legalization
5207   // step (LT.first). This is the only arithmetic vector reduction operation for
5208   // which we have an instruction.
5209   // OR, XOR and AND costs should match the codegen from:
5210   // OR: llvm/test/CodeGen/AArch64/reduce-or.ll
5211   // XOR: llvm/test/CodeGen/AArch64/reduce-xor.ll
5212   // AND: llvm/test/CodeGen/AArch64/reduce-and.ll
5213   static const CostTblEntry CostTblNoPairwise[]{
5214       {ISD::ADD, MVT::v8i8,   2},
5215       {ISD::ADD, MVT::v16i8,  2},
5216       {ISD::ADD, MVT::v4i16,  2},
5217       {ISD::ADD, MVT::v8i16,  2},
5218       {ISD::ADD, MVT::v2i32,  2},
5219       {ISD::ADD, MVT::v4i32,  2},
5220       {ISD::ADD, MVT::v2i64,  2},
5221       {ISD::OR,  MVT::v8i8,  15},
5222       {ISD::OR,  MVT::v16i8, 17},
5223       {ISD::OR,  MVT::v4i16,  7},
5224       {ISD::OR,  MVT::v8i16,  9},
5225       {ISD::OR,  MVT::v2i32,  3},
5226       {ISD::OR,  MVT::v4i32,  5},
5227       {ISD::OR,  MVT::v2i64,  3},
5228       {ISD::XOR, MVT::v8i8,  15},
5229       {ISD::XOR, MVT::v16i8, 17},
5230       {ISD::XOR, MVT::v4i16,  7},
5231       {ISD::XOR, MVT::v8i16,  9},
5232       {ISD::XOR, MVT::v2i32,  3},
5233       {ISD::XOR, MVT::v4i32,  5},
5234       {ISD::XOR, MVT::v2i64,  3},
5235       {ISD::AND, MVT::v8i8,  15},
5236       {ISD::AND, MVT::v16i8, 17},
5237       {ISD::AND, MVT::v4i16,  7},
5238       {ISD::AND, MVT::v8i16,  9},
5239       {ISD::AND, MVT::v2i32,  3},
5240       {ISD::AND, MVT::v4i32,  5},
5241       {ISD::AND, MVT::v2i64,  3},
5242   };
5243   switch (ISD) {
5244   default:
5245     break;
5246   case ISD::FADD:
5247     if (Type *EltTy = ValTy->getScalarType();
5248         // FIXME: For half types without fullfp16 support, this could extend and
5249         // use a fp32 faddp reduction but current codegen unrolls.
5250         MTy.isVector() && (EltTy->isFloatTy() || EltTy->isDoubleTy() ||
5251                            (EltTy->isHalfTy() && ST->hasFullFP16()))) {
5252       const unsigned NElts = MTy.getVectorNumElements();
5253       if (ValTy->getElementCount().getFixedValue() >= 2 && NElts >= 2 &&
5254           isPowerOf2_32(NElts))
5255         // Reduction corresponding to series of fadd instructions is lowered to
5256         // series of faddp instructions. faddp has latency/throughput that
5257         // matches fadd instruction and hence, every faddp instruction can be
5258         // considered to have a relative cost = 1 with
5259         // CostKind = TCK_RecipThroughput.
5260         // An faddp will pairwise add vector elements, so the size of input
5261         // vector reduces by half every time, requiring
5262         // #(faddp instructions) = log2_32(NElts).
5263         return (LT.first - 1) + /*No of faddp instructions*/ Log2_32(NElts);
5264     }
5265     break;
5266   case ISD::ADD:
5267     if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy))
5268       return (LT.first - 1) + Entry->Cost;
5269     break;
5270   case ISD::XOR:
5271   case ISD::AND:
5272   case ISD::OR:
5273     const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy);
5274     if (!Entry)
5275       break;
5276     auto *ValVTy = cast<FixedVectorType>(ValTy);
5277     if (MTy.getVectorNumElements() <= ValVTy->getNumElements() &&
5278         isPowerOf2_32(ValVTy->getNumElements())) {
5279       InstructionCost ExtraCost = 0;
5280       if (LT.first != 1) {
5281         // Type needs to be split, so there is an extra cost of LT.first - 1
5282         // arithmetic ops.
5283         auto *Ty = FixedVectorType::get(ValTy->getElementType(),
5284                                         MTy.getVectorNumElements());
5285         ExtraCost = getArithmeticInstrCost(Opcode, Ty, CostKind);
5286         ExtraCost *= LT.first - 1;
5287       }
5288       // All and/or/xor of i1 will be lowered with maxv/minv/addv + fmov
5289       auto Cost = ValVTy->getElementType()->isIntegerTy(1) ? 2 : Entry->Cost;
5290       return Cost + ExtraCost;
5291     }
5292     break;
5293   }
5294   return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
5295 }
5296 
getExtendedReductionCost(unsigned Opcode,bool IsUnsigned,Type * ResTy,VectorType * VecTy,std::optional<FastMathFlags> FMF,TTI::TargetCostKind CostKind) const5297 InstructionCost AArch64TTIImpl::getExtendedReductionCost(
5298     unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *VecTy,
5299     std::optional<FastMathFlags> FMF, TTI::TargetCostKind CostKind) const {
5300   EVT VecVT = TLI->getValueType(DL, VecTy);
5301   EVT ResVT = TLI->getValueType(DL, ResTy);
5302 
5303   if (Opcode == Instruction::Add && VecVT.isSimple() && ResVT.isSimple() &&
5304       VecVT.getSizeInBits() >= 64) {
5305     std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(VecTy);
5306 
5307     // The legal cases are:
5308     //   UADDLV 8/16/32->32
5309     //   UADDLP 32->64
5310     unsigned RevVTSize = ResVT.getSizeInBits();
5311     if (((LT.second == MVT::v8i8 || LT.second == MVT::v16i8) &&
5312          RevVTSize <= 32) ||
5313         ((LT.second == MVT::v4i16 || LT.second == MVT::v8i16) &&
5314          RevVTSize <= 32) ||
5315         ((LT.second == MVT::v2i32 || LT.second == MVT::v4i32) &&
5316          RevVTSize <= 64))
5317       return (LT.first - 1) * 2 + 2;
5318   }
5319 
5320   return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, VecTy, FMF,
5321                                          CostKind);
5322 }
5323 
5324 InstructionCost
getMulAccReductionCost(bool IsUnsigned,Type * ResTy,VectorType * VecTy,TTI::TargetCostKind CostKind) const5325 AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
5326                                        VectorType *VecTy,
5327                                        TTI::TargetCostKind CostKind) const {
5328   EVT VecVT = TLI->getValueType(DL, VecTy);
5329   EVT ResVT = TLI->getValueType(DL, ResTy);
5330 
5331   if (ST->hasDotProd() && VecVT.isSimple() && ResVT.isSimple()) {
5332     std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(VecTy);
5333 
5334     // The legal cases with dotprod are
5335     //   UDOT 8->32
5336     // Which requires an additional uaddv to sum the i32 values.
5337     if ((LT.second == MVT::v8i8 || LT.second == MVT::v16i8) &&
5338          ResVT == MVT::i32)
5339       return LT.first + 2;
5340   }
5341 
5342   return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, CostKind);
5343 }
5344 
5345 InstructionCost
getSpliceCost(VectorType * Tp,int Index,TTI::TargetCostKind CostKind) const5346 AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index,
5347                               TTI::TargetCostKind CostKind) const {
5348   static const CostTblEntry ShuffleTbl[] = {
5349       { TTI::SK_Splice, MVT::nxv16i8,  1 },
5350       { TTI::SK_Splice, MVT::nxv8i16,  1 },
5351       { TTI::SK_Splice, MVT::nxv4i32,  1 },
5352       { TTI::SK_Splice, MVT::nxv2i64,  1 },
5353       { TTI::SK_Splice, MVT::nxv2f16,  1 },
5354       { TTI::SK_Splice, MVT::nxv4f16,  1 },
5355       { TTI::SK_Splice, MVT::nxv8f16,  1 },
5356       { TTI::SK_Splice, MVT::nxv2bf16, 1 },
5357       { TTI::SK_Splice, MVT::nxv4bf16, 1 },
5358       { TTI::SK_Splice, MVT::nxv8bf16, 1 },
5359       { TTI::SK_Splice, MVT::nxv2f32,  1 },
5360       { TTI::SK_Splice, MVT::nxv4f32,  1 },
5361       { TTI::SK_Splice, MVT::nxv2f64,  1 },
5362   };
5363 
5364   // The code-generator is currently not able to handle scalable vectors
5365   // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
5366   // it. This change will be removed when code-generation for these types is
5367   // sufficiently reliable.
5368   if (Tp->getElementCount() == ElementCount::getScalable(1))
5369     return InstructionCost::getInvalid();
5370 
5371   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
5372   Type *LegalVTy = EVT(LT.second).getTypeForEVT(Tp->getContext());
5373   EVT PromotedVT = LT.second.getScalarType() == MVT::i1
5374                        ? TLI->getPromotedVTForPredicate(EVT(LT.second))
5375                        : LT.second;
5376   Type *PromotedVTy = EVT(PromotedVT).getTypeForEVT(Tp->getContext());
5377   InstructionCost LegalizationCost = 0;
5378   if (Index < 0) {
5379     LegalizationCost =
5380         getCmpSelInstrCost(Instruction::ICmp, PromotedVTy, PromotedVTy,
5381                            CmpInst::BAD_ICMP_PREDICATE, CostKind) +
5382         getCmpSelInstrCost(Instruction::Select, PromotedVTy, LegalVTy,
5383                            CmpInst::BAD_ICMP_PREDICATE, CostKind);
5384   }
5385 
5386   // Predicated splice are promoted when lowering. See AArch64ISelLowering.cpp
5387   // Cost performed on a promoted type.
5388   if (LT.second.getScalarType() == MVT::i1) {
5389     LegalizationCost +=
5390         getCastInstrCost(Instruction::ZExt, PromotedVTy, LegalVTy,
5391                          TTI::CastContextHint::None, CostKind) +
5392         getCastInstrCost(Instruction::Trunc, LegalVTy, PromotedVTy,
5393                          TTI::CastContextHint::None, CostKind);
5394   }
5395   const auto *Entry =
5396       CostTableLookup(ShuffleTbl, TTI::SK_Splice, PromotedVT.getSimpleVT());
5397   assert(Entry && "Illegal Type for Splice");
5398   LegalizationCost += Entry->Cost;
5399   return LegalizationCost * LT.first;
5400 }
5401 
getPartialReductionCost(unsigned Opcode,Type * InputTypeA,Type * InputTypeB,Type * AccumType,ElementCount VF,TTI::PartialReductionExtendKind OpAExtend,TTI::PartialReductionExtendKind OpBExtend,std::optional<unsigned> BinOp,TTI::TargetCostKind CostKind) const5402 InstructionCost AArch64TTIImpl::getPartialReductionCost(
5403     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
5404     ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
5405     TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
5406     TTI::TargetCostKind CostKind) const {
5407   InstructionCost Invalid = InstructionCost::getInvalid();
5408   InstructionCost Cost(TTI::TCC_Basic);
5409 
5410   if (CostKind != TTI::TCK_RecipThroughput)
5411     return Invalid;
5412 
5413   // Sub opcodes currently only occur in chained cases.
5414   // Independent partial reduction subtractions are still costed as an add
5415   if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
5416       OpAExtend == TTI::PR_None)
5417     return Invalid;
5418 
5419   // We only support multiply binary operations for now, and for muls we
5420   // require the types being extended to be the same.
5421   // NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
5422   // only if the i8mm or sve/streaming features are available.
5423   if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
5424                 OpBExtend == TTI::PR_None ||
5425                 (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
5426                  !ST->isSVEorStreamingSVEAvailable())))
5427     return Invalid;
5428   assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
5429          "Unexpected values for OpBExtend or InputTypeB");
5430 
5431   EVT InputEVT = EVT::getEVT(InputTypeA);
5432   EVT AccumEVT = EVT::getEVT(AccumType);
5433 
5434   unsigned VFMinValue = VF.getKnownMinValue();
5435 
5436   if (VF.isScalable()) {
5437     if (!ST->isSVEorStreamingSVEAvailable())
5438       return Invalid;
5439 
5440     // Don't accept a partial reduction if the scaled accumulator is vscale x 1,
5441     // since we can't lower that type.
5442     unsigned Scale =
5443         AccumEVT.getScalarSizeInBits() / InputEVT.getScalarSizeInBits();
5444     if (VFMinValue == Scale)
5445       return Invalid;
5446   }
5447   if (VF.isFixed() &&
5448       (!ST->isNeonAvailable() || !ST->hasDotProd() || AccumEVT == MVT::i64))
5449     return Invalid;
5450 
5451   if (InputEVT == MVT::i8) {
5452     switch (VFMinValue) {
5453     default:
5454       return Invalid;
5455     case 8:
5456       if (AccumEVT == MVT::i32)
5457         Cost *= 2;
5458       else if (AccumEVT != MVT::i64)
5459         return Invalid;
5460       break;
5461     case 16:
5462       if (AccumEVT == MVT::i64)
5463         Cost *= 2;
5464       else if (AccumEVT != MVT::i32)
5465         return Invalid;
5466       break;
5467     }
5468   } else if (InputEVT == MVT::i16) {
5469     // FIXME: Allow i32 accumulator but increase cost, as we would extend
5470     //        it to i64.
5471     if (VFMinValue != 8 || AccumEVT != MVT::i64)
5472       return Invalid;
5473   } else
5474     return Invalid;
5475 
5476   return Cost;
5477 }
5478 
5479 InstructionCost
getShuffleCost(TTI::ShuffleKind Kind,VectorType * DstTy,VectorType * SrcTy,ArrayRef<int> Mask,TTI::TargetCostKind CostKind,int Index,VectorType * SubTp,ArrayRef<const Value * > Args,const Instruction * CxtI) const5480 AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy,
5481                                VectorType *SrcTy, ArrayRef<int> Mask,
5482                                TTI::TargetCostKind CostKind, int Index,
5483                                VectorType *SubTp, ArrayRef<const Value *> Args,
5484                                const Instruction *CxtI) const {
5485   assert((Mask.empty() || DstTy->isScalableTy() ||
5486           Mask.size() == DstTy->getElementCount().getKnownMinValue()) &&
5487          "Expected the Mask to match the return size if given");
5488   assert(SrcTy->getScalarType() == DstTy->getScalarType() &&
5489          "Expected the same scalar types");
5490   std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(SrcTy);
5491 
5492   // If we have a Mask, and the LT is being legalized somehow, split the Mask
5493   // into smaller vectors and sum the cost of each shuffle.
5494   if (!Mask.empty() && isa<FixedVectorType>(SrcTy) && LT.second.isVector() &&
5495       LT.second.getScalarSizeInBits() * Mask.size() > 128 &&
5496       SrcTy->getScalarSizeInBits() == LT.second.getScalarSizeInBits() &&
5497       Mask.size() > LT.second.getVectorNumElements() && !Index && !SubTp) {
5498     // Check for LD3/LD4 instructions, which are represented in llvm IR as
5499     // deinterleaving-shuffle(load). The shuffle cost could potentially be free,
5500     // but we model it with a cost of LT.first so that LD3/LD4 have a higher
5501     // cost than just the load.
5502     if (Args.size() >= 1 && isa<LoadInst>(Args[0]) &&
5503         (ShuffleVectorInst::isDeInterleaveMaskOfFactor(Mask, 3) ||
5504          ShuffleVectorInst::isDeInterleaveMaskOfFactor(Mask, 4)))
5505       return std::max<InstructionCost>(1, LT.first / 4);
5506 
5507     // Check for ST3/ST4 instructions, which are represented in llvm IR as
5508     // store(interleaving-shuffle). The shuffle cost could potentially be free,
5509     // but we model it with a cost of LT.first so that ST3/ST4 have a higher
5510     // cost than just the store.
5511     if (CxtI && CxtI->hasOneUse() && isa<StoreInst>(*CxtI->user_begin()) &&
5512         (ShuffleVectorInst::isInterleaveMask(
5513              Mask, 4, SrcTy->getElementCount().getKnownMinValue() * 2) ||
5514          ShuffleVectorInst::isInterleaveMask(
5515              Mask, 3, SrcTy->getElementCount().getKnownMinValue() * 2)))
5516       return LT.first;
5517 
5518     unsigned TpNumElts = Mask.size();
5519     unsigned LTNumElts = LT.second.getVectorNumElements();
5520     unsigned NumVecs = (TpNumElts + LTNumElts - 1) / LTNumElts;
5521     VectorType *NTp = VectorType::get(SrcTy->getScalarType(),
5522                                       LT.second.getVectorElementCount());
5523     InstructionCost Cost;
5524     std::map<std::tuple<unsigned, unsigned, SmallVector<int>>, InstructionCost>
5525         PreviousCosts;
5526     for (unsigned N = 0; N < NumVecs; N++) {
5527       SmallVector<int> NMask;
5528       // Split the existing mask into chunks of size LTNumElts. Track the source
5529       // sub-vectors to ensure the result has at most 2 inputs.
5530       unsigned Source1 = -1U, Source2 = -1U;
5531       unsigned NumSources = 0;
5532       for (unsigned E = 0; E < LTNumElts; E++) {
5533         int MaskElt = (N * LTNumElts + E < TpNumElts) ? Mask[N * LTNumElts + E]
5534                                                       : PoisonMaskElem;
5535         if (MaskElt < 0) {
5536           NMask.push_back(PoisonMaskElem);
5537           continue;
5538         }
5539 
5540         // Calculate which source from the input this comes from and whether it
5541         // is new to us.
5542         unsigned Source = MaskElt / LTNumElts;
5543         if (NumSources == 0) {
5544           Source1 = Source;
5545           NumSources = 1;
5546         } else if (NumSources == 1 && Source != Source1) {
5547           Source2 = Source;
5548           NumSources = 2;
5549         } else if (NumSources >= 2 && Source != Source1 && Source != Source2) {
5550           NumSources++;
5551         }
5552 
5553         // Add to the new mask. For the NumSources>2 case these are not correct,
5554         // but are only used for the modular lane number.
5555         if (Source == Source1)
5556           NMask.push_back(MaskElt % LTNumElts);
5557         else if (Source == Source2)
5558           NMask.push_back(MaskElt % LTNumElts + LTNumElts);
5559         else
5560           NMask.push_back(MaskElt % LTNumElts);
5561       }
5562       // Check if we have already generated this sub-shuffle, which means we
5563       // will have already generated the output. For example a <16 x i32> splat
5564       // will be the same sub-splat 4 times, which only needs to be generated
5565       // once and reused.
5566       auto Result =
5567           PreviousCosts.insert({std::make_tuple(Source1, Source2, NMask), 0});
5568       // Check if it was already in the map (already costed).
5569       if (!Result.second)
5570         continue;
5571       // If the sub-mask has at most 2 input sub-vectors then re-cost it using
5572       // getShuffleCost. If not then cost it using the worst case as the number
5573       // of element moves into a new vector.
5574       InstructionCost NCost =
5575           NumSources <= 2
5576               ? getShuffleCost(NumSources <= 1 ? TTI::SK_PermuteSingleSrc
5577                                                : TTI::SK_PermuteTwoSrc,
5578                                NTp, NTp, NMask, CostKind, 0, nullptr, Args,
5579                                CxtI)
5580               : LTNumElts;
5581       Result.first->second = NCost;
5582       Cost += NCost;
5583     }
5584     return Cost;
5585   }
5586 
5587   Kind = improveShuffleKindFromMask(Kind, Mask, SrcTy, Index, SubTp);
5588   bool IsExtractSubvector = Kind == TTI::SK_ExtractSubvector;
5589   // A subvector extract can be implemented with an ext (or trivial extract, if
5590   // from lane 0). This currently only handles low or high extracts to prevent
5591   // SLP vectorizer regressions.
5592   if (IsExtractSubvector && LT.second.isFixedLengthVector()) {
5593     if (LT.second.is128BitVector() &&
5594         cast<FixedVectorType>(SubTp)->getNumElements() ==
5595             LT.second.getVectorNumElements() / 2) {
5596       if (Index == 0)
5597         return 0;
5598       if (Index == (int)LT.second.getVectorNumElements() / 2)
5599         return 1;
5600     }
5601     Kind = TTI::SK_PermuteSingleSrc;
5602   }
5603   // FIXME: This was added to keep the costs equal when adding DstTys. Update
5604   // the code to handle length-changing shuffles.
5605   if (Kind == TTI::SK_InsertSubvector) {
5606     LT = getTypeLegalizationCost(DstTy);
5607     SrcTy = DstTy;
5608   }
5609 
5610   // Segmented shuffle matching.
5611   if (Kind == TTI::SK_PermuteSingleSrc && isa<FixedVectorType>(SrcTy) &&
5612       !Mask.empty() && SrcTy->getPrimitiveSizeInBits().isNonZero() &&
5613       SrcTy->getPrimitiveSizeInBits().isKnownMultipleOf(
5614           AArch64::SVEBitsPerBlock)) {
5615 
5616     FixedVectorType *VTy = cast<FixedVectorType>(SrcTy);
5617     unsigned Segments =
5618         VTy->getPrimitiveSizeInBits() / AArch64::SVEBitsPerBlock;
5619     unsigned SegmentElts = VTy->getNumElements() / Segments;
5620 
5621     // dupq zd.t, zn.t[idx]
5622     if ((ST->hasSVE2p1() || ST->hasSME2p1()) &&
5623         ST->isSVEorStreamingSVEAvailable() &&
5624         isDUPQMask(Mask, Segments, SegmentElts))
5625       return LT.first;
5626 
5627     // mov zd.q, vn
5628     if (ST->isSVEorStreamingSVEAvailable() &&
5629         isDUPFirstSegmentMask(Mask, Segments, SegmentElts))
5630       return LT.first;
5631   }
5632 
5633   // Check for broadcast loads, which are supported by the LD1R instruction.
5634   // In terms of code-size, the shuffle vector is free when a load + dup get
5635   // folded into a LD1R. That's what we check and return here. For performance
5636   // and reciprocal throughput, a LD1R is not completely free. In this case, we
5637   // return the cost for the broadcast below (i.e. 1 for most/all types), so
5638   // that we model the load + dup sequence slightly higher because LD1R is a
5639   // high latency instruction.
5640   if (CostKind == TTI::TCK_CodeSize && Kind == TTI::SK_Broadcast) {
5641     bool IsLoad = !Args.empty() && isa<LoadInst>(Args[0]);
5642     if (IsLoad && LT.second.isVector() &&
5643         isLegalBroadcastLoad(SrcTy->getElementType(),
5644                              LT.second.getVectorElementCount()))
5645       return 0;
5646   }
5647 
5648   // If we have 4 elements for the shuffle and a Mask, get the cost straight
5649   // from the perfect shuffle tables.
5650   if (Mask.size() == 4 &&
5651       SrcTy->getElementCount() == ElementCount::getFixed(4) &&
5652       (SrcTy->getScalarSizeInBits() == 16 ||
5653        SrcTy->getScalarSizeInBits() == 32) &&
5654       all_of(Mask, [](int E) { return E < 8; }))
5655     return getPerfectShuffleCost(Mask);
5656 
5657   // Check for identity masks, which we can treat as free.
5658   if (!Mask.empty() && LT.second.isFixedLengthVector() &&
5659       (Kind == TTI::SK_PermuteTwoSrc || Kind == TTI::SK_PermuteSingleSrc) &&
5660       all_of(enumerate(Mask), [](const auto &M) {
5661         return M.value() < 0 || M.value() == (int)M.index();
5662       }))
5663     return 0;
5664 
5665   // Check for other shuffles that are not SK_ kinds but we have native
5666   // instructions for, for example ZIP and UZP.
5667   unsigned Unused;
5668   if (LT.second.isFixedLengthVector() &&
5669       LT.second.getVectorNumElements() == Mask.size() &&
5670       (Kind == TTI::SK_PermuteTwoSrc || Kind == TTI::SK_PermuteSingleSrc) &&
5671       (isZIPMask(Mask, LT.second.getVectorNumElements(), Unused) ||
5672        isUZPMask(Mask, LT.second.getVectorNumElements(), Unused) ||
5673        isREVMask(Mask, LT.second.getScalarSizeInBits(),
5674                  LT.second.getVectorNumElements(), 16) ||
5675        isREVMask(Mask, LT.second.getScalarSizeInBits(),
5676                  LT.second.getVectorNumElements(), 32) ||
5677        isREVMask(Mask, LT.second.getScalarSizeInBits(),
5678                  LT.second.getVectorNumElements(), 64) ||
5679        // Check for non-zero lane splats
5680        all_of(drop_begin(Mask),
5681               [&Mask](int M) { return M < 0 || M == Mask[0]; })))
5682     return 1;
5683 
5684   if (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Transpose ||
5685       Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc ||
5686       Kind == TTI::SK_Reverse || Kind == TTI::SK_Splice) {
5687     static const CostTblEntry ShuffleTbl[] = {
5688         // Broadcast shuffle kinds can be performed with 'dup'.
5689         {TTI::SK_Broadcast, MVT::v8i8, 1},
5690         {TTI::SK_Broadcast, MVT::v16i8, 1},
5691         {TTI::SK_Broadcast, MVT::v4i16, 1},
5692         {TTI::SK_Broadcast, MVT::v8i16, 1},
5693         {TTI::SK_Broadcast, MVT::v2i32, 1},
5694         {TTI::SK_Broadcast, MVT::v4i32, 1},
5695         {TTI::SK_Broadcast, MVT::v2i64, 1},
5696         {TTI::SK_Broadcast, MVT::v4f16, 1},
5697         {TTI::SK_Broadcast, MVT::v8f16, 1},
5698         {TTI::SK_Broadcast, MVT::v4bf16, 1},
5699         {TTI::SK_Broadcast, MVT::v8bf16, 1},
5700         {TTI::SK_Broadcast, MVT::v2f32, 1},
5701         {TTI::SK_Broadcast, MVT::v4f32, 1},
5702         {TTI::SK_Broadcast, MVT::v2f64, 1},
5703         // Transpose shuffle kinds can be performed with 'trn1/trn2' and
5704         // 'zip1/zip2' instructions.
5705         {TTI::SK_Transpose, MVT::v8i8, 1},
5706         {TTI::SK_Transpose, MVT::v16i8, 1},
5707         {TTI::SK_Transpose, MVT::v4i16, 1},
5708         {TTI::SK_Transpose, MVT::v8i16, 1},
5709         {TTI::SK_Transpose, MVT::v2i32, 1},
5710         {TTI::SK_Transpose, MVT::v4i32, 1},
5711         {TTI::SK_Transpose, MVT::v2i64, 1},
5712         {TTI::SK_Transpose, MVT::v4f16, 1},
5713         {TTI::SK_Transpose, MVT::v8f16, 1},
5714         {TTI::SK_Transpose, MVT::v4bf16, 1},
5715         {TTI::SK_Transpose, MVT::v8bf16, 1},
5716         {TTI::SK_Transpose, MVT::v2f32, 1},
5717         {TTI::SK_Transpose, MVT::v4f32, 1},
5718         {TTI::SK_Transpose, MVT::v2f64, 1},
5719         // Select shuffle kinds.
5720         // TODO: handle vXi8/vXi16.
5721         {TTI::SK_Select, MVT::v2i32, 1}, // mov.
5722         {TTI::SK_Select, MVT::v4i32, 2}, // rev+trn (or similar).
5723         {TTI::SK_Select, MVT::v2i64, 1}, // mov.
5724         {TTI::SK_Select, MVT::v2f32, 1}, // mov.
5725         {TTI::SK_Select, MVT::v4f32, 2}, // rev+trn (or similar).
5726         {TTI::SK_Select, MVT::v2f64, 1}, // mov.
5727         // PermuteSingleSrc shuffle kinds.
5728         {TTI::SK_PermuteSingleSrc, MVT::v2i32, 1}, // mov.
5729         {TTI::SK_PermuteSingleSrc, MVT::v4i32, 3}, // perfectshuffle worst case.
5730         {TTI::SK_PermuteSingleSrc, MVT::v2i64, 1}, // mov.
5731         {TTI::SK_PermuteSingleSrc, MVT::v2f32, 1}, // mov.
5732         {TTI::SK_PermuteSingleSrc, MVT::v4f32, 3}, // perfectshuffle worst case.
5733         {TTI::SK_PermuteSingleSrc, MVT::v2f64, 1}, // mov.
5734         {TTI::SK_PermuteSingleSrc, MVT::v4i16, 3}, // perfectshuffle worst case.
5735         {TTI::SK_PermuteSingleSrc, MVT::v4f16, 3}, // perfectshuffle worst case.
5736         {TTI::SK_PermuteSingleSrc, MVT::v4bf16, 3}, // same
5737         {TTI::SK_PermuteSingleSrc, MVT::v8i16, 8},  // constpool + load + tbl
5738         {TTI::SK_PermuteSingleSrc, MVT::v8f16, 8},  // constpool + load + tbl
5739         {TTI::SK_PermuteSingleSrc, MVT::v8bf16, 8}, // constpool + load + tbl
5740         {TTI::SK_PermuteSingleSrc, MVT::v8i8, 8},   // constpool + load + tbl
5741         {TTI::SK_PermuteSingleSrc, MVT::v16i8, 8},  // constpool + load + tbl
5742         // Reverse can be lowered with `rev`.
5743         {TTI::SK_Reverse, MVT::v2i32, 1}, // REV64
5744         {TTI::SK_Reverse, MVT::v4i32, 2}, // REV64; EXT
5745         {TTI::SK_Reverse, MVT::v2i64, 1}, // EXT
5746         {TTI::SK_Reverse, MVT::v2f32, 1}, // REV64
5747         {TTI::SK_Reverse, MVT::v4f32, 2}, // REV64; EXT
5748         {TTI::SK_Reverse, MVT::v2f64, 1}, // EXT
5749         {TTI::SK_Reverse, MVT::v8f16, 2}, // REV64; EXT
5750         {TTI::SK_Reverse, MVT::v8bf16, 2}, // REV64; EXT
5751         {TTI::SK_Reverse, MVT::v8i16, 2}, // REV64; EXT
5752         {TTI::SK_Reverse, MVT::v16i8, 2}, // REV64; EXT
5753         {TTI::SK_Reverse, MVT::v4f16, 1}, // REV64
5754         {TTI::SK_Reverse, MVT::v4bf16, 1}, // REV64
5755         {TTI::SK_Reverse, MVT::v4i16, 1}, // REV64
5756         {TTI::SK_Reverse, MVT::v8i8, 1},  // REV64
5757         // Splice can all be lowered as `ext`.
5758         {TTI::SK_Splice, MVT::v2i32, 1},
5759         {TTI::SK_Splice, MVT::v4i32, 1},
5760         {TTI::SK_Splice, MVT::v2i64, 1},
5761         {TTI::SK_Splice, MVT::v2f32, 1},
5762         {TTI::SK_Splice, MVT::v4f32, 1},
5763         {TTI::SK_Splice, MVT::v2f64, 1},
5764         {TTI::SK_Splice, MVT::v8f16, 1},
5765         {TTI::SK_Splice, MVT::v8bf16, 1},
5766         {TTI::SK_Splice, MVT::v8i16, 1},
5767         {TTI::SK_Splice, MVT::v16i8, 1},
5768         {TTI::SK_Splice, MVT::v4f16, 1},
5769         {TTI::SK_Splice, MVT::v4bf16, 1},
5770         {TTI::SK_Splice, MVT::v4i16, 1},
5771         {TTI::SK_Splice, MVT::v8i8, 1},
5772         // Broadcast shuffle kinds for scalable vectors
5773         {TTI::SK_Broadcast, MVT::nxv16i8, 1},
5774         {TTI::SK_Broadcast, MVT::nxv8i16, 1},
5775         {TTI::SK_Broadcast, MVT::nxv4i32, 1},
5776         {TTI::SK_Broadcast, MVT::nxv2i64, 1},
5777         {TTI::SK_Broadcast, MVT::nxv2f16, 1},
5778         {TTI::SK_Broadcast, MVT::nxv4f16, 1},
5779         {TTI::SK_Broadcast, MVT::nxv8f16, 1},
5780         {TTI::SK_Broadcast, MVT::nxv2bf16, 1},
5781         {TTI::SK_Broadcast, MVT::nxv4bf16, 1},
5782         {TTI::SK_Broadcast, MVT::nxv8bf16, 1},
5783         {TTI::SK_Broadcast, MVT::nxv2f32, 1},
5784         {TTI::SK_Broadcast, MVT::nxv4f32, 1},
5785         {TTI::SK_Broadcast, MVT::nxv2f64, 1},
5786         {TTI::SK_Broadcast, MVT::nxv16i1, 1},
5787         {TTI::SK_Broadcast, MVT::nxv8i1, 1},
5788         {TTI::SK_Broadcast, MVT::nxv4i1, 1},
5789         {TTI::SK_Broadcast, MVT::nxv2i1, 1},
5790         // Handle the cases for vector.reverse with scalable vectors
5791         {TTI::SK_Reverse, MVT::nxv16i8, 1},
5792         {TTI::SK_Reverse, MVT::nxv8i16, 1},
5793         {TTI::SK_Reverse, MVT::nxv4i32, 1},
5794         {TTI::SK_Reverse, MVT::nxv2i64, 1},
5795         {TTI::SK_Reverse, MVT::nxv2f16, 1},
5796         {TTI::SK_Reverse, MVT::nxv4f16, 1},
5797         {TTI::SK_Reverse, MVT::nxv8f16, 1},
5798         {TTI::SK_Reverse, MVT::nxv2bf16, 1},
5799         {TTI::SK_Reverse, MVT::nxv4bf16, 1},
5800         {TTI::SK_Reverse, MVT::nxv8bf16, 1},
5801         {TTI::SK_Reverse, MVT::nxv2f32, 1},
5802         {TTI::SK_Reverse, MVT::nxv4f32, 1},
5803         {TTI::SK_Reverse, MVT::nxv2f64, 1},
5804         {TTI::SK_Reverse, MVT::nxv16i1, 1},
5805         {TTI::SK_Reverse, MVT::nxv8i1, 1},
5806         {TTI::SK_Reverse, MVT::nxv4i1, 1},
5807         {TTI::SK_Reverse, MVT::nxv2i1, 1},
5808     };
5809     if (const auto *Entry = CostTableLookup(ShuffleTbl, Kind, LT.second))
5810       return LT.first * Entry->Cost;
5811   }
5812 
5813   if (Kind == TTI::SK_Splice && isa<ScalableVectorType>(SrcTy))
5814     return getSpliceCost(SrcTy, Index, CostKind);
5815 
5816   // Inserting a subvector can often be done with either a D, S or H register
5817   // move, so long as the inserted vector is "aligned".
5818   if (Kind == TTI::SK_InsertSubvector && LT.second.isFixedLengthVector() &&
5819       LT.second.getSizeInBits() <= 128 && SubTp) {
5820     std::pair<InstructionCost, MVT> SubLT = getTypeLegalizationCost(SubTp);
5821     if (SubLT.second.isVector()) {
5822       int NumElts = LT.second.getVectorNumElements();
5823       int NumSubElts = SubLT.second.getVectorNumElements();
5824       if ((Index % NumSubElts) == 0 && (NumElts % NumSubElts) == 0)
5825         return SubLT.first;
5826     }
5827   }
5828 
5829   // Restore optimal kind.
5830   if (IsExtractSubvector)
5831     Kind = TTI::SK_ExtractSubvector;
5832   return BaseT::getShuffleCost(Kind, DstTy, SrcTy, Mask, CostKind, Index, SubTp,
5833                                Args, CxtI);
5834 }
5835 
containsDecreasingPointers(Loop * TheLoop,PredicatedScalarEvolution * PSE)5836 static bool containsDecreasingPointers(Loop *TheLoop,
5837                                        PredicatedScalarEvolution *PSE) {
5838   const auto &Strides = DenseMap<Value *, const SCEV *>();
5839   for (BasicBlock *BB : TheLoop->blocks()) {
5840     // Scan the instructions in the block and look for addresses that are
5841     // consecutive and decreasing.
5842     for (Instruction &I : *BB) {
5843       if (isa<LoadInst>(&I) || isa<StoreInst>(&I)) {
5844         Value *Ptr = getLoadStorePointerOperand(&I);
5845         Type *AccessTy = getLoadStoreType(&I);
5846         if (getPtrStride(*PSE, AccessTy, Ptr, TheLoop, Strides, /*Assume=*/true,
5847                          /*ShouldCheckWrap=*/false)
5848                 .value_or(0) < 0)
5849           return true;
5850       }
5851     }
5852   }
5853   return false;
5854 }
5855 
preferFixedOverScalableIfEqualCost() const5856 bool AArch64TTIImpl::preferFixedOverScalableIfEqualCost() const {
5857   if (SVEPreferFixedOverScalableIfEqualCost.getNumOccurrences())
5858     return SVEPreferFixedOverScalableIfEqualCost;
5859   return ST->useFixedOverScalableIfEqualCost();
5860 }
5861 
getEpilogueVectorizationMinVF() const5862 unsigned AArch64TTIImpl::getEpilogueVectorizationMinVF() const {
5863   return ST->getEpilogueVectorizationMinVF();
5864 }
5865 
preferPredicateOverEpilogue(TailFoldingInfo * TFI) const5866 bool AArch64TTIImpl::preferPredicateOverEpilogue(TailFoldingInfo *TFI) const {
5867   if (!ST->hasSVE())
5868     return false;
5869 
5870   // We don't currently support vectorisation with interleaving for SVE - with
5871   // such loops we're better off not using tail-folding. This gives us a chance
5872   // to fall back on fixed-width vectorisation using NEON's ld2/st2/etc.
5873   if (TFI->IAI->hasGroups())
5874     return false;
5875 
5876   TailFoldingOpts Required = TailFoldingOpts::Disabled;
5877   if (TFI->LVL->getReductionVars().size())
5878     Required |= TailFoldingOpts::Reductions;
5879   if (TFI->LVL->getFixedOrderRecurrences().size())
5880     Required |= TailFoldingOpts::Recurrences;
5881 
5882   // We call this to discover whether any load/store pointers in the loop have
5883   // negative strides. This will require extra work to reverse the loop
5884   // predicate, which may be expensive.
5885   if (containsDecreasingPointers(TFI->LVL->getLoop(),
5886                                  TFI->LVL->getPredicatedScalarEvolution()))
5887     Required |= TailFoldingOpts::Reverse;
5888   if (Required == TailFoldingOpts::Disabled)
5889     Required |= TailFoldingOpts::Simple;
5890 
5891   if (!TailFoldingOptionLoc.satisfies(ST->getSVETailFoldingDefaultOpts(),
5892                                       Required))
5893     return false;
5894 
5895   // Don't tail-fold for tight loops where we would be better off interleaving
5896   // with an unpredicated loop.
5897   unsigned NumInsns = 0;
5898   for (BasicBlock *BB : TFI->LVL->getLoop()->blocks()) {
5899     NumInsns += BB->sizeWithoutDebug();
5900   }
5901 
5902   // We expect 4 of these to be a IV PHI, IV add, IV compare and branch.
5903   return NumInsns >= SVETailFoldInsnThreshold;
5904 }
5905 
5906 InstructionCost
getScalingFactorCost(Type * Ty,GlobalValue * BaseGV,StackOffset BaseOffset,bool HasBaseReg,int64_t Scale,unsigned AddrSpace) const5907 AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
5908                                      StackOffset BaseOffset, bool HasBaseReg,
5909                                      int64_t Scale, unsigned AddrSpace) const {
5910   // Scaling factors are not free at all.
5911   // Operands                     | Rt Latency
5912   // -------------------------------------------
5913   // Rt, [Xn, Xm]                 | 4
5914   // -------------------------------------------
5915   // Rt, [Xn, Xm, lsl #imm]       | Rn: 4 Rm: 5
5916   // Rt, [Xn, Wm, <extend> #imm]  |
5917   TargetLoweringBase::AddrMode AM;
5918   AM.BaseGV = BaseGV;
5919   AM.BaseOffs = BaseOffset.getFixed();
5920   AM.HasBaseReg = HasBaseReg;
5921   AM.Scale = Scale;
5922   AM.ScalableOffset = BaseOffset.getScalable();
5923   if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
5924     // Scale represents reg2 * scale, thus account for 1 if
5925     // it is not equal to 0 or 1.
5926     return AM.Scale != 0 && AM.Scale != 1;
5927   return InstructionCost::getInvalid();
5928 }
5929 
shouldTreatInstructionLikeSelect(const Instruction * I) const5930 bool AArch64TTIImpl::shouldTreatInstructionLikeSelect(
5931     const Instruction *I) const {
5932   if (EnableOrLikeSelectOpt) {
5933     // For the binary operators (e.g. or) we need to be more careful than
5934     // selects, here we only transform them if they are already at a natural
5935     // break point in the code - the end of a block with an unconditional
5936     // terminator.
5937     if (I->getOpcode() == Instruction::Or &&
5938         isa<BranchInst>(I->getNextNode()) &&
5939         cast<BranchInst>(I->getNextNode())->isUnconditional())
5940       return true;
5941 
5942     if (I->getOpcode() == Instruction::Add ||
5943         I->getOpcode() == Instruction::Sub)
5944       return true;
5945   }
5946   return BaseT::shouldTreatInstructionLikeSelect(I);
5947 }
5948 
isLSRCostLess(const TargetTransformInfo::LSRCost & C1,const TargetTransformInfo::LSRCost & C2) const5949 bool AArch64TTIImpl::isLSRCostLess(
5950     const TargetTransformInfo::LSRCost &C1,
5951     const TargetTransformInfo::LSRCost &C2) const {
5952   // AArch64 specific here is adding the number of instructions to the
5953   // comparison (though not as the first consideration, as some targets do)
5954   // along with changing the priority of the base additions.
5955   // TODO: Maybe a more nuanced tradeoff between instruction count
5956   // and number of registers? To be investigated at a later date.
5957   if (EnableLSRCostOpt)
5958     return std::tie(C1.NumRegs, C1.Insns, C1.NumBaseAdds, C1.AddRecCost,
5959                     C1.NumIVMuls, C1.ScaleCost, C1.ImmCost, C1.SetupCost) <
5960            std::tie(C2.NumRegs, C2.Insns, C2.NumBaseAdds, C2.AddRecCost,
5961                     C2.NumIVMuls, C2.ScaleCost, C2.ImmCost, C2.SetupCost);
5962 
5963   return TargetTransformInfoImplBase::isLSRCostLess(C1, C2);
5964 }
5965 
isSplatShuffle(Value * V)5966 static bool isSplatShuffle(Value *V) {
5967   if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V))
5968     return all_equal(Shuf->getShuffleMask());
5969   return false;
5970 }
5971 
5972 /// Check if both Op1 and Op2 are shufflevector extracts of either the lower
5973 /// or upper half of the vector elements.
areExtractShuffleVectors(Value * Op1,Value * Op2,bool AllowSplat=false)5974 static bool areExtractShuffleVectors(Value *Op1, Value *Op2,
5975                                      bool AllowSplat = false) {
5976   // Scalable types can't be extract shuffle vectors.
5977   if (Op1->getType()->isScalableTy() || Op2->getType()->isScalableTy())
5978     return false;
5979 
5980   auto areTypesHalfed = [](Value *FullV, Value *HalfV) {
5981     auto *FullTy = FullV->getType();
5982     auto *HalfTy = HalfV->getType();
5983     return FullTy->getPrimitiveSizeInBits().getFixedValue() ==
5984            2 * HalfTy->getPrimitiveSizeInBits().getFixedValue();
5985   };
5986 
5987   auto extractHalf = [](Value *FullV, Value *HalfV) {
5988     auto *FullVT = cast<FixedVectorType>(FullV->getType());
5989     auto *HalfVT = cast<FixedVectorType>(HalfV->getType());
5990     return FullVT->getNumElements() == 2 * HalfVT->getNumElements();
5991   };
5992 
5993   ArrayRef<int> M1, M2;
5994   Value *S1Op1 = nullptr, *S2Op1 = nullptr;
5995   if (!match(Op1, m_Shuffle(m_Value(S1Op1), m_Undef(), m_Mask(M1))) ||
5996       !match(Op2, m_Shuffle(m_Value(S2Op1), m_Undef(), m_Mask(M2))))
5997     return false;
5998 
5999   // If we allow splats, set S1Op1/S2Op1 to nullptr for the relevant arg so that
6000   // it is not checked as an extract below.
6001   if (AllowSplat && isSplatShuffle(Op1))
6002     S1Op1 = nullptr;
6003   if (AllowSplat && isSplatShuffle(Op2))
6004     S2Op1 = nullptr;
6005 
6006   // Check that the operands are half as wide as the result and we extract
6007   // half of the elements of the input vectors.
6008   if ((S1Op1 && (!areTypesHalfed(S1Op1, Op1) || !extractHalf(S1Op1, Op1))) ||
6009       (S2Op1 && (!areTypesHalfed(S2Op1, Op2) || !extractHalf(S2Op1, Op2))))
6010     return false;
6011 
6012   // Check the mask extracts either the lower or upper half of vector
6013   // elements.
6014   int M1Start = 0;
6015   int M2Start = 0;
6016   int NumElements = cast<FixedVectorType>(Op1->getType())->getNumElements() * 2;
6017   if ((S1Op1 &&
6018        !ShuffleVectorInst::isExtractSubvectorMask(M1, NumElements, M1Start)) ||
6019       (S2Op1 &&
6020        !ShuffleVectorInst::isExtractSubvectorMask(M2, NumElements, M2Start)))
6021     return false;
6022 
6023   if ((M1Start != 0 && M1Start != (NumElements / 2)) ||
6024       (M2Start != 0 && M2Start != (NumElements / 2)))
6025     return false;
6026   if (S1Op1 && S2Op1 && M1Start != M2Start)
6027     return false;
6028 
6029   return true;
6030 }
6031 
6032 /// Check if Ext1 and Ext2 are extends of the same type, doubling the bitwidth
6033 /// of the vector elements.
areExtractExts(Value * Ext1,Value * Ext2)6034 static bool areExtractExts(Value *Ext1, Value *Ext2) {
6035   auto areExtDoubled = [](Instruction *Ext) {
6036     return Ext->getType()->getScalarSizeInBits() ==
6037            2 * Ext->getOperand(0)->getType()->getScalarSizeInBits();
6038   };
6039 
6040   if (!match(Ext1, m_ZExtOrSExt(m_Value())) ||
6041       !match(Ext2, m_ZExtOrSExt(m_Value())) ||
6042       !areExtDoubled(cast<Instruction>(Ext1)) ||
6043       !areExtDoubled(cast<Instruction>(Ext2)))
6044     return false;
6045 
6046   return true;
6047 }
6048 
6049 /// Check if Op could be used with vmull_high_p64 intrinsic.
isOperandOfVmullHighP64(Value * Op)6050 static bool isOperandOfVmullHighP64(Value *Op) {
6051   Value *VectorOperand = nullptr;
6052   ConstantInt *ElementIndex = nullptr;
6053   return match(Op, m_ExtractElt(m_Value(VectorOperand),
6054                                 m_ConstantInt(ElementIndex))) &&
6055          ElementIndex->getValue() == 1 &&
6056          isa<FixedVectorType>(VectorOperand->getType()) &&
6057          cast<FixedVectorType>(VectorOperand->getType())->getNumElements() == 2;
6058 }
6059 
6060 /// Check if Op1 and Op2 could be used with vmull_high_p64 intrinsic.
areOperandsOfVmullHighP64(Value * Op1,Value * Op2)6061 static bool areOperandsOfVmullHighP64(Value *Op1, Value *Op2) {
6062   return isOperandOfVmullHighP64(Op1) && isOperandOfVmullHighP64(Op2);
6063 }
6064 
shouldSinkVectorOfPtrs(Value * Ptrs,SmallVectorImpl<Use * > & Ops)6065 static bool shouldSinkVectorOfPtrs(Value *Ptrs, SmallVectorImpl<Use *> &Ops) {
6066   // Restrict ourselves to the form CodeGenPrepare typically constructs.
6067   auto *GEP = dyn_cast<GetElementPtrInst>(Ptrs);
6068   if (!GEP || GEP->getNumOperands() != 2)
6069     return false;
6070 
6071   Value *Base = GEP->getOperand(0);
6072   Value *Offsets = GEP->getOperand(1);
6073 
6074   // We only care about scalar_base+vector_offsets.
6075   if (Base->getType()->isVectorTy() || !Offsets->getType()->isVectorTy())
6076     return false;
6077 
6078   // Sink extends that would allow us to use 32-bit offset vectors.
6079   if (isa<SExtInst>(Offsets) || isa<ZExtInst>(Offsets)) {
6080     auto *OffsetsInst = cast<Instruction>(Offsets);
6081     if (OffsetsInst->getType()->getScalarSizeInBits() > 32 &&
6082         OffsetsInst->getOperand(0)->getType()->getScalarSizeInBits() <= 32)
6083       Ops.push_back(&GEP->getOperandUse(1));
6084   }
6085 
6086   // Sink the GEP.
6087   return true;
6088 }
6089 
6090 /// We want to sink following cases:
6091 /// (add|sub|gep) A, ((mul|shl) vscale, imm); (add|sub|gep) A, vscale;
6092 /// (add|sub|gep) A, ((mul|shl) zext(vscale), imm);
shouldSinkVScale(Value * Op,SmallVectorImpl<Use * > & Ops)6093 static bool shouldSinkVScale(Value *Op, SmallVectorImpl<Use *> &Ops) {
6094   if (match(Op, m_VScale()))
6095     return true;
6096   if (match(Op, m_Shl(m_VScale(), m_ConstantInt())) ||
6097       match(Op, m_Mul(m_VScale(), m_ConstantInt()))) {
6098     Ops.push_back(&cast<Instruction>(Op)->getOperandUse(0));
6099     return true;
6100   }
6101   if (match(Op, m_Shl(m_ZExt(m_VScale()), m_ConstantInt())) ||
6102       match(Op, m_Mul(m_ZExt(m_VScale()), m_ConstantInt()))) {
6103     Value *ZExtOp = cast<Instruction>(Op)->getOperand(0);
6104     Ops.push_back(&cast<Instruction>(ZExtOp)->getOperandUse(0));
6105     Ops.push_back(&cast<Instruction>(Op)->getOperandUse(0));
6106     return true;
6107   }
6108   return false;
6109 }
6110 
6111 /// Check if sinking \p I's operands to I's basic block is profitable, because
6112 /// the operands can be folded into a target instruction, e.g.
6113 /// shufflevectors extracts and/or sext/zext can be folded into (u,s)subl(2).
isProfitableToSinkOperands(Instruction * I,SmallVectorImpl<Use * > & Ops) const6114 bool AArch64TTIImpl::isProfitableToSinkOperands(
6115     Instruction *I, SmallVectorImpl<Use *> &Ops) const {
6116   if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
6117     switch (II->getIntrinsicID()) {
6118     case Intrinsic::aarch64_neon_smull:
6119     case Intrinsic::aarch64_neon_umull:
6120       if (areExtractShuffleVectors(II->getOperand(0), II->getOperand(1),
6121                                    /*AllowSplat=*/true)) {
6122         Ops.push_back(&II->getOperandUse(0));
6123         Ops.push_back(&II->getOperandUse(1));
6124         return true;
6125       }
6126       [[fallthrough]];
6127 
6128     case Intrinsic::fma:
6129     case Intrinsic::fmuladd:
6130       if (isa<VectorType>(I->getType()) &&
6131           cast<VectorType>(I->getType())->getElementType()->isHalfTy() &&
6132           !ST->hasFullFP16())
6133         return false;
6134       [[fallthrough]];
6135     case Intrinsic::aarch64_neon_sqdmull:
6136     case Intrinsic::aarch64_neon_sqdmulh:
6137     case Intrinsic::aarch64_neon_sqrdmulh:
6138       // Sink splats for index lane variants
6139       if (isSplatShuffle(II->getOperand(0)))
6140         Ops.push_back(&II->getOperandUse(0));
6141       if (isSplatShuffle(II->getOperand(1)))
6142         Ops.push_back(&II->getOperandUse(1));
6143       return !Ops.empty();
6144     case Intrinsic::aarch64_neon_fmlal:
6145     case Intrinsic::aarch64_neon_fmlal2:
6146     case Intrinsic::aarch64_neon_fmlsl:
6147     case Intrinsic::aarch64_neon_fmlsl2:
6148       // Sink splats for index lane variants
6149       if (isSplatShuffle(II->getOperand(1)))
6150         Ops.push_back(&II->getOperandUse(1));
6151       if (isSplatShuffle(II->getOperand(2)))
6152         Ops.push_back(&II->getOperandUse(2));
6153       return !Ops.empty();
6154     case Intrinsic::aarch64_sve_ptest_first:
6155     case Intrinsic::aarch64_sve_ptest_last:
6156       if (auto *IIOp = dyn_cast<IntrinsicInst>(II->getOperand(0)))
6157         if (IIOp->getIntrinsicID() == Intrinsic::aarch64_sve_ptrue)
6158           Ops.push_back(&II->getOperandUse(0));
6159       return !Ops.empty();
6160     case Intrinsic::aarch64_sme_write_horiz:
6161     case Intrinsic::aarch64_sme_write_vert:
6162     case Intrinsic::aarch64_sme_writeq_horiz:
6163     case Intrinsic::aarch64_sme_writeq_vert: {
6164       auto *Idx = dyn_cast<Instruction>(II->getOperand(1));
6165       if (!Idx || Idx->getOpcode() != Instruction::Add)
6166         return false;
6167       Ops.push_back(&II->getOperandUse(1));
6168       return true;
6169     }
6170     case Intrinsic::aarch64_sme_read_horiz:
6171     case Intrinsic::aarch64_sme_read_vert:
6172     case Intrinsic::aarch64_sme_readq_horiz:
6173     case Intrinsic::aarch64_sme_readq_vert:
6174     case Intrinsic::aarch64_sme_ld1b_vert:
6175     case Intrinsic::aarch64_sme_ld1h_vert:
6176     case Intrinsic::aarch64_sme_ld1w_vert:
6177     case Intrinsic::aarch64_sme_ld1d_vert:
6178     case Intrinsic::aarch64_sme_ld1q_vert:
6179     case Intrinsic::aarch64_sme_st1b_vert:
6180     case Intrinsic::aarch64_sme_st1h_vert:
6181     case Intrinsic::aarch64_sme_st1w_vert:
6182     case Intrinsic::aarch64_sme_st1d_vert:
6183     case Intrinsic::aarch64_sme_st1q_vert:
6184     case Intrinsic::aarch64_sme_ld1b_horiz:
6185     case Intrinsic::aarch64_sme_ld1h_horiz:
6186     case Intrinsic::aarch64_sme_ld1w_horiz:
6187     case Intrinsic::aarch64_sme_ld1d_horiz:
6188     case Intrinsic::aarch64_sme_ld1q_horiz:
6189     case Intrinsic::aarch64_sme_st1b_horiz:
6190     case Intrinsic::aarch64_sme_st1h_horiz:
6191     case Intrinsic::aarch64_sme_st1w_horiz:
6192     case Intrinsic::aarch64_sme_st1d_horiz:
6193     case Intrinsic::aarch64_sme_st1q_horiz: {
6194       auto *Idx = dyn_cast<Instruction>(II->getOperand(3));
6195       if (!Idx || Idx->getOpcode() != Instruction::Add)
6196         return false;
6197       Ops.push_back(&II->getOperandUse(3));
6198       return true;
6199     }
6200     case Intrinsic::aarch64_neon_pmull:
6201       if (!areExtractShuffleVectors(II->getOperand(0), II->getOperand(1)))
6202         return false;
6203       Ops.push_back(&II->getOperandUse(0));
6204       Ops.push_back(&II->getOperandUse(1));
6205       return true;
6206     case Intrinsic::aarch64_neon_pmull64:
6207       if (!areOperandsOfVmullHighP64(II->getArgOperand(0),
6208                                      II->getArgOperand(1)))
6209         return false;
6210       Ops.push_back(&II->getArgOperandUse(0));
6211       Ops.push_back(&II->getArgOperandUse(1));
6212       return true;
6213     case Intrinsic::masked_gather:
6214       if (!shouldSinkVectorOfPtrs(II->getArgOperand(0), Ops))
6215         return false;
6216       Ops.push_back(&II->getArgOperandUse(0));
6217       return true;
6218     case Intrinsic::masked_scatter:
6219       if (!shouldSinkVectorOfPtrs(II->getArgOperand(1), Ops))
6220         return false;
6221       Ops.push_back(&II->getArgOperandUse(1));
6222       return true;
6223     default:
6224       return false;
6225     }
6226   }
6227 
6228   auto ShouldSinkCondition = [](Value *Cond) -> bool {
6229     auto *II = dyn_cast<IntrinsicInst>(Cond);
6230     return II && II->getIntrinsicID() == Intrinsic::vector_reduce_or &&
6231            isa<ScalableVectorType>(II->getOperand(0)->getType());
6232   };
6233 
6234   switch (I->getOpcode()) {
6235   case Instruction::GetElementPtr:
6236   case Instruction::Add:
6237   case Instruction::Sub:
6238     // Sink vscales closer to uses for better isel
6239     for (unsigned Op = 0; Op < I->getNumOperands(); ++Op) {
6240       if (shouldSinkVScale(I->getOperand(Op), Ops)) {
6241         Ops.push_back(&I->getOperandUse(Op));
6242         return true;
6243       }
6244     }
6245     break;
6246   case Instruction::Select: {
6247     if (!ShouldSinkCondition(I->getOperand(0)))
6248       return false;
6249 
6250     Ops.push_back(&I->getOperandUse(0));
6251     return true;
6252   }
6253   case Instruction::Br: {
6254     if (cast<BranchInst>(I)->isUnconditional())
6255       return false;
6256 
6257     if (!ShouldSinkCondition(cast<BranchInst>(I)->getCondition()))
6258       return false;
6259 
6260     Ops.push_back(&I->getOperandUse(0));
6261     return true;
6262   }
6263   default:
6264     break;
6265   }
6266 
6267   if (!I->getType()->isVectorTy())
6268     return false;
6269 
6270   switch (I->getOpcode()) {
6271   case Instruction::Sub:
6272   case Instruction::Add: {
6273     if (!areExtractExts(I->getOperand(0), I->getOperand(1)))
6274       return false;
6275 
6276     // If the exts' operands extract either the lower or upper elements, we
6277     // can sink them too.
6278     auto Ext1 = cast<Instruction>(I->getOperand(0));
6279     auto Ext2 = cast<Instruction>(I->getOperand(1));
6280     if (areExtractShuffleVectors(Ext1->getOperand(0), Ext2->getOperand(0))) {
6281       Ops.push_back(&Ext1->getOperandUse(0));
6282       Ops.push_back(&Ext2->getOperandUse(0));
6283     }
6284 
6285     Ops.push_back(&I->getOperandUse(0));
6286     Ops.push_back(&I->getOperandUse(1));
6287 
6288     return true;
6289   }
6290   case Instruction::Or: {
6291     // Pattern: Or(And(MaskValue, A), And(Not(MaskValue), B)) ->
6292     // bitselect(MaskValue, A, B) where Not(MaskValue) = Xor(MaskValue, -1)
6293     if (ST->hasNEON()) {
6294       Instruction *OtherAnd, *IA, *IB;
6295       Value *MaskValue;
6296       // MainAnd refers to And instruction that has 'Not' as one of its operands
6297       if (match(I, m_c_Or(m_OneUse(m_Instruction(OtherAnd)),
6298                           m_OneUse(m_c_And(m_OneUse(m_Not(m_Value(MaskValue))),
6299                                            m_Instruction(IA)))))) {
6300         if (match(OtherAnd,
6301                   m_c_And(m_Specific(MaskValue), m_Instruction(IB)))) {
6302           Instruction *MainAnd = I->getOperand(0) == OtherAnd
6303                                      ? cast<Instruction>(I->getOperand(1))
6304                                      : cast<Instruction>(I->getOperand(0));
6305 
6306           // Both Ands should be in same basic block as Or
6307           if (I->getParent() != MainAnd->getParent() ||
6308               I->getParent() != OtherAnd->getParent())
6309             return false;
6310 
6311           // Non-mask operands of both Ands should also be in same basic block
6312           if (I->getParent() != IA->getParent() ||
6313               I->getParent() != IB->getParent())
6314             return false;
6315 
6316           Ops.push_back(
6317               &MainAnd->getOperandUse(MainAnd->getOperand(0) == IA ? 1 : 0));
6318           Ops.push_back(&I->getOperandUse(0));
6319           Ops.push_back(&I->getOperandUse(1));
6320 
6321           return true;
6322         }
6323       }
6324     }
6325 
6326     return false;
6327   }
6328   case Instruction::Mul: {
6329     auto ShouldSinkSplatForIndexedVariant = [](Value *V) {
6330       auto *Ty = cast<VectorType>(V->getType());
6331       // For SVE the lane-indexing is within 128-bits, so we can't fold splats.
6332       if (Ty->isScalableTy())
6333         return false;
6334 
6335       // Indexed variants of Mul exist for i16 and i32 element types only.
6336       return Ty->getScalarSizeInBits() == 16 || Ty->getScalarSizeInBits() == 32;
6337     };
6338 
6339     int NumZExts = 0, NumSExts = 0;
6340     for (auto &Op : I->operands()) {
6341       // Make sure we are not already sinking this operand
6342       if (any_of(Ops, [&](Use *U) { return U->get() == Op; }))
6343         continue;
6344 
6345       if (match(&Op, m_ZExtOrSExt(m_Value()))) {
6346         auto *Ext = cast<Instruction>(Op);
6347         auto *ExtOp = Ext->getOperand(0);
6348         if (isSplatShuffle(ExtOp) && ShouldSinkSplatForIndexedVariant(ExtOp))
6349           Ops.push_back(&Ext->getOperandUse(0));
6350         Ops.push_back(&Op);
6351 
6352         if (isa<SExtInst>(Ext))
6353           NumSExts++;
6354         else
6355           NumZExts++;
6356 
6357         continue;
6358       }
6359 
6360       ShuffleVectorInst *Shuffle = dyn_cast<ShuffleVectorInst>(Op);
6361       if (!Shuffle)
6362         continue;
6363 
6364       // If the Shuffle is a splat and the operand is a zext/sext, sinking the
6365       // operand and the s/zext can help create indexed s/umull. This is
6366       // especially useful to prevent i64 mul being scalarized.
6367       if (isSplatShuffle(Shuffle) &&
6368           match(Shuffle->getOperand(0), m_ZExtOrSExt(m_Value()))) {
6369         Ops.push_back(&Shuffle->getOperandUse(0));
6370         Ops.push_back(&Op);
6371         if (match(Shuffle->getOperand(0), m_SExt(m_Value())))
6372           NumSExts++;
6373         else
6374           NumZExts++;
6375         continue;
6376       }
6377 
6378       Value *ShuffleOperand = Shuffle->getOperand(0);
6379       InsertElementInst *Insert = dyn_cast<InsertElementInst>(ShuffleOperand);
6380       if (!Insert)
6381         continue;
6382 
6383       Instruction *OperandInstr = dyn_cast<Instruction>(Insert->getOperand(1));
6384       if (!OperandInstr)
6385         continue;
6386 
6387       ConstantInt *ElementConstant =
6388           dyn_cast<ConstantInt>(Insert->getOperand(2));
6389       // Check that the insertelement is inserting into element 0
6390       if (!ElementConstant || !ElementConstant->isZero())
6391         continue;
6392 
6393       unsigned Opcode = OperandInstr->getOpcode();
6394       if (Opcode == Instruction::SExt)
6395         NumSExts++;
6396       else if (Opcode == Instruction::ZExt)
6397         NumZExts++;
6398       else {
6399         // If we find that the top bits are known 0, then we can sink and allow
6400         // the backend to generate a umull.
6401         unsigned Bitwidth = I->getType()->getScalarSizeInBits();
6402         APInt UpperMask = APInt::getHighBitsSet(Bitwidth, Bitwidth / 2);
6403         if (!MaskedValueIsZero(OperandInstr, UpperMask, DL))
6404           continue;
6405         NumZExts++;
6406       }
6407 
6408       // And(Load) is excluded to prevent CGP getting stuck in a loop of sinking
6409       // the And, just to hoist it again back to the load.
6410       if (!match(OperandInstr, m_And(m_Load(m_Value()), m_Value())))
6411         Ops.push_back(&Insert->getOperandUse(1));
6412       Ops.push_back(&Shuffle->getOperandUse(0));
6413       Ops.push_back(&Op);
6414     }
6415 
6416     // It is profitable to sink if we found two of the same type of extends.
6417     if (!Ops.empty() && (NumSExts == 2 || NumZExts == 2))
6418       return true;
6419 
6420     // Otherwise, see if we should sink splats for indexed variants.
6421     if (!ShouldSinkSplatForIndexedVariant(I))
6422       return false;
6423 
6424     Ops.clear();
6425     if (isSplatShuffle(I->getOperand(0)))
6426       Ops.push_back(&I->getOperandUse(0));
6427     if (isSplatShuffle(I->getOperand(1)))
6428       Ops.push_back(&I->getOperandUse(1));
6429 
6430     return !Ops.empty();
6431   }
6432   case Instruction::FMul: {
6433     // For SVE the lane-indexing is within 128-bits, so we can't fold splats.
6434     if (I->getType()->isScalableTy())
6435       return false;
6436 
6437     if (cast<VectorType>(I->getType())->getElementType()->isHalfTy() &&
6438         !ST->hasFullFP16())
6439       return false;
6440 
6441     // Sink splats for index lane variants
6442     if (isSplatShuffle(I->getOperand(0)))
6443       Ops.push_back(&I->getOperandUse(0));
6444     if (isSplatShuffle(I->getOperand(1)))
6445       Ops.push_back(&I->getOperandUse(1));
6446     return !Ops.empty();
6447   }
6448   default:
6449     return false;
6450   }
6451   return false;
6452 }
6453