1 //===-- AArch64TargetTransformInfo.cpp - AArch64 specific TTI -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "AArch64TargetTransformInfo.h"
10 #include "AArch64ExpandImm.h"
11 #include "AArch64PerfectShuffle.h"
12 #include "MCTargetDesc/AArch64AddressingModes.h"
13 #include "Utils/AArch64SMEAttributes.h"
14 #include "llvm/ADT/DenseMap.h"
15 #include "llvm/Analysis/LoopInfo.h"
16 #include "llvm/Analysis/TargetTransformInfo.h"
17 #include "llvm/CodeGen/BasicTTIImpl.h"
18 #include "llvm/CodeGen/CostTable.h"
19 #include "llvm/CodeGen/TargetLowering.h"
20 #include "llvm/IR/DerivedTypes.h"
21 #include "llvm/IR/IntrinsicInst.h"
22 #include "llvm/IR/Intrinsics.h"
23 #include "llvm/IR/IntrinsicsAArch64.h"
24 #include "llvm/IR/PatternMatch.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/TargetParser/AArch64TargetParser.h"
27 #include "llvm/Transforms/InstCombine/InstCombiner.h"
28 #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h"
29 #include <algorithm>
30 #include <optional>
31 using namespace llvm;
32 using namespace llvm::PatternMatch;
33
34 #define DEBUG_TYPE "aarch64tti"
35
36 static cl::opt<bool> EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix",
37 cl::init(true), cl::Hidden);
38
39 static cl::opt<bool> SVEPreferFixedOverScalableIfEqualCost(
40 "sve-prefer-fixed-over-scalable-if-equal", cl::Hidden);
41
42 static cl::opt<unsigned> SVEGatherOverhead("sve-gather-overhead", cl::init(10),
43 cl::Hidden);
44
45 static cl::opt<unsigned> SVEScatterOverhead("sve-scatter-overhead",
46 cl::init(10), cl::Hidden);
47
48 static cl::opt<unsigned> SVETailFoldInsnThreshold("sve-tail-folding-insn-threshold",
49 cl::init(15), cl::Hidden);
50
51 static cl::opt<unsigned>
52 NeonNonConstStrideOverhead("neon-nonconst-stride-overhead", cl::init(10),
53 cl::Hidden);
54
55 static cl::opt<unsigned> CallPenaltyChangeSM(
56 "call-penalty-sm-change", cl::init(5), cl::Hidden,
57 cl::desc(
58 "Penalty of calling a function that requires a change to PSTATE.SM"));
59
60 static cl::opt<unsigned> InlineCallPenaltyChangeSM(
61 "inline-call-penalty-sm-change", cl::init(10), cl::Hidden,
62 cl::desc("Penalty of inlining a call that requires a change to PSTATE.SM"));
63
64 static cl::opt<bool> EnableOrLikeSelectOpt("enable-aarch64-or-like-select",
65 cl::init(true), cl::Hidden);
66
67 static cl::opt<bool> EnableLSRCostOpt("enable-aarch64-lsr-cost-opt",
68 cl::init(true), cl::Hidden);
69
70 // A complete guess as to a reasonable cost.
71 static cl::opt<unsigned>
72 BaseHistCntCost("aarch64-base-histcnt-cost", cl::init(8), cl::Hidden,
73 cl::desc("The cost of a histcnt instruction"));
74
75 static cl::opt<unsigned> DMBLookaheadThreshold(
76 "dmb-lookahead-threshold", cl::init(10), cl::Hidden,
77 cl::desc("The number of instructions to search for a redundant dmb"));
78
79 namespace {
80 class TailFoldingOption {
81 // These bitfields will only ever be set to something non-zero in operator=,
82 // when setting the -sve-tail-folding option. This option should always be of
83 // the form (default|simple|all|disable)[+(Flag1|Flag2|etc)], where here
84 // InitialBits is one of (disabled|all|simple). EnableBits represents
85 // additional flags we're enabling, and DisableBits for those flags we're
86 // disabling. The default flag is tracked in the variable NeedsDefault, since
87 // at the time of setting the option we may not know what the default value
88 // for the CPU is.
89 TailFoldingOpts InitialBits = TailFoldingOpts::Disabled;
90 TailFoldingOpts EnableBits = TailFoldingOpts::Disabled;
91 TailFoldingOpts DisableBits = TailFoldingOpts::Disabled;
92
93 // This value needs to be initialised to true in case the user does not
94 // explicitly set the -sve-tail-folding option.
95 bool NeedsDefault = true;
96
setInitialBits(TailFoldingOpts Bits)97 void setInitialBits(TailFoldingOpts Bits) { InitialBits = Bits; }
98
setNeedsDefault(bool V)99 void setNeedsDefault(bool V) { NeedsDefault = V; }
100
setEnableBit(TailFoldingOpts Bit)101 void setEnableBit(TailFoldingOpts Bit) {
102 EnableBits |= Bit;
103 DisableBits &= ~Bit;
104 }
105
setDisableBit(TailFoldingOpts Bit)106 void setDisableBit(TailFoldingOpts Bit) {
107 EnableBits &= ~Bit;
108 DisableBits |= Bit;
109 }
110
getBits(TailFoldingOpts DefaultBits) const111 TailFoldingOpts getBits(TailFoldingOpts DefaultBits) const {
112 TailFoldingOpts Bits = TailFoldingOpts::Disabled;
113
114 assert((InitialBits == TailFoldingOpts::Disabled || !NeedsDefault) &&
115 "Initial bits should only include one of "
116 "(disabled|all|simple|default)");
117 Bits = NeedsDefault ? DefaultBits : InitialBits;
118 Bits |= EnableBits;
119 Bits &= ~DisableBits;
120
121 return Bits;
122 }
123
reportError(std::string Opt)124 void reportError(std::string Opt) {
125 errs() << "invalid argument '" << Opt
126 << "' to -sve-tail-folding=; the option should be of the form\n"
127 " (disabled|all|default|simple)[+(reductions|recurrences"
128 "|reverse|noreductions|norecurrences|noreverse)]\n";
129 report_fatal_error("Unrecognised tail-folding option");
130 }
131
132 public:
133
operator =(const std::string & Val)134 void operator=(const std::string &Val) {
135 // If the user explicitly sets -sve-tail-folding= then treat as an error.
136 if (Val.empty()) {
137 reportError("");
138 return;
139 }
140
141 // Since the user is explicitly setting the option we don't automatically
142 // need the default unless they require it.
143 setNeedsDefault(false);
144
145 SmallVector<StringRef, 4> TailFoldTypes;
146 StringRef(Val).split(TailFoldTypes, '+', -1, false);
147
148 unsigned StartIdx = 1;
149 if (TailFoldTypes[0] == "disabled")
150 setInitialBits(TailFoldingOpts::Disabled);
151 else if (TailFoldTypes[0] == "all")
152 setInitialBits(TailFoldingOpts::All);
153 else if (TailFoldTypes[0] == "default")
154 setNeedsDefault(true);
155 else if (TailFoldTypes[0] == "simple")
156 setInitialBits(TailFoldingOpts::Simple);
157 else {
158 StartIdx = 0;
159 setInitialBits(TailFoldingOpts::Disabled);
160 }
161
162 for (unsigned I = StartIdx; I < TailFoldTypes.size(); I++) {
163 if (TailFoldTypes[I] == "reductions")
164 setEnableBit(TailFoldingOpts::Reductions);
165 else if (TailFoldTypes[I] == "recurrences")
166 setEnableBit(TailFoldingOpts::Recurrences);
167 else if (TailFoldTypes[I] == "reverse")
168 setEnableBit(TailFoldingOpts::Reverse);
169 else if (TailFoldTypes[I] == "noreductions")
170 setDisableBit(TailFoldingOpts::Reductions);
171 else if (TailFoldTypes[I] == "norecurrences")
172 setDisableBit(TailFoldingOpts::Recurrences);
173 else if (TailFoldTypes[I] == "noreverse")
174 setDisableBit(TailFoldingOpts::Reverse);
175 else
176 reportError(Val);
177 }
178 }
179
satisfies(TailFoldingOpts DefaultBits,TailFoldingOpts Required) const180 bool satisfies(TailFoldingOpts DefaultBits, TailFoldingOpts Required) const {
181 return (getBits(DefaultBits) & Required) == Required;
182 }
183 };
184 } // namespace
185
186 TailFoldingOption TailFoldingOptionLoc;
187
188 static cl::opt<TailFoldingOption, true, cl::parser<std::string>> SVETailFolding(
189 "sve-tail-folding",
190 cl::desc(
191 "Control the use of vectorisation using tail-folding for SVE where the"
192 " option is specified in the form (Initial)[+(Flag1|Flag2|...)]:"
193 "\ndisabled (Initial) No loop types will vectorize using "
194 "tail-folding"
195 "\ndefault (Initial) Uses the default tail-folding settings for "
196 "the target CPU"
197 "\nall (Initial) All legal loop types will vectorize using "
198 "tail-folding"
199 "\nsimple (Initial) Use tail-folding for simple loops (not "
200 "reductions or recurrences)"
201 "\nreductions Use tail-folding for loops containing reductions"
202 "\nnoreductions Inverse of above"
203 "\nrecurrences Use tail-folding for loops containing fixed order "
204 "recurrences"
205 "\nnorecurrences Inverse of above"
206 "\nreverse Use tail-folding for loops requiring reversed "
207 "predicates"
208 "\nnoreverse Inverse of above"),
209 cl::location(TailFoldingOptionLoc));
210
211 // Experimental option that will only be fully functional when the
212 // code-generator is changed to use SVE instead of NEON for all fixed-width
213 // operations.
214 static cl::opt<bool> EnableFixedwidthAutovecInStreamingMode(
215 "enable-fixedwidth-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
216
217 // Experimental option that will only be fully functional when the cost-model
218 // and code-generator have been changed to avoid using scalable vector
219 // instructions that are not legal in streaming SVE mode.
220 static cl::opt<bool> EnableScalableAutovecInStreamingMode(
221 "enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden);
222
isSMEABIRoutineCall(const CallInst & CI)223 static bool isSMEABIRoutineCall(const CallInst &CI) {
224 const auto *F = CI.getCalledFunction();
225 return F && StringSwitch<bool>(F->getName())
226 .Case("__arm_sme_state", true)
227 .Case("__arm_tpidr2_save", true)
228 .Case("__arm_tpidr2_restore", true)
229 .Case("__arm_za_disable", true)
230 .Default(false);
231 }
232
233 /// Returns true if the function has explicit operations that can only be
234 /// lowered using incompatible instructions for the selected mode. This also
235 /// returns true if the function F may use or modify ZA state.
hasPossibleIncompatibleOps(const Function * F)236 static bool hasPossibleIncompatibleOps(const Function *F) {
237 for (const BasicBlock &BB : *F) {
238 for (const Instruction &I : BB) {
239 // Be conservative for now and assume that any call to inline asm or to
240 // intrinsics could could result in non-streaming ops (e.g. calls to
241 // @llvm.aarch64.* or @llvm.gather/scatter intrinsics). We can assume that
242 // all native LLVM instructions can be lowered to compatible instructions.
243 if (isa<CallInst>(I) && !I.isDebugOrPseudoInst() &&
244 (cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) ||
245 isSMEABIRoutineCall(cast<CallInst>(I))))
246 return true;
247 }
248 }
249 return false;
250 }
251
getFeatureMask(const Function & F) const252 uint64_t AArch64TTIImpl::getFeatureMask(const Function &F) const {
253 StringRef AttributeStr =
254 isMultiversionedFunction(F) ? "fmv-features" : "target-features";
255 StringRef FeatureStr = F.getFnAttribute(AttributeStr).getValueAsString();
256 SmallVector<StringRef, 8> Features;
257 FeatureStr.split(Features, ",");
258 return AArch64::getFMVPriority(Features);
259 }
260
isMultiversionedFunction(const Function & F) const261 bool AArch64TTIImpl::isMultiversionedFunction(const Function &F) const {
262 return F.hasFnAttribute("fmv-features");
263 }
264
265 const FeatureBitset AArch64TTIImpl::InlineInverseFeatures = {
266 AArch64::FeatureExecuteOnly,
267 };
268
areInlineCompatible(const Function * Caller,const Function * Callee) const269 bool AArch64TTIImpl::areInlineCompatible(const Function *Caller,
270 const Function *Callee) const {
271 SMECallAttrs CallAttrs(*Caller, *Callee);
272
273 // When inlining, we should consider the body of the function, not the
274 // interface.
275 if (CallAttrs.callee().hasStreamingBody()) {
276 CallAttrs.callee().set(SMEAttrs::SM_Compatible, false);
277 CallAttrs.callee().set(SMEAttrs::SM_Enabled, true);
278 }
279
280 if (CallAttrs.callee().isNewZA() || CallAttrs.callee().isNewZT0())
281 return false;
282
283 if (CallAttrs.requiresLazySave() || CallAttrs.requiresSMChange() ||
284 CallAttrs.requiresPreservingZT0() ||
285 CallAttrs.requiresPreservingAllZAState()) {
286 if (hasPossibleIncompatibleOps(Callee))
287 return false;
288 }
289
290 const TargetMachine &TM = getTLI()->getTargetMachine();
291 const FeatureBitset &CallerBits =
292 TM.getSubtargetImpl(*Caller)->getFeatureBits();
293 const FeatureBitset &CalleeBits =
294 TM.getSubtargetImpl(*Callee)->getFeatureBits();
295 // Adjust the feature bitsets by inverting some of the bits. This is needed
296 // for target features that represent restrictions rather than capabilities,
297 // for example a "+execute-only" callee can be inlined into a caller without
298 // "+execute-only", but not vice versa.
299 FeatureBitset EffectiveCallerBits = CallerBits ^ InlineInverseFeatures;
300 FeatureBitset EffectiveCalleeBits = CalleeBits ^ InlineInverseFeatures;
301
302 return (EffectiveCallerBits & EffectiveCalleeBits) == EffectiveCalleeBits;
303 }
304
areTypesABICompatible(const Function * Caller,const Function * Callee,const ArrayRef<Type * > & Types) const305 bool AArch64TTIImpl::areTypesABICompatible(
306 const Function *Caller, const Function *Callee,
307 const ArrayRef<Type *> &Types) const {
308 if (!BaseT::areTypesABICompatible(Caller, Callee, Types))
309 return false;
310
311 // We need to ensure that argument promotion does not attempt to promote
312 // pointers to fixed-length vector types larger than 128 bits like
313 // <8 x float> (and pointers to aggregate types which have such fixed-length
314 // vector type members) into the values of the pointees. Such vector types
315 // are used for SVE VLS but there is no ABI for SVE VLS arguments and the
316 // backend cannot lower such value arguments. The 128-bit fixed-length SVE
317 // types can be safely treated as 128-bit NEON types and they cannot be
318 // distinguished in IR.
319 if (ST->useSVEForFixedLengthVectors() && llvm::any_of(Types, [](Type *Ty) {
320 auto FVTy = dyn_cast<FixedVectorType>(Ty);
321 return FVTy &&
322 FVTy->getScalarSizeInBits() * FVTy->getNumElements() > 128;
323 }))
324 return false;
325
326 return true;
327 }
328
329 unsigned
getInlineCallPenalty(const Function * F,const CallBase & Call,unsigned DefaultCallPenalty) const330 AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call,
331 unsigned DefaultCallPenalty) const {
332 // This function calculates a penalty for executing Call in F.
333 //
334 // There are two ways this function can be called:
335 // (1) F:
336 // call from F -> G (the call here is Call)
337 //
338 // For (1), Call.getCaller() == F, so it will always return a high cost if
339 // a streaming-mode change is required (thus promoting the need to inline the
340 // function)
341 //
342 // (2) F:
343 // call from F -> G (the call here is not Call)
344 // G:
345 // call from G -> H (the call here is Call)
346 //
347 // For (2), if after inlining the body of G into F the call to H requires a
348 // streaming-mode change, and the call to G from F would also require a
349 // streaming-mode change, then there is benefit to do the streaming-mode
350 // change only once and avoid inlining of G into F.
351
352 SMEAttrs FAttrs(*F);
353 SMECallAttrs CallAttrs(Call);
354
355 if (SMECallAttrs(FAttrs, CallAttrs.callee()).requiresSMChange()) {
356 if (F == Call.getCaller()) // (1)
357 return CallPenaltyChangeSM * DefaultCallPenalty;
358 if (SMECallAttrs(FAttrs, CallAttrs.caller()).requiresSMChange()) // (2)
359 return InlineCallPenaltyChangeSM * DefaultCallPenalty;
360 }
361
362 return DefaultCallPenalty;
363 }
364
shouldMaximizeVectorBandwidth(TargetTransformInfo::RegisterKind K) const365 bool AArch64TTIImpl::shouldMaximizeVectorBandwidth(
366 TargetTransformInfo::RegisterKind K) const {
367 assert(K != TargetTransformInfo::RGK_Scalar);
368 return (K == TargetTransformInfo::RGK_FixedWidthVector &&
369 ST->isNeonAvailable());
370 }
371
372 /// Calculate the cost of materializing a 64-bit value. This helper
373 /// method might only calculate a fraction of a larger immediate. Therefore it
374 /// is valid to return a cost of ZERO.
getIntImmCost(int64_t Val) const375 InstructionCost AArch64TTIImpl::getIntImmCost(int64_t Val) const {
376 // Check if the immediate can be encoded within an instruction.
377 if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64))
378 return 0;
379
380 if (Val < 0)
381 Val = ~Val;
382
383 // Calculate how many moves we will need to materialize this constant.
384 SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn;
385 AArch64_IMM::expandMOVImm(Val, 64, Insn);
386 return Insn.size();
387 }
388
389 /// Calculate the cost of materializing the given constant.
390 InstructionCost
getIntImmCost(const APInt & Imm,Type * Ty,TTI::TargetCostKind CostKind) const391 AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty,
392 TTI::TargetCostKind CostKind) const {
393 assert(Ty->isIntegerTy());
394
395 unsigned BitSize = Ty->getPrimitiveSizeInBits();
396 if (BitSize == 0)
397 return ~0U;
398
399 // Sign-extend all constants to a multiple of 64-bit.
400 APInt ImmVal = Imm;
401 if (BitSize & 0x3f)
402 ImmVal = Imm.sext((BitSize + 63) & ~0x3fU);
403
404 // Split the constant into 64-bit chunks and calculate the cost for each
405 // chunk.
406 InstructionCost Cost = 0;
407 for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) {
408 APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64);
409 int64_t Val = Tmp.getSExtValue();
410 Cost += getIntImmCost(Val);
411 }
412 // We need at least one instruction to materialze the constant.
413 return std::max<InstructionCost>(1, Cost);
414 }
415
getIntImmCostInst(unsigned Opcode,unsigned Idx,const APInt & Imm,Type * Ty,TTI::TargetCostKind CostKind,Instruction * Inst) const416 InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx,
417 const APInt &Imm, Type *Ty,
418 TTI::TargetCostKind CostKind,
419 Instruction *Inst) const {
420 assert(Ty->isIntegerTy());
421
422 unsigned BitSize = Ty->getPrimitiveSizeInBits();
423 // There is no cost model for constants with a bit size of 0. Return TCC_Free
424 // here, so that constant hoisting will ignore this constant.
425 if (BitSize == 0)
426 return TTI::TCC_Free;
427
428 unsigned ImmIdx = ~0U;
429 switch (Opcode) {
430 default:
431 return TTI::TCC_Free;
432 case Instruction::GetElementPtr:
433 // Always hoist the base address of a GetElementPtr.
434 if (Idx == 0)
435 return 2 * TTI::TCC_Basic;
436 return TTI::TCC_Free;
437 case Instruction::Store:
438 ImmIdx = 0;
439 break;
440 case Instruction::Add:
441 case Instruction::Sub:
442 case Instruction::Mul:
443 case Instruction::UDiv:
444 case Instruction::SDiv:
445 case Instruction::URem:
446 case Instruction::SRem:
447 case Instruction::And:
448 case Instruction::Or:
449 case Instruction::Xor:
450 case Instruction::ICmp:
451 ImmIdx = 1;
452 break;
453 // Always return TCC_Free for the shift value of a shift instruction.
454 case Instruction::Shl:
455 case Instruction::LShr:
456 case Instruction::AShr:
457 if (Idx == 1)
458 return TTI::TCC_Free;
459 break;
460 case Instruction::Trunc:
461 case Instruction::ZExt:
462 case Instruction::SExt:
463 case Instruction::IntToPtr:
464 case Instruction::PtrToInt:
465 case Instruction::BitCast:
466 case Instruction::PHI:
467 case Instruction::Call:
468 case Instruction::Select:
469 case Instruction::Ret:
470 case Instruction::Load:
471 break;
472 }
473
474 if (Idx == ImmIdx) {
475 int NumConstants = (BitSize + 63) / 64;
476 InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
477 return (Cost <= NumConstants * TTI::TCC_Basic)
478 ? static_cast<int>(TTI::TCC_Free)
479 : Cost;
480 }
481 return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
482 }
483
484 InstructionCost
getIntImmCostIntrin(Intrinsic::ID IID,unsigned Idx,const APInt & Imm,Type * Ty,TTI::TargetCostKind CostKind) const485 AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
486 const APInt &Imm, Type *Ty,
487 TTI::TargetCostKind CostKind) const {
488 assert(Ty->isIntegerTy());
489
490 unsigned BitSize = Ty->getPrimitiveSizeInBits();
491 // There is no cost model for constants with a bit size of 0. Return TCC_Free
492 // here, so that constant hoisting will ignore this constant.
493 if (BitSize == 0)
494 return TTI::TCC_Free;
495
496 // Most (all?) AArch64 intrinsics do not support folding immediates into the
497 // selected instruction, so we compute the materialization cost for the
498 // immediate directly.
499 if (IID >= Intrinsic::aarch64_addg && IID <= Intrinsic::aarch64_udiv)
500 return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
501
502 switch (IID) {
503 default:
504 return TTI::TCC_Free;
505 case Intrinsic::sadd_with_overflow:
506 case Intrinsic::uadd_with_overflow:
507 case Intrinsic::ssub_with_overflow:
508 case Intrinsic::usub_with_overflow:
509 case Intrinsic::smul_with_overflow:
510 case Intrinsic::umul_with_overflow:
511 if (Idx == 1) {
512 int NumConstants = (BitSize + 63) / 64;
513 InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
514 return (Cost <= NumConstants * TTI::TCC_Basic)
515 ? static_cast<int>(TTI::TCC_Free)
516 : Cost;
517 }
518 break;
519 case Intrinsic::experimental_stackmap:
520 if ((Idx < 2) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
521 return TTI::TCC_Free;
522 break;
523 case Intrinsic::experimental_patchpoint_void:
524 case Intrinsic::experimental_patchpoint:
525 if ((Idx < 4) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
526 return TTI::TCC_Free;
527 break;
528 case Intrinsic::experimental_gc_statepoint:
529 if ((Idx < 5) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue())))
530 return TTI::TCC_Free;
531 break;
532 }
533 return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind);
534 }
535
536 TargetTransformInfo::PopcntSupportKind
getPopcntSupport(unsigned TyWidth) const537 AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) const {
538 assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2");
539 if (TyWidth == 32 || TyWidth == 64)
540 return TTI::PSK_FastHardware;
541 // TODO: AArch64TargetLowering::LowerCTPOP() supports 128bit popcount.
542 return TTI::PSK_Software;
543 }
544
isUnpackedVectorVT(EVT VecVT)545 static bool isUnpackedVectorVT(EVT VecVT) {
546 return VecVT.isScalableVector() &&
547 VecVT.getSizeInBits().getKnownMinValue() < AArch64::SVEBitsPerBlock;
548 }
549
getHistogramCost(const IntrinsicCostAttributes & ICA)550 static InstructionCost getHistogramCost(const IntrinsicCostAttributes &ICA) {
551 Type *BucketPtrsTy = ICA.getArgTypes()[0]; // Type of vector of pointers
552 Type *EltTy = ICA.getArgTypes()[1]; // Type of bucket elements
553 unsigned TotalHistCnts = 1;
554
555 unsigned EltSize = EltTy->getScalarSizeInBits();
556 // Only allow (up to 64b) integers or pointers
557 if ((!EltTy->isIntegerTy() && !EltTy->isPointerTy()) || EltSize > 64)
558 return InstructionCost::getInvalid();
559
560 // FIXME: We should be able to generate histcnt for fixed-length vectors
561 // using ptrue with a specific VL.
562 if (VectorType *VTy = dyn_cast<VectorType>(BucketPtrsTy)) {
563 unsigned EC = VTy->getElementCount().getKnownMinValue();
564 if (!isPowerOf2_64(EC) || !VTy->isScalableTy())
565 return InstructionCost::getInvalid();
566
567 // HistCnt only supports 32b and 64b element types
568 unsigned LegalEltSize = EltSize <= 32 ? 32 : 64;
569
570 if (EC == 2 || (LegalEltSize == 32 && EC == 4))
571 return InstructionCost(BaseHistCntCost);
572
573 unsigned NaturalVectorWidth = AArch64::SVEBitsPerBlock / LegalEltSize;
574 TotalHistCnts = EC / NaturalVectorWidth;
575 }
576
577 return InstructionCost(BaseHistCntCost * TotalHistCnts);
578 }
579
580 InstructionCost
getIntrinsicInstrCost(const IntrinsicCostAttributes & ICA,TTI::TargetCostKind CostKind) const581 AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
582 TTI::TargetCostKind CostKind) const {
583 // The code-generator is currently not able to handle scalable vectors
584 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
585 // it. This change will be removed when code-generation for these types is
586 // sufficiently reliable.
587 auto *RetTy = ICA.getReturnType();
588 if (auto *VTy = dyn_cast<ScalableVectorType>(RetTy))
589 if (VTy->getElementCount() == ElementCount::getScalable(1))
590 return InstructionCost::getInvalid();
591
592 switch (ICA.getID()) {
593 case Intrinsic::experimental_vector_histogram_add:
594 if (!ST->hasSVE2())
595 return InstructionCost::getInvalid();
596 return getHistogramCost(ICA);
597 case Intrinsic::umin:
598 case Intrinsic::umax:
599 case Intrinsic::smin:
600 case Intrinsic::smax: {
601 static const auto ValidMinMaxTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16,
602 MVT::v8i16, MVT::v2i32, MVT::v4i32,
603 MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32,
604 MVT::nxv2i64};
605 auto LT = getTypeLegalizationCost(RetTy);
606 // v2i64 types get converted to cmp+bif hence the cost of 2
607 if (LT.second == MVT::v2i64)
608 return LT.first * 2;
609 if (any_of(ValidMinMaxTys, [<](MVT M) { return M == LT.second; }))
610 return LT.first;
611 break;
612 }
613 case Intrinsic::sadd_sat:
614 case Intrinsic::ssub_sat:
615 case Intrinsic::uadd_sat:
616 case Intrinsic::usub_sat: {
617 static const auto ValidSatTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16,
618 MVT::v8i16, MVT::v2i32, MVT::v4i32,
619 MVT::v2i64};
620 auto LT = getTypeLegalizationCost(RetTy);
621 // This is a base cost of 1 for the vadd, plus 3 extract shifts if we
622 // need to extend the type, as it uses shr(qadd(shl, shl)).
623 unsigned Instrs =
624 LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4;
625 if (any_of(ValidSatTys, [<](MVT M) { return M == LT.second; }))
626 return LT.first * Instrs;
627 break;
628 }
629 case Intrinsic::abs: {
630 static const auto ValidAbsTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16,
631 MVT::v8i16, MVT::v2i32, MVT::v4i32,
632 MVT::v2i64};
633 auto LT = getTypeLegalizationCost(RetTy);
634 if (any_of(ValidAbsTys, [<](MVT M) { return M == LT.second; }))
635 return LT.first;
636 break;
637 }
638 case Intrinsic::bswap: {
639 static const auto ValidAbsTys = {MVT::v4i16, MVT::v8i16, MVT::v2i32,
640 MVT::v4i32, MVT::v2i64};
641 auto LT = getTypeLegalizationCost(RetTy);
642 if (any_of(ValidAbsTys, [<](MVT M) { return M == LT.second; }) &&
643 LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits())
644 return LT.first;
645 break;
646 }
647 case Intrinsic::stepvector: {
648 InstructionCost Cost = 1; // Cost of the `index' instruction
649 auto LT = getTypeLegalizationCost(RetTy);
650 // Legalisation of illegal vectors involves an `index' instruction plus
651 // (LT.first - 1) vector adds.
652 if (LT.first > 1) {
653 Type *LegalVTy = EVT(LT.second).getTypeForEVT(RetTy->getContext());
654 InstructionCost AddCost =
655 getArithmeticInstrCost(Instruction::Add, LegalVTy, CostKind);
656 Cost += AddCost * (LT.first - 1);
657 }
658 return Cost;
659 }
660 case Intrinsic::vector_extract:
661 case Intrinsic::vector_insert: {
662 // If both the vector and subvector types are legal types and the index
663 // is 0, then this should be a no-op or simple operation; return a
664 // relatively low cost.
665
666 // If arguments aren't actually supplied, then we cannot determine the
667 // value of the index. We also want to skip predicate types.
668 if (ICA.getArgs().size() != ICA.getArgTypes().size() ||
669 ICA.getReturnType()->getScalarType()->isIntegerTy(1))
670 break;
671
672 LLVMContext &C = RetTy->getContext();
673 EVT VecVT = getTLI()->getValueType(DL, ICA.getArgTypes()[0]);
674 bool IsExtract = ICA.getID() == Intrinsic::vector_extract;
675 EVT SubVecVT = IsExtract ? getTLI()->getValueType(DL, RetTy)
676 : getTLI()->getValueType(DL, ICA.getArgTypes()[1]);
677 // Skip this if either the vector or subvector types are unpacked
678 // SVE types; they may get lowered to stack stores and loads.
679 if (isUnpackedVectorVT(VecVT) || isUnpackedVectorVT(SubVecVT))
680 break;
681
682 TargetLoweringBase::LegalizeKind SubVecLK =
683 getTLI()->getTypeConversion(C, SubVecVT);
684 TargetLoweringBase::LegalizeKind VecLK =
685 getTLI()->getTypeConversion(C, VecVT);
686 const Value *Idx = IsExtract ? ICA.getArgs()[1] : ICA.getArgs()[2];
687 const ConstantInt *CIdx = cast<ConstantInt>(Idx);
688 if (SubVecLK.first == TargetLoweringBase::TypeLegal &&
689 VecLK.first == TargetLoweringBase::TypeLegal && CIdx->isZero())
690 return TTI::TCC_Free;
691 break;
692 }
693 case Intrinsic::bitreverse: {
694 static const CostTblEntry BitreverseTbl[] = {
695 {Intrinsic::bitreverse, MVT::i32, 1},
696 {Intrinsic::bitreverse, MVT::i64, 1},
697 {Intrinsic::bitreverse, MVT::v8i8, 1},
698 {Intrinsic::bitreverse, MVT::v16i8, 1},
699 {Intrinsic::bitreverse, MVT::v4i16, 2},
700 {Intrinsic::bitreverse, MVT::v8i16, 2},
701 {Intrinsic::bitreverse, MVT::v2i32, 2},
702 {Intrinsic::bitreverse, MVT::v4i32, 2},
703 {Intrinsic::bitreverse, MVT::v1i64, 2},
704 {Intrinsic::bitreverse, MVT::v2i64, 2},
705 };
706 const auto LegalisationCost = getTypeLegalizationCost(RetTy);
707 const auto *Entry =
708 CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second);
709 if (Entry) {
710 // Cost Model is using the legal type(i32) that i8 and i16 will be
711 // converted to +1 so that we match the actual lowering cost
712 if (TLI->getValueType(DL, RetTy, true) == MVT::i8 ||
713 TLI->getValueType(DL, RetTy, true) == MVT::i16)
714 return LegalisationCost.first * Entry->Cost + 1;
715
716 return LegalisationCost.first * Entry->Cost;
717 }
718 break;
719 }
720 case Intrinsic::ctpop: {
721 if (!ST->hasNEON()) {
722 // 32-bit or 64-bit ctpop without NEON is 12 instructions.
723 return getTypeLegalizationCost(RetTy).first * 12;
724 }
725 static const CostTblEntry CtpopCostTbl[] = {
726 {ISD::CTPOP, MVT::v2i64, 4},
727 {ISD::CTPOP, MVT::v4i32, 3},
728 {ISD::CTPOP, MVT::v8i16, 2},
729 {ISD::CTPOP, MVT::v16i8, 1},
730 {ISD::CTPOP, MVT::i64, 4},
731 {ISD::CTPOP, MVT::v2i32, 3},
732 {ISD::CTPOP, MVT::v4i16, 2},
733 {ISD::CTPOP, MVT::v8i8, 1},
734 {ISD::CTPOP, MVT::i32, 5},
735 };
736 auto LT = getTypeLegalizationCost(RetTy);
737 MVT MTy = LT.second;
738 if (const auto *Entry = CostTableLookup(CtpopCostTbl, ISD::CTPOP, MTy)) {
739 // Extra cost of +1 when illegal vector types are legalized by promoting
740 // the integer type.
741 int ExtraCost = MTy.isVector() && MTy.getScalarSizeInBits() !=
742 RetTy->getScalarSizeInBits()
743 ? 1
744 : 0;
745 return LT.first * Entry->Cost + ExtraCost;
746 }
747 break;
748 }
749 case Intrinsic::sadd_with_overflow:
750 case Intrinsic::uadd_with_overflow:
751 case Intrinsic::ssub_with_overflow:
752 case Intrinsic::usub_with_overflow:
753 case Intrinsic::smul_with_overflow:
754 case Intrinsic::umul_with_overflow: {
755 static const CostTblEntry WithOverflowCostTbl[] = {
756 {Intrinsic::sadd_with_overflow, MVT::i8, 3},
757 {Intrinsic::uadd_with_overflow, MVT::i8, 3},
758 {Intrinsic::sadd_with_overflow, MVT::i16, 3},
759 {Intrinsic::uadd_with_overflow, MVT::i16, 3},
760 {Intrinsic::sadd_with_overflow, MVT::i32, 1},
761 {Intrinsic::uadd_with_overflow, MVT::i32, 1},
762 {Intrinsic::sadd_with_overflow, MVT::i64, 1},
763 {Intrinsic::uadd_with_overflow, MVT::i64, 1},
764 {Intrinsic::ssub_with_overflow, MVT::i8, 3},
765 {Intrinsic::usub_with_overflow, MVT::i8, 3},
766 {Intrinsic::ssub_with_overflow, MVT::i16, 3},
767 {Intrinsic::usub_with_overflow, MVT::i16, 3},
768 {Intrinsic::ssub_with_overflow, MVT::i32, 1},
769 {Intrinsic::usub_with_overflow, MVT::i32, 1},
770 {Intrinsic::ssub_with_overflow, MVT::i64, 1},
771 {Intrinsic::usub_with_overflow, MVT::i64, 1},
772 {Intrinsic::smul_with_overflow, MVT::i8, 5},
773 {Intrinsic::umul_with_overflow, MVT::i8, 4},
774 {Intrinsic::smul_with_overflow, MVT::i16, 5},
775 {Intrinsic::umul_with_overflow, MVT::i16, 4},
776 {Intrinsic::smul_with_overflow, MVT::i32, 2}, // eg umull;tst
777 {Intrinsic::umul_with_overflow, MVT::i32, 2}, // eg umull;cmp sxtw
778 {Intrinsic::smul_with_overflow, MVT::i64, 3}, // eg mul;smulh;cmp
779 {Intrinsic::umul_with_overflow, MVT::i64, 3}, // eg mul;umulh;cmp asr
780 };
781 EVT MTy = TLI->getValueType(DL, RetTy->getContainedType(0), true);
782 if (MTy.isSimple())
783 if (const auto *Entry = CostTableLookup(WithOverflowCostTbl, ICA.getID(),
784 MTy.getSimpleVT()))
785 return Entry->Cost;
786 break;
787 }
788 case Intrinsic::fptosi_sat:
789 case Intrinsic::fptoui_sat: {
790 if (ICA.getArgTypes().empty())
791 break;
792 bool IsSigned = ICA.getID() == Intrinsic::fptosi_sat;
793 auto LT = getTypeLegalizationCost(ICA.getArgTypes()[0]);
794 EVT MTy = TLI->getValueType(DL, RetTy);
795 // Check for the legal types, which are where the size of the input and the
796 // output are the same, or we are using cvt f64->i32 or f32->i64.
797 if ((LT.second == MVT::f32 || LT.second == MVT::f64 ||
798 LT.second == MVT::v2f32 || LT.second == MVT::v4f32 ||
799 LT.second == MVT::v2f64)) {
800 if ((LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits() ||
801 (LT.second == MVT::f64 && MTy == MVT::i32) ||
802 (LT.second == MVT::f32 && MTy == MVT::i64)))
803 return LT.first;
804 // Extending vector types v2f32->v2i64, fcvtl*2 + fcvt*2
805 if (LT.second.getScalarType() == MVT::f32 && MTy.isFixedLengthVector() &&
806 MTy.getScalarSizeInBits() == 64)
807 return LT.first * (MTy.getVectorNumElements() > 2 ? 4 : 2);
808 }
809 // Similarly for fp16 sizes. Without FullFP16 we generally need to fcvt to
810 // f32.
811 if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
812 return LT.first + getIntrinsicInstrCost(
813 {ICA.getID(),
814 RetTy,
815 {ICA.getArgTypes()[0]->getWithNewType(
816 Type::getFloatTy(RetTy->getContext()))}},
817 CostKind);
818 if ((LT.second == MVT::f16 && MTy == MVT::i32) ||
819 (LT.second == MVT::f16 && MTy == MVT::i64) ||
820 ((LT.second == MVT::v4f16 || LT.second == MVT::v8f16) &&
821 (LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits())))
822 return LT.first;
823 // Extending vector types v8f16->v8i32, fcvtl*2 + fcvt*2
824 if (LT.second.getScalarType() == MVT::f16 && MTy.isFixedLengthVector() &&
825 MTy.getScalarSizeInBits() == 32)
826 return LT.first * (MTy.getVectorNumElements() > 4 ? 4 : 2);
827 // Extending vector types v8f16->v8i32. These current scalarize but the
828 // codegen could be better.
829 if (LT.second.getScalarType() == MVT::f16 && MTy.isFixedLengthVector() &&
830 MTy.getScalarSizeInBits() == 64)
831 return MTy.getVectorNumElements() * 3;
832
833 // If we can we use a legal convert followed by a min+max
834 if ((LT.second.getScalarType() == MVT::f32 ||
835 LT.second.getScalarType() == MVT::f64 ||
836 LT.second.getScalarType() == MVT::f16) &&
837 LT.second.getScalarSizeInBits() >= MTy.getScalarSizeInBits()) {
838 Type *LegalTy =
839 Type::getIntNTy(RetTy->getContext(), LT.second.getScalarSizeInBits());
840 if (LT.second.isVector())
841 LegalTy = VectorType::get(LegalTy, LT.second.getVectorElementCount());
842 InstructionCost Cost = 1;
843 IntrinsicCostAttributes Attrs1(IsSigned ? Intrinsic::smin : Intrinsic::umin,
844 LegalTy, {LegalTy, LegalTy});
845 Cost += getIntrinsicInstrCost(Attrs1, CostKind);
846 IntrinsicCostAttributes Attrs2(IsSigned ? Intrinsic::smax : Intrinsic::umax,
847 LegalTy, {LegalTy, LegalTy});
848 Cost += getIntrinsicInstrCost(Attrs2, CostKind);
849 return LT.first * Cost +
850 ((LT.second.getScalarType() != MVT::f16 || ST->hasFullFP16()) ? 0
851 : 1);
852 }
853 // Otherwise we need to follow the default expansion that clamps the value
854 // using a float min/max with a fcmp+sel for nan handling when signed.
855 Type *FPTy = ICA.getArgTypes()[0]->getScalarType();
856 RetTy = RetTy->getScalarType();
857 if (LT.second.isVector()) {
858 FPTy = VectorType::get(FPTy, LT.second.getVectorElementCount());
859 RetTy = VectorType::get(RetTy, LT.second.getVectorElementCount());
860 }
861 IntrinsicCostAttributes Attrs1(Intrinsic::minnum, FPTy, {FPTy, FPTy});
862 InstructionCost Cost = getIntrinsicInstrCost(Attrs1, CostKind);
863 IntrinsicCostAttributes Attrs2(Intrinsic::maxnum, FPTy, {FPTy, FPTy});
864 Cost += getIntrinsicInstrCost(Attrs2, CostKind);
865 Cost +=
866 getCastInstrCost(IsSigned ? Instruction::FPToSI : Instruction::FPToUI,
867 RetTy, FPTy, TTI::CastContextHint::None, CostKind);
868 if (IsSigned) {
869 Type *CondTy = RetTy->getWithNewBitWidth(1);
870 Cost += getCmpSelInstrCost(BinaryOperator::FCmp, FPTy, CondTy,
871 CmpInst::FCMP_UNO, CostKind);
872 Cost += getCmpSelInstrCost(BinaryOperator::Select, RetTy, CondTy,
873 CmpInst::FCMP_UNO, CostKind);
874 }
875 return LT.first * Cost;
876 }
877 case Intrinsic::fshl:
878 case Intrinsic::fshr: {
879 if (ICA.getArgs().empty())
880 break;
881
882 // TODO: Add handling for fshl where third argument is not a constant.
883 const TTI::OperandValueInfo OpInfoZ = TTI::getOperandInfo(ICA.getArgs()[2]);
884 if (!OpInfoZ.isConstant())
885 break;
886
887 const auto LegalisationCost = getTypeLegalizationCost(RetTy);
888 if (OpInfoZ.isUniform()) {
889 static const CostTblEntry FshlTbl[] = {
890 {Intrinsic::fshl, MVT::v4i32, 2}, // shl + usra
891 {Intrinsic::fshl, MVT::v2i64, 2}, {Intrinsic::fshl, MVT::v16i8, 2},
892 {Intrinsic::fshl, MVT::v8i16, 2}, {Intrinsic::fshl, MVT::v2i32, 2},
893 {Intrinsic::fshl, MVT::v8i8, 2}, {Intrinsic::fshl, MVT::v4i16, 2}};
894 // Costs for both fshl & fshr are the same, so just pass Intrinsic::fshl
895 // to avoid having to duplicate the costs.
896 const auto *Entry =
897 CostTableLookup(FshlTbl, Intrinsic::fshl, LegalisationCost.second);
898 if (Entry)
899 return LegalisationCost.first * Entry->Cost;
900 }
901
902 auto TyL = getTypeLegalizationCost(RetTy);
903 if (!RetTy->isIntegerTy())
904 break;
905
906 // Estimate cost manually, as types like i8 and i16 will get promoted to
907 // i32 and CostTableLookup will ignore the extra conversion cost.
908 bool HigherCost = (RetTy->getScalarSizeInBits() != 32 &&
909 RetTy->getScalarSizeInBits() < 64) ||
910 (RetTy->getScalarSizeInBits() % 64 != 0);
911 unsigned ExtraCost = HigherCost ? 1 : 0;
912 if (RetTy->getScalarSizeInBits() == 32 ||
913 RetTy->getScalarSizeInBits() == 64)
914 ExtraCost = 0; // fhsl/fshr for i32 and i64 can be lowered to a single
915 // extr instruction.
916 else if (HigherCost)
917 ExtraCost = 1;
918 else
919 break;
920 return TyL.first + ExtraCost;
921 }
922 case Intrinsic::get_active_lane_mask: {
923 auto *RetTy = dyn_cast<FixedVectorType>(ICA.getReturnType());
924 if (RetTy) {
925 EVT RetVT = getTLI()->getValueType(DL, RetTy);
926 EVT OpVT = getTLI()->getValueType(DL, ICA.getArgTypes()[0]);
927 if (!getTLI()->shouldExpandGetActiveLaneMask(RetVT, OpVT) &&
928 !getTLI()->isTypeLegal(RetVT)) {
929 // We don't have enough context at this point to determine if the mask
930 // is going to be kept live after the block, which will force the vXi1
931 // type to be expanded to legal vectors of integers, e.g. v4i1->v4i32.
932 // For now, we just assume the vectorizer created this intrinsic and
933 // the result will be the input for a PHI. In this case the cost will
934 // be extremely high for fixed-width vectors.
935 // NOTE: getScalarizationOverhead returns a cost that's far too
936 // pessimistic for the actual generated codegen. In reality there are
937 // two instructions generated per lane.
938 return RetTy->getNumElements() * 2;
939 }
940 }
941 break;
942 }
943 case Intrinsic::experimental_vector_match: {
944 auto *NeedleTy = cast<FixedVectorType>(ICA.getArgTypes()[1]);
945 EVT SearchVT = getTLI()->getValueType(DL, ICA.getArgTypes()[0]);
946 unsigned SearchSize = NeedleTy->getNumElements();
947 if (!getTLI()->shouldExpandVectorMatch(SearchVT, SearchSize)) {
948 // Base cost for MATCH instructions. At least on the Neoverse V2 and
949 // Neoverse V3, these are cheap operations with the same latency as a
950 // vector ADD. In most cases, however, we also need to do an extra DUP.
951 // For fixed-length vectors we currently need an extra five--six
952 // instructions besides the MATCH.
953 InstructionCost Cost = 4;
954 if (isa<FixedVectorType>(RetTy))
955 Cost += 10;
956 return Cost;
957 }
958 break;
959 }
960 case Intrinsic::experimental_cttz_elts: {
961 EVT ArgVT = getTLI()->getValueType(DL, ICA.getArgTypes()[0]);
962 if (!getTLI()->shouldExpandCttzElements(ArgVT)) {
963 // This will consist of a SVE brkb and a cntp instruction. These
964 // typically have the same latency and half the throughput as a vector
965 // add instruction.
966 return 4;
967 }
968 break;
969 }
970 default:
971 break;
972 }
973 return BaseT::getIntrinsicInstrCost(ICA, CostKind);
974 }
975
976 /// The function will remove redundant reinterprets casting in the presence
977 /// of the control flow
processPhiNode(InstCombiner & IC,IntrinsicInst & II)978 static std::optional<Instruction *> processPhiNode(InstCombiner &IC,
979 IntrinsicInst &II) {
980 SmallVector<Instruction *, 32> Worklist;
981 auto RequiredType = II.getType();
982
983 auto *PN = dyn_cast<PHINode>(II.getArgOperand(0));
984 assert(PN && "Expected Phi Node!");
985
986 // Don't create a new Phi unless we can remove the old one.
987 if (!PN->hasOneUse())
988 return std::nullopt;
989
990 for (Value *IncValPhi : PN->incoming_values()) {
991 auto *Reinterpret = dyn_cast<IntrinsicInst>(IncValPhi);
992 if (!Reinterpret ||
993 Reinterpret->getIntrinsicID() !=
994 Intrinsic::aarch64_sve_convert_to_svbool ||
995 RequiredType != Reinterpret->getArgOperand(0)->getType())
996 return std::nullopt;
997 }
998
999 // Create the new Phi
1000 IC.Builder.SetInsertPoint(PN);
1001 PHINode *NPN = IC.Builder.CreatePHI(RequiredType, PN->getNumIncomingValues());
1002 Worklist.push_back(PN);
1003
1004 for (unsigned I = 0; I < PN->getNumIncomingValues(); I++) {
1005 auto *Reinterpret = cast<Instruction>(PN->getIncomingValue(I));
1006 NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(I));
1007 Worklist.push_back(Reinterpret);
1008 }
1009
1010 // Cleanup Phi Node and reinterprets
1011 return IC.replaceInstUsesWith(II, NPN);
1012 }
1013
1014 // A collection of properties common to SVE intrinsics that allow for combines
1015 // to be written without needing to know the specific intrinsic.
1016 struct SVEIntrinsicInfo {
1017 //
1018 // Helper routines for common intrinsic definitions.
1019 //
1020
1021 // e.g. llvm.aarch64.sve.add pg, op1, op2
1022 // with IID ==> llvm.aarch64.sve.add_u
1023 static SVEIntrinsicInfo
defaultMergingOpSVEIntrinsicInfo1024 defaultMergingOp(Intrinsic::ID IID = Intrinsic::not_intrinsic) {
1025 return SVEIntrinsicInfo()
1026 .setGoverningPredicateOperandIdx(0)
1027 .setOperandIdxInactiveLanesTakenFrom(1)
1028 .setMatchingUndefIntrinsic(IID);
1029 }
1030
1031 // e.g. llvm.aarch64.sve.neg inactive, pg, op
defaultMergingUnaryOpSVEIntrinsicInfo1032 static SVEIntrinsicInfo defaultMergingUnaryOp() {
1033 return SVEIntrinsicInfo()
1034 .setGoverningPredicateOperandIdx(1)
1035 .setOperandIdxInactiveLanesTakenFrom(0)
1036 .setOperandIdxWithNoActiveLanes(0);
1037 }
1038
1039 // e.g. llvm.aarch64.sve.fcvtnt inactive, pg, op
defaultMergingUnaryNarrowingTopOpSVEIntrinsicInfo1040 static SVEIntrinsicInfo defaultMergingUnaryNarrowingTopOp() {
1041 return SVEIntrinsicInfo()
1042 .setGoverningPredicateOperandIdx(1)
1043 .setOperandIdxInactiveLanesTakenFrom(0);
1044 }
1045
1046 // e.g. llvm.aarch64.sve.add_u pg, op1, op2
defaultUndefOpSVEIntrinsicInfo1047 static SVEIntrinsicInfo defaultUndefOp() {
1048 return SVEIntrinsicInfo()
1049 .setGoverningPredicateOperandIdx(0)
1050 .setInactiveLanesAreNotDefined();
1051 }
1052
1053 // e.g. llvm.aarch64.sve.prf pg, ptr (GPIndex = 0)
1054 // llvm.aarch64.sve.st1 data, pg, ptr (GPIndex = 1)
defaultVoidOpSVEIntrinsicInfo1055 static SVEIntrinsicInfo defaultVoidOp(unsigned GPIndex) {
1056 return SVEIntrinsicInfo()
1057 .setGoverningPredicateOperandIdx(GPIndex)
1058 .setInactiveLanesAreUnused();
1059 }
1060
1061 // e.g. llvm.aarch64.sve.cmpeq pg, op1, op2
1062 // llvm.aarch64.sve.ld1 pg, ptr
defaultZeroingOpSVEIntrinsicInfo1063 static SVEIntrinsicInfo defaultZeroingOp() {
1064 return SVEIntrinsicInfo()
1065 .setGoverningPredicateOperandIdx(0)
1066 .setInactiveLanesAreUnused()
1067 .setResultIsZeroInitialized();
1068 }
1069
1070 // All properties relate to predication and thus having a general predicate
1071 // is the minimum requirement to say there is intrinsic info to act on.
operator boolSVEIntrinsicInfo1072 explicit operator bool() const { return hasGoverningPredicate(); }
1073
1074 //
1075 // Properties relating to the governing predicate.
1076 //
1077
hasGoverningPredicateSVEIntrinsicInfo1078 bool hasGoverningPredicate() const {
1079 return GoverningPredicateIdx != std::numeric_limits<unsigned>::max();
1080 }
1081
getGoverningPredicateOperandIdxSVEIntrinsicInfo1082 unsigned getGoverningPredicateOperandIdx() const {
1083 assert(hasGoverningPredicate() && "Propery not set!");
1084 return GoverningPredicateIdx;
1085 }
1086
setGoverningPredicateOperandIdxSVEIntrinsicInfo1087 SVEIntrinsicInfo &setGoverningPredicateOperandIdx(unsigned Index) {
1088 assert(!hasGoverningPredicate() && "Cannot set property twice!");
1089 GoverningPredicateIdx = Index;
1090 return *this;
1091 }
1092
1093 //
1094 // Properties relating to operations the intrinsic could be transformed into.
1095 // NOTE: This does not mean such a transformation is always possible, but the
1096 // knowledge makes it possible to reuse existing optimisations without needing
1097 // to embed specific handling for each intrinsic. For example, instruction
1098 // simplification can be used to optimise an intrinsic's active lanes.
1099 //
1100
hasMatchingUndefIntrinsicSVEIntrinsicInfo1101 bool hasMatchingUndefIntrinsic() const {
1102 return UndefIntrinsic != Intrinsic::not_intrinsic;
1103 }
1104
getMatchingUndefIntrinsicSVEIntrinsicInfo1105 Intrinsic::ID getMatchingUndefIntrinsic() const {
1106 assert(hasMatchingUndefIntrinsic() && "Propery not set!");
1107 return UndefIntrinsic;
1108 }
1109
setMatchingUndefIntrinsicSVEIntrinsicInfo1110 SVEIntrinsicInfo &setMatchingUndefIntrinsic(Intrinsic::ID IID) {
1111 assert(!hasMatchingUndefIntrinsic() && "Cannot set property twice!");
1112 UndefIntrinsic = IID;
1113 return *this;
1114 }
1115
hasMatchingIROpodeSVEIntrinsicInfo1116 bool hasMatchingIROpode() const { return IROpcode != 0; }
1117
getMatchingIROpodeSVEIntrinsicInfo1118 unsigned getMatchingIROpode() const {
1119 assert(hasMatchingIROpode() && "Propery not set!");
1120 return IROpcode;
1121 }
1122
setMatchingIROpcodeSVEIntrinsicInfo1123 SVEIntrinsicInfo &setMatchingIROpcode(unsigned Opcode) {
1124 assert(!hasMatchingIROpode() && "Cannot set property twice!");
1125 IROpcode = Opcode;
1126 return *this;
1127 }
1128
1129 //
1130 // Properties relating to the result of inactive lanes.
1131 //
1132
inactiveLanesTakenFromOperandSVEIntrinsicInfo1133 bool inactiveLanesTakenFromOperand() const {
1134 return ResultLanes == InactiveLanesTakenFromOperand;
1135 }
1136
getOperandIdxInactiveLanesTakenFromSVEIntrinsicInfo1137 unsigned getOperandIdxInactiveLanesTakenFrom() const {
1138 assert(inactiveLanesTakenFromOperand() && "Propery not set!");
1139 return OperandIdxForInactiveLanes;
1140 }
1141
setOperandIdxInactiveLanesTakenFromSVEIntrinsicInfo1142 SVEIntrinsicInfo &setOperandIdxInactiveLanesTakenFrom(unsigned Index) {
1143 assert(ResultLanes == Uninitialized && "Cannot set property twice!");
1144 ResultLanes = InactiveLanesTakenFromOperand;
1145 OperandIdxForInactiveLanes = Index;
1146 return *this;
1147 }
1148
inactiveLanesAreNotDefinedSVEIntrinsicInfo1149 bool inactiveLanesAreNotDefined() const {
1150 return ResultLanes == InactiveLanesAreNotDefined;
1151 }
1152
setInactiveLanesAreNotDefinedSVEIntrinsicInfo1153 SVEIntrinsicInfo &setInactiveLanesAreNotDefined() {
1154 assert(ResultLanes == Uninitialized && "Cannot set property twice!");
1155 ResultLanes = InactiveLanesAreNotDefined;
1156 return *this;
1157 }
1158
inactiveLanesAreUnusedSVEIntrinsicInfo1159 bool inactiveLanesAreUnused() const {
1160 return ResultLanes == InactiveLanesAreUnused;
1161 }
1162
setInactiveLanesAreUnusedSVEIntrinsicInfo1163 SVEIntrinsicInfo &setInactiveLanesAreUnused() {
1164 assert(ResultLanes == Uninitialized && "Cannot set property twice!");
1165 ResultLanes = InactiveLanesAreUnused;
1166 return *this;
1167 }
1168
1169 // NOTE: Whilst not limited to only inactive lanes, the common use case is:
1170 // inactiveLanesAreZeroed =
1171 // resultIsZeroInitialized() && inactiveLanesAreUnused()
resultIsZeroInitializedSVEIntrinsicInfo1172 bool resultIsZeroInitialized() const { return ResultIsZeroInitialized; }
1173
setResultIsZeroInitializedSVEIntrinsicInfo1174 SVEIntrinsicInfo &setResultIsZeroInitialized() {
1175 ResultIsZeroInitialized = true;
1176 return *this;
1177 }
1178
1179 //
1180 // The first operand of unary merging operations is typically only used to
1181 // set the result for inactive lanes. Knowing this allows us to deadcode the
1182 // operand when we can prove there are no inactive lanes.
1183 //
1184
hasOperandWithNoActiveLanesSVEIntrinsicInfo1185 bool hasOperandWithNoActiveLanes() const {
1186 return OperandIdxWithNoActiveLanes != std::numeric_limits<unsigned>::max();
1187 }
1188
getOperandIdxWithNoActiveLanesSVEIntrinsicInfo1189 unsigned getOperandIdxWithNoActiveLanes() const {
1190 assert(hasOperandWithNoActiveLanes() && "Propery not set!");
1191 return OperandIdxWithNoActiveLanes;
1192 }
1193
setOperandIdxWithNoActiveLanesSVEIntrinsicInfo1194 SVEIntrinsicInfo &setOperandIdxWithNoActiveLanes(unsigned Index) {
1195 assert(!hasOperandWithNoActiveLanes() && "Cannot set property twice!");
1196 OperandIdxWithNoActiveLanes = Index;
1197 return *this;
1198 }
1199
1200 private:
1201 unsigned GoverningPredicateIdx = std::numeric_limits<unsigned>::max();
1202
1203 Intrinsic::ID UndefIntrinsic = Intrinsic::not_intrinsic;
1204 unsigned IROpcode = 0;
1205
1206 enum PredicationStyle {
1207 Uninitialized,
1208 InactiveLanesTakenFromOperand,
1209 InactiveLanesAreNotDefined,
1210 InactiveLanesAreUnused
1211 } ResultLanes = Uninitialized;
1212
1213 bool ResultIsZeroInitialized = false;
1214 unsigned OperandIdxForInactiveLanes = std::numeric_limits<unsigned>::max();
1215 unsigned OperandIdxWithNoActiveLanes = std::numeric_limits<unsigned>::max();
1216 };
1217
constructSVEIntrinsicInfo(IntrinsicInst & II)1218 static SVEIntrinsicInfo constructSVEIntrinsicInfo(IntrinsicInst &II) {
1219 // Some SVE intrinsics do not use scalable vector types, but since they are
1220 // not relevant from an SVEIntrinsicInfo perspective, they are also ignored.
1221 if (!isa<ScalableVectorType>(II.getType()) &&
1222 all_of(II.args(), [&](const Value *V) {
1223 return !isa<ScalableVectorType>(V->getType());
1224 }))
1225 return SVEIntrinsicInfo();
1226
1227 Intrinsic::ID IID = II.getIntrinsicID();
1228 switch (IID) {
1229 default:
1230 break;
1231 case Intrinsic::aarch64_sve_fcvt_bf16f32_v2:
1232 case Intrinsic::aarch64_sve_fcvt_f16f32:
1233 case Intrinsic::aarch64_sve_fcvt_f16f64:
1234 case Intrinsic::aarch64_sve_fcvt_f32f16:
1235 case Intrinsic::aarch64_sve_fcvt_f32f64:
1236 case Intrinsic::aarch64_sve_fcvt_f64f16:
1237 case Intrinsic::aarch64_sve_fcvt_f64f32:
1238 case Intrinsic::aarch64_sve_fcvtlt_f32f16:
1239 case Intrinsic::aarch64_sve_fcvtlt_f64f32:
1240 case Intrinsic::aarch64_sve_fcvtx_f32f64:
1241 case Intrinsic::aarch64_sve_fcvtzs:
1242 case Intrinsic::aarch64_sve_fcvtzs_i32f16:
1243 case Intrinsic::aarch64_sve_fcvtzs_i32f64:
1244 case Intrinsic::aarch64_sve_fcvtzs_i64f16:
1245 case Intrinsic::aarch64_sve_fcvtzs_i64f32:
1246 case Intrinsic::aarch64_sve_fcvtzu:
1247 case Intrinsic::aarch64_sve_fcvtzu_i32f16:
1248 case Intrinsic::aarch64_sve_fcvtzu_i32f64:
1249 case Intrinsic::aarch64_sve_fcvtzu_i64f16:
1250 case Intrinsic::aarch64_sve_fcvtzu_i64f32:
1251 case Intrinsic::aarch64_sve_scvtf:
1252 case Intrinsic::aarch64_sve_scvtf_f16i32:
1253 case Intrinsic::aarch64_sve_scvtf_f16i64:
1254 case Intrinsic::aarch64_sve_scvtf_f32i64:
1255 case Intrinsic::aarch64_sve_scvtf_f64i32:
1256 case Intrinsic::aarch64_sve_ucvtf:
1257 case Intrinsic::aarch64_sve_ucvtf_f16i32:
1258 case Intrinsic::aarch64_sve_ucvtf_f16i64:
1259 case Intrinsic::aarch64_sve_ucvtf_f32i64:
1260 case Intrinsic::aarch64_sve_ucvtf_f64i32:
1261 return SVEIntrinsicInfo::defaultMergingUnaryOp();
1262
1263 case Intrinsic::aarch64_sve_fcvtnt_bf16f32_v2:
1264 case Intrinsic::aarch64_sve_fcvtnt_f16f32:
1265 case Intrinsic::aarch64_sve_fcvtnt_f32f64:
1266 case Intrinsic::aarch64_sve_fcvtxnt_f32f64:
1267 return SVEIntrinsicInfo::defaultMergingUnaryNarrowingTopOp();
1268
1269 case Intrinsic::aarch64_sve_fabd:
1270 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fabd_u);
1271 case Intrinsic::aarch64_sve_fadd:
1272 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fadd_u)
1273 .setMatchingIROpcode(Instruction::FAdd);
1274 case Intrinsic::aarch64_sve_fdiv:
1275 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fdiv_u)
1276 .setMatchingIROpcode(Instruction::FDiv);
1277 case Intrinsic::aarch64_sve_fmax:
1278 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fmax_u);
1279 case Intrinsic::aarch64_sve_fmaxnm:
1280 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fmaxnm_u);
1281 case Intrinsic::aarch64_sve_fmin:
1282 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fmin_u);
1283 case Intrinsic::aarch64_sve_fminnm:
1284 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fminnm_u);
1285 case Intrinsic::aarch64_sve_fmla:
1286 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fmla_u);
1287 case Intrinsic::aarch64_sve_fmls:
1288 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fmls_u);
1289 case Intrinsic::aarch64_sve_fmul:
1290 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fmul_u)
1291 .setMatchingIROpcode(Instruction::FMul);
1292 case Intrinsic::aarch64_sve_fmulx:
1293 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fmulx_u);
1294 case Intrinsic::aarch64_sve_fnmla:
1295 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fnmla_u);
1296 case Intrinsic::aarch64_sve_fnmls:
1297 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fnmls_u);
1298 case Intrinsic::aarch64_sve_fsub:
1299 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_fsub_u)
1300 .setMatchingIROpcode(Instruction::FSub);
1301 case Intrinsic::aarch64_sve_add:
1302 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_add_u)
1303 .setMatchingIROpcode(Instruction::Add);
1304 case Intrinsic::aarch64_sve_mla:
1305 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_mla_u);
1306 case Intrinsic::aarch64_sve_mls:
1307 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_mls_u);
1308 case Intrinsic::aarch64_sve_mul:
1309 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_mul_u)
1310 .setMatchingIROpcode(Instruction::Mul);
1311 case Intrinsic::aarch64_sve_sabd:
1312 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_sabd_u);
1313 case Intrinsic::aarch64_sve_sdiv:
1314 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_sdiv_u)
1315 .setMatchingIROpcode(Instruction::SDiv);
1316 case Intrinsic::aarch64_sve_smax:
1317 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_smax_u);
1318 case Intrinsic::aarch64_sve_smin:
1319 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_smin_u);
1320 case Intrinsic::aarch64_sve_smulh:
1321 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_smulh_u);
1322 case Intrinsic::aarch64_sve_sub:
1323 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_sub_u)
1324 .setMatchingIROpcode(Instruction::Sub);
1325 case Intrinsic::aarch64_sve_uabd:
1326 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_uabd_u);
1327 case Intrinsic::aarch64_sve_udiv:
1328 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_udiv_u)
1329 .setMatchingIROpcode(Instruction::UDiv);
1330 case Intrinsic::aarch64_sve_umax:
1331 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_umax_u);
1332 case Intrinsic::aarch64_sve_umin:
1333 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_umin_u);
1334 case Intrinsic::aarch64_sve_umulh:
1335 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_umulh_u);
1336 case Intrinsic::aarch64_sve_asr:
1337 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_asr_u)
1338 .setMatchingIROpcode(Instruction::AShr);
1339 case Intrinsic::aarch64_sve_lsl:
1340 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_lsl_u)
1341 .setMatchingIROpcode(Instruction::Shl);
1342 case Intrinsic::aarch64_sve_lsr:
1343 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_lsr_u)
1344 .setMatchingIROpcode(Instruction::LShr);
1345 case Intrinsic::aarch64_sve_and:
1346 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_and_u)
1347 .setMatchingIROpcode(Instruction::And);
1348 case Intrinsic::aarch64_sve_bic:
1349 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_bic_u);
1350 case Intrinsic::aarch64_sve_eor:
1351 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_eor_u)
1352 .setMatchingIROpcode(Instruction::Xor);
1353 case Intrinsic::aarch64_sve_orr:
1354 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_orr_u)
1355 .setMatchingIROpcode(Instruction::Or);
1356 case Intrinsic::aarch64_sve_sqsub:
1357 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_sqsub_u);
1358 case Intrinsic::aarch64_sve_uqsub:
1359 return SVEIntrinsicInfo::defaultMergingOp(Intrinsic::aarch64_sve_uqsub_u);
1360
1361 case Intrinsic::aarch64_sve_add_u:
1362 return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1363 Instruction::Add);
1364 case Intrinsic::aarch64_sve_and_u:
1365 return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1366 Instruction::And);
1367 case Intrinsic::aarch64_sve_asr_u:
1368 return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1369 Instruction::AShr);
1370 case Intrinsic::aarch64_sve_eor_u:
1371 return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1372 Instruction::Xor);
1373 case Intrinsic::aarch64_sve_fadd_u:
1374 return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1375 Instruction::FAdd);
1376 case Intrinsic::aarch64_sve_fdiv_u:
1377 return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1378 Instruction::FDiv);
1379 case Intrinsic::aarch64_sve_fmul_u:
1380 return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1381 Instruction::FMul);
1382 case Intrinsic::aarch64_sve_fsub_u:
1383 return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1384 Instruction::FSub);
1385 case Intrinsic::aarch64_sve_lsl_u:
1386 return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1387 Instruction::Shl);
1388 case Intrinsic::aarch64_sve_lsr_u:
1389 return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1390 Instruction::LShr);
1391 case Intrinsic::aarch64_sve_mul_u:
1392 return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1393 Instruction::Mul);
1394 case Intrinsic::aarch64_sve_orr_u:
1395 return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1396 Instruction::Or);
1397 case Intrinsic::aarch64_sve_sdiv_u:
1398 return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1399 Instruction::SDiv);
1400 case Intrinsic::aarch64_sve_sub_u:
1401 return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1402 Instruction::Sub);
1403 case Intrinsic::aarch64_sve_udiv_u:
1404 return SVEIntrinsicInfo::defaultUndefOp().setMatchingIROpcode(
1405 Instruction::UDiv);
1406
1407 case Intrinsic::aarch64_sve_addqv:
1408 case Intrinsic::aarch64_sve_and_z:
1409 case Intrinsic::aarch64_sve_bic_z:
1410 case Intrinsic::aarch64_sve_brka_z:
1411 case Intrinsic::aarch64_sve_brkb_z:
1412 case Intrinsic::aarch64_sve_brkn_z:
1413 case Intrinsic::aarch64_sve_brkpa_z:
1414 case Intrinsic::aarch64_sve_brkpb_z:
1415 case Intrinsic::aarch64_sve_cntp:
1416 case Intrinsic::aarch64_sve_compact:
1417 case Intrinsic::aarch64_sve_eor_z:
1418 case Intrinsic::aarch64_sve_eorv:
1419 case Intrinsic::aarch64_sve_eorqv:
1420 case Intrinsic::aarch64_sve_nand_z:
1421 case Intrinsic::aarch64_sve_nor_z:
1422 case Intrinsic::aarch64_sve_orn_z:
1423 case Intrinsic::aarch64_sve_orr_z:
1424 case Intrinsic::aarch64_sve_orv:
1425 case Intrinsic::aarch64_sve_orqv:
1426 case Intrinsic::aarch64_sve_pnext:
1427 case Intrinsic::aarch64_sve_rdffr_z:
1428 case Intrinsic::aarch64_sve_saddv:
1429 case Intrinsic::aarch64_sve_uaddv:
1430 case Intrinsic::aarch64_sve_umaxv:
1431 case Intrinsic::aarch64_sve_umaxqv:
1432 case Intrinsic::aarch64_sve_cmpeq:
1433 case Intrinsic::aarch64_sve_cmpeq_wide:
1434 case Intrinsic::aarch64_sve_cmpge:
1435 case Intrinsic::aarch64_sve_cmpge_wide:
1436 case Intrinsic::aarch64_sve_cmpgt:
1437 case Intrinsic::aarch64_sve_cmpgt_wide:
1438 case Intrinsic::aarch64_sve_cmphi:
1439 case Intrinsic::aarch64_sve_cmphi_wide:
1440 case Intrinsic::aarch64_sve_cmphs:
1441 case Intrinsic::aarch64_sve_cmphs_wide:
1442 case Intrinsic::aarch64_sve_cmple_wide:
1443 case Intrinsic::aarch64_sve_cmplo_wide:
1444 case Intrinsic::aarch64_sve_cmpls_wide:
1445 case Intrinsic::aarch64_sve_cmplt_wide:
1446 case Intrinsic::aarch64_sve_cmpne:
1447 case Intrinsic::aarch64_sve_cmpne_wide:
1448 case Intrinsic::aarch64_sve_facge:
1449 case Intrinsic::aarch64_sve_facgt:
1450 case Intrinsic::aarch64_sve_fcmpeq:
1451 case Intrinsic::aarch64_sve_fcmpge:
1452 case Intrinsic::aarch64_sve_fcmpgt:
1453 case Intrinsic::aarch64_sve_fcmpne:
1454 case Intrinsic::aarch64_sve_fcmpuo:
1455 case Intrinsic::aarch64_sve_ld1:
1456 case Intrinsic::aarch64_sve_ld1_gather:
1457 case Intrinsic::aarch64_sve_ld1_gather_index:
1458 case Intrinsic::aarch64_sve_ld1_gather_scalar_offset:
1459 case Intrinsic::aarch64_sve_ld1_gather_sxtw:
1460 case Intrinsic::aarch64_sve_ld1_gather_sxtw_index:
1461 case Intrinsic::aarch64_sve_ld1_gather_uxtw:
1462 case Intrinsic::aarch64_sve_ld1_gather_uxtw_index:
1463 case Intrinsic::aarch64_sve_ld1q_gather_index:
1464 case Intrinsic::aarch64_sve_ld1q_gather_scalar_offset:
1465 case Intrinsic::aarch64_sve_ld1q_gather_vector_offset:
1466 case Intrinsic::aarch64_sve_ld1ro:
1467 case Intrinsic::aarch64_sve_ld1rq:
1468 case Intrinsic::aarch64_sve_ld1udq:
1469 case Intrinsic::aarch64_sve_ld1uwq:
1470 case Intrinsic::aarch64_sve_ld2_sret:
1471 case Intrinsic::aarch64_sve_ld2q_sret:
1472 case Intrinsic::aarch64_sve_ld3_sret:
1473 case Intrinsic::aarch64_sve_ld3q_sret:
1474 case Intrinsic::aarch64_sve_ld4_sret:
1475 case Intrinsic::aarch64_sve_ld4q_sret:
1476 case Intrinsic::aarch64_sve_ldff1:
1477 case Intrinsic::aarch64_sve_ldff1_gather:
1478 case Intrinsic::aarch64_sve_ldff1_gather_index:
1479 case Intrinsic::aarch64_sve_ldff1_gather_scalar_offset:
1480 case Intrinsic::aarch64_sve_ldff1_gather_sxtw:
1481 case Intrinsic::aarch64_sve_ldff1_gather_sxtw_index:
1482 case Intrinsic::aarch64_sve_ldff1_gather_uxtw:
1483 case Intrinsic::aarch64_sve_ldff1_gather_uxtw_index:
1484 case Intrinsic::aarch64_sve_ldnf1:
1485 case Intrinsic::aarch64_sve_ldnt1:
1486 case Intrinsic::aarch64_sve_ldnt1_gather:
1487 case Intrinsic::aarch64_sve_ldnt1_gather_index:
1488 case Intrinsic::aarch64_sve_ldnt1_gather_scalar_offset:
1489 case Intrinsic::aarch64_sve_ldnt1_gather_uxtw:
1490 return SVEIntrinsicInfo::defaultZeroingOp();
1491
1492 case Intrinsic::aarch64_sve_prf:
1493 case Intrinsic::aarch64_sve_prfb_gather_index:
1494 case Intrinsic::aarch64_sve_prfb_gather_scalar_offset:
1495 case Intrinsic::aarch64_sve_prfb_gather_sxtw_index:
1496 case Intrinsic::aarch64_sve_prfb_gather_uxtw_index:
1497 case Intrinsic::aarch64_sve_prfd_gather_index:
1498 case Intrinsic::aarch64_sve_prfd_gather_scalar_offset:
1499 case Intrinsic::aarch64_sve_prfd_gather_sxtw_index:
1500 case Intrinsic::aarch64_sve_prfd_gather_uxtw_index:
1501 case Intrinsic::aarch64_sve_prfh_gather_index:
1502 case Intrinsic::aarch64_sve_prfh_gather_scalar_offset:
1503 case Intrinsic::aarch64_sve_prfh_gather_sxtw_index:
1504 case Intrinsic::aarch64_sve_prfh_gather_uxtw_index:
1505 case Intrinsic::aarch64_sve_prfw_gather_index:
1506 case Intrinsic::aarch64_sve_prfw_gather_scalar_offset:
1507 case Intrinsic::aarch64_sve_prfw_gather_sxtw_index:
1508 case Intrinsic::aarch64_sve_prfw_gather_uxtw_index:
1509 return SVEIntrinsicInfo::defaultVoidOp(0);
1510
1511 case Intrinsic::aarch64_sve_st1_scatter:
1512 case Intrinsic::aarch64_sve_st1_scatter_scalar_offset:
1513 case Intrinsic::aarch64_sve_st1_scatter_sxtw:
1514 case Intrinsic::aarch64_sve_st1_scatter_sxtw_index:
1515 case Intrinsic::aarch64_sve_st1_scatter_uxtw:
1516 case Intrinsic::aarch64_sve_st1_scatter_uxtw_index:
1517 case Intrinsic::aarch64_sve_st1dq:
1518 case Intrinsic::aarch64_sve_st1q_scatter_index:
1519 case Intrinsic::aarch64_sve_st1q_scatter_scalar_offset:
1520 case Intrinsic::aarch64_sve_st1q_scatter_vector_offset:
1521 case Intrinsic::aarch64_sve_st1wq:
1522 case Intrinsic::aarch64_sve_stnt1:
1523 case Intrinsic::aarch64_sve_stnt1_scatter:
1524 case Intrinsic::aarch64_sve_stnt1_scatter_index:
1525 case Intrinsic::aarch64_sve_stnt1_scatter_scalar_offset:
1526 case Intrinsic::aarch64_sve_stnt1_scatter_uxtw:
1527 return SVEIntrinsicInfo::defaultVoidOp(1);
1528 case Intrinsic::aarch64_sve_st2:
1529 case Intrinsic::aarch64_sve_st2q:
1530 return SVEIntrinsicInfo::defaultVoidOp(2);
1531 case Intrinsic::aarch64_sve_st3:
1532 case Intrinsic::aarch64_sve_st3q:
1533 return SVEIntrinsicInfo::defaultVoidOp(3);
1534 case Intrinsic::aarch64_sve_st4:
1535 case Intrinsic::aarch64_sve_st4q:
1536 return SVEIntrinsicInfo::defaultVoidOp(4);
1537 }
1538
1539 return SVEIntrinsicInfo();
1540 }
1541
isAllActivePredicate(Value * Pred)1542 static bool isAllActivePredicate(Value *Pred) {
1543 // Look through convert.from.svbool(convert.to.svbool(...) chain.
1544 Value *UncastedPred;
1545 if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_from_svbool>(
1546 m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>(
1547 m_Value(UncastedPred)))))
1548 // If the predicate has the same or less lanes than the uncasted
1549 // predicate then we know the casting has no effect.
1550 if (cast<ScalableVectorType>(Pred->getType())->getMinNumElements() <=
1551 cast<ScalableVectorType>(UncastedPred->getType())->getMinNumElements())
1552 Pred = UncastedPred;
1553 auto *C = dyn_cast<Constant>(Pred);
1554 return (C && C->isAllOnesValue());
1555 }
1556
1557 // Simplify `V` by only considering the operations that affect active lanes.
1558 // This function should only return existing Values or newly created Constants.
stripInactiveLanes(Value * V,const Value * Pg)1559 static Value *stripInactiveLanes(Value *V, const Value *Pg) {
1560 auto *Dup = dyn_cast<IntrinsicInst>(V);
1561 if (Dup && Dup->getIntrinsicID() == Intrinsic::aarch64_sve_dup &&
1562 Dup->getOperand(1) == Pg && isa<Constant>(Dup->getOperand(2)))
1563 return ConstantVector::getSplat(
1564 cast<VectorType>(V->getType())->getElementCount(),
1565 cast<Constant>(Dup->getOperand(2)));
1566
1567 return V;
1568 }
1569
1570 static std::optional<Instruction *>
simplifySVEIntrinsicBinOp(InstCombiner & IC,IntrinsicInst & II,const SVEIntrinsicInfo & IInfo)1571 simplifySVEIntrinsicBinOp(InstCombiner &IC, IntrinsicInst &II,
1572 const SVEIntrinsicInfo &IInfo) {
1573 const unsigned Opc = IInfo.getMatchingIROpode();
1574 assert(Instruction::isBinaryOp(Opc) && "Expected a binary operation!");
1575
1576 Value *Pg = II.getOperand(0);
1577 Value *Op1 = II.getOperand(1);
1578 Value *Op2 = II.getOperand(2);
1579 const DataLayout &DL = II.getDataLayout();
1580
1581 // Canonicalise constants to the RHS.
1582 if (Instruction::isCommutative(Opc) && IInfo.inactiveLanesAreNotDefined() &&
1583 isa<Constant>(Op1) && !isa<Constant>(Op2)) {
1584 IC.replaceOperand(II, 1, Op2);
1585 IC.replaceOperand(II, 2, Op1);
1586 return &II;
1587 }
1588
1589 // Only active lanes matter when simplifying the operation.
1590 Op1 = stripInactiveLanes(Op1, Pg);
1591 Op2 = stripInactiveLanes(Op2, Pg);
1592
1593 Value *SimpleII;
1594 if (auto FII = dyn_cast<FPMathOperator>(&II))
1595 SimpleII = simplifyBinOp(Opc, Op1, Op2, FII->getFastMathFlags(), DL);
1596 else
1597 SimpleII = simplifyBinOp(Opc, Op1, Op2, DL);
1598
1599 // An SVE intrinsic's result is always defined. However, this is not the case
1600 // for its equivalent IR instruction (e.g. when shifting by an amount more
1601 // than the data's bitwidth). Simplifications to an undefined result must be
1602 // ignored to preserve the intrinsic's expected behaviour.
1603 if (!SimpleII || isa<UndefValue>(SimpleII))
1604 return std::nullopt;
1605
1606 if (IInfo.inactiveLanesAreNotDefined())
1607 return IC.replaceInstUsesWith(II, SimpleII);
1608
1609 Value *Inactive = II.getOperand(IInfo.getOperandIdxInactiveLanesTakenFrom());
1610
1611 // The intrinsic does nothing (e.g. sve.mul(pg, A, 1.0)).
1612 if (SimpleII == Inactive)
1613 return IC.replaceInstUsesWith(II, SimpleII);
1614
1615 // Inactive lanes must be preserved.
1616 SimpleII = IC.Builder.CreateSelect(Pg, SimpleII, Inactive);
1617 return IC.replaceInstUsesWith(II, SimpleII);
1618 }
1619
1620 // Use SVE intrinsic info to eliminate redundant operands and/or canonicalise
1621 // to operations with less strict inactive lane requirements.
1622 static std::optional<Instruction *>
simplifySVEIntrinsic(InstCombiner & IC,IntrinsicInst & II,const SVEIntrinsicInfo & IInfo)1623 simplifySVEIntrinsic(InstCombiner &IC, IntrinsicInst &II,
1624 const SVEIntrinsicInfo &IInfo) {
1625 if (!IInfo.hasGoverningPredicate())
1626 return std::nullopt;
1627
1628 auto *OpPredicate = II.getOperand(IInfo.getGoverningPredicateOperandIdx());
1629
1630 // If there are no active lanes.
1631 if (match(OpPredicate, m_ZeroInt())) {
1632 if (IInfo.inactiveLanesTakenFromOperand())
1633 return IC.replaceInstUsesWith(
1634 II, II.getOperand(IInfo.getOperandIdxInactiveLanesTakenFrom()));
1635
1636 if (IInfo.inactiveLanesAreUnused()) {
1637 if (IInfo.resultIsZeroInitialized())
1638 IC.replaceInstUsesWith(II, Constant::getNullValue(II.getType()));
1639
1640 return IC.eraseInstFromFunction(II);
1641 }
1642 }
1643
1644 // If there are no inactive lanes.
1645 if (isAllActivePredicate(OpPredicate)) {
1646 if (IInfo.hasOperandWithNoActiveLanes()) {
1647 unsigned OpIdx = IInfo.getOperandIdxWithNoActiveLanes();
1648 if (!isa<UndefValue>(II.getOperand(OpIdx)))
1649 return IC.replaceOperand(II, OpIdx, UndefValue::get(II.getType()));
1650 }
1651
1652 if (IInfo.hasMatchingUndefIntrinsic()) {
1653 auto *NewDecl = Intrinsic::getOrInsertDeclaration(
1654 II.getModule(), IInfo.getMatchingUndefIntrinsic(), {II.getType()});
1655 II.setCalledFunction(NewDecl);
1656 return &II;
1657 }
1658 }
1659
1660 // Operation specific simplifications.
1661 if (IInfo.hasMatchingIROpode() &&
1662 Instruction::isBinaryOp(IInfo.getMatchingIROpode()))
1663 return simplifySVEIntrinsicBinOp(IC, II, IInfo);
1664
1665 return std::nullopt;
1666 }
1667
1668 // (from_svbool (binop (to_svbool pred) (svbool_t _) (svbool_t _))))
1669 // => (binop (pred) (from_svbool _) (from_svbool _))
1670 //
1671 // The above transformation eliminates a `to_svbool` in the predicate
1672 // operand of bitwise operation `binop` by narrowing the vector width of
1673 // the operation. For example, it would convert a `<vscale x 16 x i1>
1674 // and` into a `<vscale x 4 x i1> and`. This is profitable because
1675 // to_svbool must zero the new lanes during widening, whereas
1676 // from_svbool is free.
1677 static std::optional<Instruction *>
tryCombineFromSVBoolBinOp(InstCombiner & IC,IntrinsicInst & II)1678 tryCombineFromSVBoolBinOp(InstCombiner &IC, IntrinsicInst &II) {
1679 auto BinOp = dyn_cast<IntrinsicInst>(II.getOperand(0));
1680 if (!BinOp)
1681 return std::nullopt;
1682
1683 auto IntrinsicID = BinOp->getIntrinsicID();
1684 switch (IntrinsicID) {
1685 case Intrinsic::aarch64_sve_and_z:
1686 case Intrinsic::aarch64_sve_bic_z:
1687 case Intrinsic::aarch64_sve_eor_z:
1688 case Intrinsic::aarch64_sve_nand_z:
1689 case Intrinsic::aarch64_sve_nor_z:
1690 case Intrinsic::aarch64_sve_orn_z:
1691 case Intrinsic::aarch64_sve_orr_z:
1692 break;
1693 default:
1694 return std::nullopt;
1695 }
1696
1697 auto BinOpPred = BinOp->getOperand(0);
1698 auto BinOpOp1 = BinOp->getOperand(1);
1699 auto BinOpOp2 = BinOp->getOperand(2);
1700
1701 auto PredIntr = dyn_cast<IntrinsicInst>(BinOpPred);
1702 if (!PredIntr ||
1703 PredIntr->getIntrinsicID() != Intrinsic::aarch64_sve_convert_to_svbool)
1704 return std::nullopt;
1705
1706 auto PredOp = PredIntr->getOperand(0);
1707 auto PredOpTy = cast<VectorType>(PredOp->getType());
1708 if (PredOpTy != II.getType())
1709 return std::nullopt;
1710
1711 SmallVector<Value *> NarrowedBinOpArgs = {PredOp};
1712 auto NarrowBinOpOp1 = IC.Builder.CreateIntrinsic(
1713 Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp1});
1714 NarrowedBinOpArgs.push_back(NarrowBinOpOp1);
1715 if (BinOpOp1 == BinOpOp2)
1716 NarrowedBinOpArgs.push_back(NarrowBinOpOp1);
1717 else
1718 NarrowedBinOpArgs.push_back(IC.Builder.CreateIntrinsic(
1719 Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp2}));
1720
1721 auto NarrowedBinOp =
1722 IC.Builder.CreateIntrinsic(IntrinsicID, {PredOpTy}, NarrowedBinOpArgs);
1723 return IC.replaceInstUsesWith(II, NarrowedBinOp);
1724 }
1725
1726 static std::optional<Instruction *>
instCombineConvertFromSVBool(InstCombiner & IC,IntrinsicInst & II)1727 instCombineConvertFromSVBool(InstCombiner &IC, IntrinsicInst &II) {
1728 // If the reinterpret instruction operand is a PHI Node
1729 if (isa<PHINode>(II.getArgOperand(0)))
1730 return processPhiNode(IC, II);
1731
1732 if (auto BinOpCombine = tryCombineFromSVBoolBinOp(IC, II))
1733 return BinOpCombine;
1734
1735 // Ignore converts to/from svcount_t.
1736 if (isa<TargetExtType>(II.getArgOperand(0)->getType()) ||
1737 isa<TargetExtType>(II.getType()))
1738 return std::nullopt;
1739
1740 SmallVector<Instruction *, 32> CandidatesForRemoval;
1741 Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr;
1742
1743 const auto *IVTy = cast<VectorType>(II.getType());
1744
1745 // Walk the chain of conversions.
1746 while (Cursor) {
1747 // If the type of the cursor has fewer lanes than the final result, zeroing
1748 // must take place, which breaks the equivalence chain.
1749 const auto *CursorVTy = cast<VectorType>(Cursor->getType());
1750 if (CursorVTy->getElementCount().getKnownMinValue() <
1751 IVTy->getElementCount().getKnownMinValue())
1752 break;
1753
1754 // If the cursor has the same type as I, it is a viable replacement.
1755 if (Cursor->getType() == IVTy)
1756 EarliestReplacement = Cursor;
1757
1758 auto *IntrinsicCursor = dyn_cast<IntrinsicInst>(Cursor);
1759
1760 // If this is not an SVE conversion intrinsic, this is the end of the chain.
1761 if (!IntrinsicCursor || !(IntrinsicCursor->getIntrinsicID() ==
1762 Intrinsic::aarch64_sve_convert_to_svbool ||
1763 IntrinsicCursor->getIntrinsicID() ==
1764 Intrinsic::aarch64_sve_convert_from_svbool))
1765 break;
1766
1767 CandidatesForRemoval.insert(CandidatesForRemoval.begin(), IntrinsicCursor);
1768 Cursor = IntrinsicCursor->getOperand(0);
1769 }
1770
1771 // If no viable replacement in the conversion chain was found, there is
1772 // nothing to do.
1773 if (!EarliestReplacement)
1774 return std::nullopt;
1775
1776 return IC.replaceInstUsesWith(II, EarliestReplacement);
1777 }
1778
instCombineSVESel(InstCombiner & IC,IntrinsicInst & II)1779 static std::optional<Instruction *> instCombineSVESel(InstCombiner &IC,
1780 IntrinsicInst &II) {
1781 // svsel(ptrue, x, y) => x
1782 auto *OpPredicate = II.getOperand(0);
1783 if (isAllActivePredicate(OpPredicate))
1784 return IC.replaceInstUsesWith(II, II.getOperand(1));
1785
1786 auto Select =
1787 IC.Builder.CreateSelect(OpPredicate, II.getOperand(1), II.getOperand(2));
1788 return IC.replaceInstUsesWith(II, Select);
1789 }
1790
instCombineSVEDup(InstCombiner & IC,IntrinsicInst & II)1791 static std::optional<Instruction *> instCombineSVEDup(InstCombiner &IC,
1792 IntrinsicInst &II) {
1793 IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
1794 if (!Pg)
1795 return std::nullopt;
1796
1797 if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
1798 return std::nullopt;
1799
1800 const auto PTruePattern =
1801 cast<ConstantInt>(Pg->getOperand(0))->getZExtValue();
1802 if (PTruePattern != AArch64SVEPredPattern::vl1)
1803 return std::nullopt;
1804
1805 // The intrinsic is inserting into lane zero so use an insert instead.
1806 auto *IdxTy = Type::getInt64Ty(II.getContext());
1807 auto *Insert = InsertElementInst::Create(
1808 II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0));
1809 Insert->insertBefore(II.getIterator());
1810 Insert->takeName(&II);
1811
1812 return IC.replaceInstUsesWith(II, Insert);
1813 }
1814
instCombineSVEDupX(InstCombiner & IC,IntrinsicInst & II)1815 static std::optional<Instruction *> instCombineSVEDupX(InstCombiner &IC,
1816 IntrinsicInst &II) {
1817 // Replace DupX with a regular IR splat.
1818 auto *RetTy = cast<ScalableVectorType>(II.getType());
1819 Value *Splat = IC.Builder.CreateVectorSplat(RetTy->getElementCount(),
1820 II.getArgOperand(0));
1821 Splat->takeName(&II);
1822 return IC.replaceInstUsesWith(II, Splat);
1823 }
1824
instCombineSVECmpNE(InstCombiner & IC,IntrinsicInst & II)1825 static std::optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC,
1826 IntrinsicInst &II) {
1827 LLVMContext &Ctx = II.getContext();
1828
1829 if (!isAllActivePredicate(II.getArgOperand(0)))
1830 return std::nullopt;
1831
1832 // Check that we have a compare of zero..
1833 auto *SplatValue =
1834 dyn_cast_or_null<ConstantInt>(getSplatValue(II.getArgOperand(2)));
1835 if (!SplatValue || !SplatValue->isZero())
1836 return std::nullopt;
1837
1838 // ..against a dupq
1839 auto *DupQLane = dyn_cast<IntrinsicInst>(II.getArgOperand(1));
1840 if (!DupQLane ||
1841 DupQLane->getIntrinsicID() != Intrinsic::aarch64_sve_dupq_lane)
1842 return std::nullopt;
1843
1844 // Where the dupq is a lane 0 replicate of a vector insert
1845 auto *DupQLaneIdx = dyn_cast<ConstantInt>(DupQLane->getArgOperand(1));
1846 if (!DupQLaneIdx || !DupQLaneIdx->isZero())
1847 return std::nullopt;
1848
1849 auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0));
1850 if (!VecIns || VecIns->getIntrinsicID() != Intrinsic::vector_insert)
1851 return std::nullopt;
1852
1853 // Where the vector insert is a fixed constant vector insert into undef at
1854 // index zero
1855 if (!isa<UndefValue>(VecIns->getArgOperand(0)))
1856 return std::nullopt;
1857
1858 if (!cast<ConstantInt>(VecIns->getArgOperand(2))->isZero())
1859 return std::nullopt;
1860
1861 auto *ConstVec = dyn_cast<Constant>(VecIns->getArgOperand(1));
1862 if (!ConstVec)
1863 return std::nullopt;
1864
1865 auto *VecTy = dyn_cast<FixedVectorType>(ConstVec->getType());
1866 auto *OutTy = dyn_cast<ScalableVectorType>(II.getType());
1867 if (!VecTy || !OutTy || VecTy->getNumElements() != OutTy->getMinNumElements())
1868 return std::nullopt;
1869
1870 unsigned NumElts = VecTy->getNumElements();
1871 unsigned PredicateBits = 0;
1872
1873 // Expand intrinsic operands to a 16-bit byte level predicate
1874 for (unsigned I = 0; I < NumElts; ++I) {
1875 auto *Arg = dyn_cast<ConstantInt>(ConstVec->getAggregateElement(I));
1876 if (!Arg)
1877 return std::nullopt;
1878 if (!Arg->isZero())
1879 PredicateBits |= 1 << (I * (16 / NumElts));
1880 }
1881
1882 // If all bits are zero bail early with an empty predicate
1883 if (PredicateBits == 0) {
1884 auto *PFalse = Constant::getNullValue(II.getType());
1885 PFalse->takeName(&II);
1886 return IC.replaceInstUsesWith(II, PFalse);
1887 }
1888
1889 // Calculate largest predicate type used (where byte predicate is largest)
1890 unsigned Mask = 8;
1891 for (unsigned I = 0; I < 16; ++I)
1892 if ((PredicateBits & (1 << I)) != 0)
1893 Mask |= (I % 8);
1894
1895 unsigned PredSize = Mask & -Mask;
1896 auto *PredType = ScalableVectorType::get(
1897 Type::getInt1Ty(Ctx), AArch64::SVEBitsPerBlock / (PredSize * 8));
1898
1899 // Ensure all relevant bits are set
1900 for (unsigned I = 0; I < 16; I += PredSize)
1901 if ((PredicateBits & (1 << I)) == 0)
1902 return std::nullopt;
1903
1904 auto *PTruePat =
1905 ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
1906 auto *PTrue = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
1907 {PredType}, {PTruePat});
1908 auto *ConvertToSVBool = IC.Builder.CreateIntrinsic(
1909 Intrinsic::aarch64_sve_convert_to_svbool, {PredType}, {PTrue});
1910 auto *ConvertFromSVBool =
1911 IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool,
1912 {II.getType()}, {ConvertToSVBool});
1913
1914 ConvertFromSVBool->takeName(&II);
1915 return IC.replaceInstUsesWith(II, ConvertFromSVBool);
1916 }
1917
instCombineSVELast(InstCombiner & IC,IntrinsicInst & II)1918 static std::optional<Instruction *> instCombineSVELast(InstCombiner &IC,
1919 IntrinsicInst &II) {
1920 Value *Pg = II.getArgOperand(0);
1921 Value *Vec = II.getArgOperand(1);
1922 auto IntrinsicID = II.getIntrinsicID();
1923 bool IsAfter = IntrinsicID == Intrinsic::aarch64_sve_lasta;
1924
1925 // lastX(splat(X)) --> X
1926 if (auto *SplatVal = getSplatValue(Vec))
1927 return IC.replaceInstUsesWith(II, SplatVal);
1928
1929 // If x and/or y is a splat value then:
1930 // lastX (binop (x, y)) --> binop(lastX(x), lastX(y))
1931 Value *LHS, *RHS;
1932 if (match(Vec, m_OneUse(m_BinOp(m_Value(LHS), m_Value(RHS))))) {
1933 if (isSplatValue(LHS) || isSplatValue(RHS)) {
1934 auto *OldBinOp = cast<BinaryOperator>(Vec);
1935 auto OpC = OldBinOp->getOpcode();
1936 auto *NewLHS =
1937 IC.Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, LHS});
1938 auto *NewRHS =
1939 IC.Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, RHS});
1940 auto *NewBinOp = BinaryOperator::CreateWithCopiedFlags(
1941 OpC, NewLHS, NewRHS, OldBinOp, OldBinOp->getName(), II.getIterator());
1942 return IC.replaceInstUsesWith(II, NewBinOp);
1943 }
1944 }
1945
1946 auto *C = dyn_cast<Constant>(Pg);
1947 if (IsAfter && C && C->isNullValue()) {
1948 // The intrinsic is extracting lane 0 so use an extract instead.
1949 auto *IdxTy = Type::getInt64Ty(II.getContext());
1950 auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, 0));
1951 Extract->insertBefore(II.getIterator());
1952 Extract->takeName(&II);
1953 return IC.replaceInstUsesWith(II, Extract);
1954 }
1955
1956 auto *IntrPG = dyn_cast<IntrinsicInst>(Pg);
1957 if (!IntrPG)
1958 return std::nullopt;
1959
1960 if (IntrPG->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue)
1961 return std::nullopt;
1962
1963 const auto PTruePattern =
1964 cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue();
1965
1966 // Can the intrinsic's predicate be converted to a known constant index?
1967 unsigned MinNumElts = getNumElementsFromSVEPredPattern(PTruePattern);
1968 if (!MinNumElts)
1969 return std::nullopt;
1970
1971 unsigned Idx = MinNumElts - 1;
1972 // Increment the index if extracting the element after the last active
1973 // predicate element.
1974 if (IsAfter)
1975 ++Idx;
1976
1977 // Ignore extracts whose index is larger than the known minimum vector
1978 // length. NOTE: This is an artificial constraint where we prefer to
1979 // maintain what the user asked for until an alternative is proven faster.
1980 auto *PgVTy = cast<ScalableVectorType>(Pg->getType());
1981 if (Idx >= PgVTy->getMinNumElements())
1982 return std::nullopt;
1983
1984 // The intrinsic is extracting a fixed lane so use an extract instead.
1985 auto *IdxTy = Type::getInt64Ty(II.getContext());
1986 auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, Idx));
1987 Extract->insertBefore(II.getIterator());
1988 Extract->takeName(&II);
1989 return IC.replaceInstUsesWith(II, Extract);
1990 }
1991
instCombineSVECondLast(InstCombiner & IC,IntrinsicInst & II)1992 static std::optional<Instruction *> instCombineSVECondLast(InstCombiner &IC,
1993 IntrinsicInst &II) {
1994 // The SIMD&FP variant of CLAST[AB] is significantly faster than the scalar
1995 // integer variant across a variety of micro-architectures. Replace scalar
1996 // integer CLAST[AB] intrinsic with optimal SIMD&FP variant. A simple
1997 // bitcast-to-fp + clast[ab] + bitcast-to-int will cost a cycle or two more
1998 // depending on the micro-architecture, but has been observed as generally
1999 // being faster, particularly when the CLAST[AB] op is a loop-carried
2000 // dependency.
2001 Value *Pg = II.getArgOperand(0);
2002 Value *Fallback = II.getArgOperand(1);
2003 Value *Vec = II.getArgOperand(2);
2004 Type *Ty = II.getType();
2005
2006 if (!Ty->isIntegerTy())
2007 return std::nullopt;
2008
2009 Type *FPTy;
2010 switch (cast<IntegerType>(Ty)->getBitWidth()) {
2011 default:
2012 return std::nullopt;
2013 case 16:
2014 FPTy = IC.Builder.getHalfTy();
2015 break;
2016 case 32:
2017 FPTy = IC.Builder.getFloatTy();
2018 break;
2019 case 64:
2020 FPTy = IC.Builder.getDoubleTy();
2021 break;
2022 }
2023
2024 Value *FPFallBack = IC.Builder.CreateBitCast(Fallback, FPTy);
2025 auto *FPVTy = VectorType::get(
2026 FPTy, cast<VectorType>(Vec->getType())->getElementCount());
2027 Value *FPVec = IC.Builder.CreateBitCast(Vec, FPVTy);
2028 auto *FPII = IC.Builder.CreateIntrinsic(
2029 II.getIntrinsicID(), {FPVec->getType()}, {Pg, FPFallBack, FPVec});
2030 Value *FPIItoInt = IC.Builder.CreateBitCast(FPII, II.getType());
2031 return IC.replaceInstUsesWith(II, FPIItoInt);
2032 }
2033
instCombineRDFFR(InstCombiner & IC,IntrinsicInst & II)2034 static std::optional<Instruction *> instCombineRDFFR(InstCombiner &IC,
2035 IntrinsicInst &II) {
2036 LLVMContext &Ctx = II.getContext();
2037 // Replace rdffr with predicated rdffr.z intrinsic, so that optimizePTestInstr
2038 // can work with RDFFR_PP for ptest elimination.
2039 auto *AllPat =
2040 ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all);
2041 auto *PTrue = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue,
2042 {II.getType()}, {AllPat});
2043 auto *RDFFR =
2044 IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_rdffr_z, {PTrue});
2045 RDFFR->takeName(&II);
2046 return IC.replaceInstUsesWith(II, RDFFR);
2047 }
2048
2049 static std::optional<Instruction *>
instCombineSVECntElts(InstCombiner & IC,IntrinsicInst & II,unsigned NumElts)2050 instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) {
2051 const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue();
2052
2053 if (Pattern == AArch64SVEPredPattern::all) {
2054 Value *Cnt = IC.Builder.CreateElementCount(
2055 II.getType(), ElementCount::getScalable(NumElts));
2056 Cnt->takeName(&II);
2057 return IC.replaceInstUsesWith(II, Cnt);
2058 }
2059
2060 unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern);
2061
2062 return MinNumElts && NumElts >= MinNumElts
2063 ? std::optional<Instruction *>(IC.replaceInstUsesWith(
2064 II, ConstantInt::get(II.getType(), MinNumElts)))
2065 : std::nullopt;
2066 }
2067
instCombineSVEPTest(InstCombiner & IC,IntrinsicInst & II)2068 static std::optional<Instruction *> instCombineSVEPTest(InstCombiner &IC,
2069 IntrinsicInst &II) {
2070 Value *PgVal = II.getArgOperand(0);
2071 Value *OpVal = II.getArgOperand(1);
2072
2073 // PTEST_<FIRST|LAST>(X, X) is equivalent to PTEST_ANY(X, X).
2074 // Later optimizations prefer this form.
2075 if (PgVal == OpVal &&
2076 (II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_first ||
2077 II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_last)) {
2078 Value *Ops[] = {PgVal, OpVal};
2079 Type *Tys[] = {PgVal->getType()};
2080
2081 auto *PTest =
2082 IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptest_any, Tys, Ops);
2083 PTest->takeName(&II);
2084
2085 return IC.replaceInstUsesWith(II, PTest);
2086 }
2087
2088 IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(PgVal);
2089 IntrinsicInst *Op = dyn_cast<IntrinsicInst>(OpVal);
2090
2091 if (!Pg || !Op)
2092 return std::nullopt;
2093
2094 Intrinsic::ID OpIID = Op->getIntrinsicID();
2095
2096 if (Pg->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool &&
2097 OpIID == Intrinsic::aarch64_sve_convert_to_svbool &&
2098 Pg->getArgOperand(0)->getType() == Op->getArgOperand(0)->getType()) {
2099 Value *Ops[] = {Pg->getArgOperand(0), Op->getArgOperand(0)};
2100 Type *Tys[] = {Pg->getArgOperand(0)->getType()};
2101
2102 auto *PTest = IC.Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
2103
2104 PTest->takeName(&II);
2105 return IC.replaceInstUsesWith(II, PTest);
2106 }
2107
2108 // Transform PTEST_ANY(X=OP(PG,...), X) -> PTEST_ANY(PG, X)).
2109 // Later optimizations may rewrite sequence to use the flag-setting variant
2110 // of instruction X to remove PTEST.
2111 if ((Pg == Op) && (II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_any) &&
2112 ((OpIID == Intrinsic::aarch64_sve_brka_z) ||
2113 (OpIID == Intrinsic::aarch64_sve_brkb_z) ||
2114 (OpIID == Intrinsic::aarch64_sve_brkpa_z) ||
2115 (OpIID == Intrinsic::aarch64_sve_brkpb_z) ||
2116 (OpIID == Intrinsic::aarch64_sve_rdffr_z) ||
2117 (OpIID == Intrinsic::aarch64_sve_and_z) ||
2118 (OpIID == Intrinsic::aarch64_sve_bic_z) ||
2119 (OpIID == Intrinsic::aarch64_sve_eor_z) ||
2120 (OpIID == Intrinsic::aarch64_sve_nand_z) ||
2121 (OpIID == Intrinsic::aarch64_sve_nor_z) ||
2122 (OpIID == Intrinsic::aarch64_sve_orn_z) ||
2123 (OpIID == Intrinsic::aarch64_sve_orr_z))) {
2124 Value *Ops[] = {Pg->getArgOperand(0), Pg};
2125 Type *Tys[] = {Pg->getType()};
2126
2127 auto *PTest = IC.Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops);
2128 PTest->takeName(&II);
2129
2130 return IC.replaceInstUsesWith(II, PTest);
2131 }
2132
2133 return std::nullopt;
2134 }
2135
2136 template <Intrinsic::ID MulOpc, typename Intrinsic::ID FuseOpc>
2137 static std::optional<Instruction *>
instCombineSVEVectorFuseMulAddSub(InstCombiner & IC,IntrinsicInst & II,bool MergeIntoAddendOp)2138 instCombineSVEVectorFuseMulAddSub(InstCombiner &IC, IntrinsicInst &II,
2139 bool MergeIntoAddendOp) {
2140 Value *P = II.getOperand(0);
2141 Value *MulOp0, *MulOp1, *AddendOp, *Mul;
2142 if (MergeIntoAddendOp) {
2143 AddendOp = II.getOperand(1);
2144 Mul = II.getOperand(2);
2145 } else {
2146 AddendOp = II.getOperand(2);
2147 Mul = II.getOperand(1);
2148 }
2149
2150 if (!match(Mul, m_Intrinsic<MulOpc>(m_Specific(P), m_Value(MulOp0),
2151 m_Value(MulOp1))))
2152 return std::nullopt;
2153
2154 if (!Mul->hasOneUse())
2155 return std::nullopt;
2156
2157 Instruction *FMFSource = nullptr;
2158 if (II.getType()->isFPOrFPVectorTy()) {
2159 llvm::FastMathFlags FAddFlags = II.getFastMathFlags();
2160 // Stop the combine when the flags on the inputs differ in case dropping
2161 // flags would lead to us missing out on more beneficial optimizations.
2162 if (FAddFlags != cast<CallInst>(Mul)->getFastMathFlags())
2163 return std::nullopt;
2164 if (!FAddFlags.allowContract())
2165 return std::nullopt;
2166 FMFSource = &II;
2167 }
2168
2169 CallInst *Res;
2170 if (MergeIntoAddendOp)
2171 Res = IC.Builder.CreateIntrinsic(FuseOpc, {II.getType()},
2172 {P, AddendOp, MulOp0, MulOp1}, FMFSource);
2173 else
2174 Res = IC.Builder.CreateIntrinsic(FuseOpc, {II.getType()},
2175 {P, MulOp0, MulOp1, AddendOp}, FMFSource);
2176
2177 return IC.replaceInstUsesWith(II, Res);
2178 }
2179
2180 static std::optional<Instruction *>
instCombineSVELD1(InstCombiner & IC,IntrinsicInst & II,const DataLayout & DL)2181 instCombineSVELD1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
2182 Value *Pred = II.getOperand(0);
2183 Value *PtrOp = II.getOperand(1);
2184 Type *VecTy = II.getType();
2185
2186 if (isAllActivePredicate(Pred)) {
2187 LoadInst *Load = IC.Builder.CreateLoad(VecTy, PtrOp);
2188 Load->copyMetadata(II);
2189 return IC.replaceInstUsesWith(II, Load);
2190 }
2191
2192 CallInst *MaskedLoad =
2193 IC.Builder.CreateMaskedLoad(VecTy, PtrOp, PtrOp->getPointerAlignment(DL),
2194 Pred, ConstantAggregateZero::get(VecTy));
2195 MaskedLoad->copyMetadata(II);
2196 return IC.replaceInstUsesWith(II, MaskedLoad);
2197 }
2198
2199 static std::optional<Instruction *>
instCombineSVEST1(InstCombiner & IC,IntrinsicInst & II,const DataLayout & DL)2200 instCombineSVEST1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) {
2201 Value *VecOp = II.getOperand(0);
2202 Value *Pred = II.getOperand(1);
2203 Value *PtrOp = II.getOperand(2);
2204
2205 if (isAllActivePredicate(Pred)) {
2206 StoreInst *Store = IC.Builder.CreateStore(VecOp, PtrOp);
2207 Store->copyMetadata(II);
2208 return IC.eraseInstFromFunction(II);
2209 }
2210
2211 CallInst *MaskedStore = IC.Builder.CreateMaskedStore(
2212 VecOp, PtrOp, PtrOp->getPointerAlignment(DL), Pred);
2213 MaskedStore->copyMetadata(II);
2214 return IC.eraseInstFromFunction(II);
2215 }
2216
intrinsicIDToBinOpCode(unsigned Intrinsic)2217 static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) {
2218 switch (Intrinsic) {
2219 case Intrinsic::aarch64_sve_fmul_u:
2220 return Instruction::BinaryOps::FMul;
2221 case Intrinsic::aarch64_sve_fadd_u:
2222 return Instruction::BinaryOps::FAdd;
2223 case Intrinsic::aarch64_sve_fsub_u:
2224 return Instruction::BinaryOps::FSub;
2225 default:
2226 return Instruction::BinaryOpsEnd;
2227 }
2228 }
2229
2230 static std::optional<Instruction *>
instCombineSVEVectorBinOp(InstCombiner & IC,IntrinsicInst & II)2231 instCombineSVEVectorBinOp(InstCombiner &IC, IntrinsicInst &II) {
2232 // Bail due to missing support for ISD::STRICT_ scalable vector operations.
2233 if (II.isStrictFP())
2234 return std::nullopt;
2235
2236 auto *OpPredicate = II.getOperand(0);
2237 auto BinOpCode = intrinsicIDToBinOpCode(II.getIntrinsicID());
2238 if (BinOpCode == Instruction::BinaryOpsEnd ||
2239 !isAllActivePredicate(OpPredicate))
2240 return std::nullopt;
2241 auto BinOp = IC.Builder.CreateBinOpFMF(
2242 BinOpCode, II.getOperand(1), II.getOperand(2), II.getFastMathFlags());
2243 return IC.replaceInstUsesWith(II, BinOp);
2244 }
2245
instCombineSVEVectorAdd(InstCombiner & IC,IntrinsicInst & II)2246 static std::optional<Instruction *> instCombineSVEVectorAdd(InstCombiner &IC,
2247 IntrinsicInst &II) {
2248 if (auto MLA = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
2249 Intrinsic::aarch64_sve_mla>(
2250 IC, II, true))
2251 return MLA;
2252 if (auto MAD = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
2253 Intrinsic::aarch64_sve_mad>(
2254 IC, II, false))
2255 return MAD;
2256 return std::nullopt;
2257 }
2258
2259 static std::optional<Instruction *>
instCombineSVEVectorFAdd(InstCombiner & IC,IntrinsicInst & II)2260 instCombineSVEVectorFAdd(InstCombiner &IC, IntrinsicInst &II) {
2261 if (auto FMLA =
2262 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
2263 Intrinsic::aarch64_sve_fmla>(IC, II,
2264 true))
2265 return FMLA;
2266 if (auto FMAD =
2267 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
2268 Intrinsic::aarch64_sve_fmad>(IC, II,
2269 false))
2270 return FMAD;
2271 if (auto FMLA =
2272 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
2273 Intrinsic::aarch64_sve_fmla>(IC, II,
2274 true))
2275 return FMLA;
2276 return std::nullopt;
2277 }
2278
2279 static std::optional<Instruction *>
instCombineSVEVectorFAddU(InstCombiner & IC,IntrinsicInst & II)2280 instCombineSVEVectorFAddU(InstCombiner &IC, IntrinsicInst &II) {
2281 if (auto FMLA =
2282 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
2283 Intrinsic::aarch64_sve_fmla>(IC, II,
2284 true))
2285 return FMLA;
2286 if (auto FMAD =
2287 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
2288 Intrinsic::aarch64_sve_fmad>(IC, II,
2289 false))
2290 return FMAD;
2291 if (auto FMLA_U =
2292 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
2293 Intrinsic::aarch64_sve_fmla_u>(
2294 IC, II, true))
2295 return FMLA_U;
2296 return instCombineSVEVectorBinOp(IC, II);
2297 }
2298
2299 static std::optional<Instruction *>
instCombineSVEVectorFSub(InstCombiner & IC,IntrinsicInst & II)2300 instCombineSVEVectorFSub(InstCombiner &IC, IntrinsicInst &II) {
2301 if (auto FMLS =
2302 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
2303 Intrinsic::aarch64_sve_fmls>(IC, II,
2304 true))
2305 return FMLS;
2306 if (auto FMSB =
2307 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
2308 Intrinsic::aarch64_sve_fnmsb>(
2309 IC, II, false))
2310 return FMSB;
2311 if (auto FMLS =
2312 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
2313 Intrinsic::aarch64_sve_fmls>(IC, II,
2314 true))
2315 return FMLS;
2316 return std::nullopt;
2317 }
2318
2319 static std::optional<Instruction *>
instCombineSVEVectorFSubU(InstCombiner & IC,IntrinsicInst & II)2320 instCombineSVEVectorFSubU(InstCombiner &IC, IntrinsicInst &II) {
2321 if (auto FMLS =
2322 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
2323 Intrinsic::aarch64_sve_fmls>(IC, II,
2324 true))
2325 return FMLS;
2326 if (auto FMSB =
2327 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul,
2328 Intrinsic::aarch64_sve_fnmsb>(
2329 IC, II, false))
2330 return FMSB;
2331 if (auto FMLS_U =
2332 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u,
2333 Intrinsic::aarch64_sve_fmls_u>(
2334 IC, II, true))
2335 return FMLS_U;
2336 return instCombineSVEVectorBinOp(IC, II);
2337 }
2338
instCombineSVEVectorSub(InstCombiner & IC,IntrinsicInst & II)2339 static std::optional<Instruction *> instCombineSVEVectorSub(InstCombiner &IC,
2340 IntrinsicInst &II) {
2341 if (auto MLS = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul,
2342 Intrinsic::aarch64_sve_mls>(
2343 IC, II, true))
2344 return MLS;
2345 return std::nullopt;
2346 }
2347
instCombineSVEUnpack(InstCombiner & IC,IntrinsicInst & II)2348 static std::optional<Instruction *> instCombineSVEUnpack(InstCombiner &IC,
2349 IntrinsicInst &II) {
2350 Value *UnpackArg = II.getArgOperand(0);
2351 auto *RetTy = cast<ScalableVectorType>(II.getType());
2352 bool IsSigned = II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpkhi ||
2353 II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpklo;
2354
2355 // Hi = uunpkhi(splat(X)) --> Hi = splat(extend(X))
2356 // Lo = uunpklo(splat(X)) --> Lo = splat(extend(X))
2357 if (auto *ScalarArg = getSplatValue(UnpackArg)) {
2358 ScalarArg =
2359 IC.Builder.CreateIntCast(ScalarArg, RetTy->getScalarType(), IsSigned);
2360 Value *NewVal =
2361 IC.Builder.CreateVectorSplat(RetTy->getElementCount(), ScalarArg);
2362 NewVal->takeName(&II);
2363 return IC.replaceInstUsesWith(II, NewVal);
2364 }
2365
2366 return std::nullopt;
2367 }
instCombineSVETBL(InstCombiner & IC,IntrinsicInst & II)2368 static std::optional<Instruction *> instCombineSVETBL(InstCombiner &IC,
2369 IntrinsicInst &II) {
2370 auto *OpVal = II.getOperand(0);
2371 auto *OpIndices = II.getOperand(1);
2372 VectorType *VTy = cast<VectorType>(II.getType());
2373
2374 // Check whether OpIndices is a constant splat value < minimal element count
2375 // of result.
2376 auto *SplatValue = dyn_cast_or_null<ConstantInt>(getSplatValue(OpIndices));
2377 if (!SplatValue ||
2378 SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue()))
2379 return std::nullopt;
2380
2381 // Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to
2382 // splat_vector(extractelement(OpVal, SplatValue)) for further optimization.
2383 auto *Extract = IC.Builder.CreateExtractElement(OpVal, SplatValue);
2384 auto *VectorSplat =
2385 IC.Builder.CreateVectorSplat(VTy->getElementCount(), Extract);
2386
2387 VectorSplat->takeName(&II);
2388 return IC.replaceInstUsesWith(II, VectorSplat);
2389 }
2390
instCombineSVEUzp1(InstCombiner & IC,IntrinsicInst & II)2391 static std::optional<Instruction *> instCombineSVEUzp1(InstCombiner &IC,
2392 IntrinsicInst &II) {
2393 Value *A, *B;
2394 Type *RetTy = II.getType();
2395 constexpr Intrinsic::ID FromSVB = Intrinsic::aarch64_sve_convert_from_svbool;
2396 constexpr Intrinsic::ID ToSVB = Intrinsic::aarch64_sve_convert_to_svbool;
2397
2398 // uzp1(to_svbool(A), to_svbool(B)) --> <A, B>
2399 // uzp1(from_svbool(to_svbool(A)), from_svbool(to_svbool(B))) --> <A, B>
2400 if ((match(II.getArgOperand(0),
2401 m_Intrinsic<FromSVB>(m_Intrinsic<ToSVB>(m_Value(A)))) &&
2402 match(II.getArgOperand(1),
2403 m_Intrinsic<FromSVB>(m_Intrinsic<ToSVB>(m_Value(B))))) ||
2404 (match(II.getArgOperand(0), m_Intrinsic<ToSVB>(m_Value(A))) &&
2405 match(II.getArgOperand(1), m_Intrinsic<ToSVB>(m_Value(B))))) {
2406 auto *TyA = cast<ScalableVectorType>(A->getType());
2407 if (TyA == B->getType() &&
2408 RetTy == ScalableVectorType::getDoubleElementsVectorType(TyA)) {
2409 auto *SubVec = IC.Builder.CreateInsertVector(
2410 RetTy, PoisonValue::get(RetTy), A, uint64_t(0));
2411 auto *ConcatVec = IC.Builder.CreateInsertVector(RetTy, SubVec, B,
2412 TyA->getMinNumElements());
2413 ConcatVec->takeName(&II);
2414 return IC.replaceInstUsesWith(II, ConcatVec);
2415 }
2416 }
2417
2418 return std::nullopt;
2419 }
2420
instCombineSVEZip(InstCombiner & IC,IntrinsicInst & II)2421 static std::optional<Instruction *> instCombineSVEZip(InstCombiner &IC,
2422 IntrinsicInst &II) {
2423 // zip1(uzp1(A, B), uzp2(A, B)) --> A
2424 // zip2(uzp1(A, B), uzp2(A, B)) --> B
2425 Value *A, *B;
2426 if (match(II.getArgOperand(0),
2427 m_Intrinsic<Intrinsic::aarch64_sve_uzp1>(m_Value(A), m_Value(B))) &&
2428 match(II.getArgOperand(1), m_Intrinsic<Intrinsic::aarch64_sve_uzp2>(
2429 m_Specific(A), m_Specific(B))))
2430 return IC.replaceInstUsesWith(
2431 II, (II.getIntrinsicID() == Intrinsic::aarch64_sve_zip1 ? A : B));
2432
2433 return std::nullopt;
2434 }
2435
2436 static std::optional<Instruction *>
instCombineLD1GatherIndex(InstCombiner & IC,IntrinsicInst & II)2437 instCombineLD1GatherIndex(InstCombiner &IC, IntrinsicInst &II) {
2438 Value *Mask = II.getOperand(0);
2439 Value *BasePtr = II.getOperand(1);
2440 Value *Index = II.getOperand(2);
2441 Type *Ty = II.getType();
2442 Value *PassThru = ConstantAggregateZero::get(Ty);
2443
2444 // Contiguous gather => masked load.
2445 // (sve.ld1.gather.index Mask BasePtr (sve.index IndexBase 1))
2446 // => (masked.load (gep BasePtr IndexBase) Align Mask zeroinitializer)
2447 Value *IndexBase;
2448 if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
2449 m_Value(IndexBase), m_SpecificInt(1)))) {
2450 Align Alignment =
2451 BasePtr->getPointerAlignment(II.getDataLayout());
2452
2453 Value *Ptr = IC.Builder.CreateGEP(cast<VectorType>(Ty)->getElementType(),
2454 BasePtr, IndexBase);
2455 CallInst *MaskedLoad =
2456 IC.Builder.CreateMaskedLoad(Ty, Ptr, Alignment, Mask, PassThru);
2457 MaskedLoad->takeName(&II);
2458 return IC.replaceInstUsesWith(II, MaskedLoad);
2459 }
2460
2461 return std::nullopt;
2462 }
2463
2464 static std::optional<Instruction *>
instCombineST1ScatterIndex(InstCombiner & IC,IntrinsicInst & II)2465 instCombineST1ScatterIndex(InstCombiner &IC, IntrinsicInst &II) {
2466 Value *Val = II.getOperand(0);
2467 Value *Mask = II.getOperand(1);
2468 Value *BasePtr = II.getOperand(2);
2469 Value *Index = II.getOperand(3);
2470 Type *Ty = Val->getType();
2471
2472 // Contiguous scatter => masked store.
2473 // (sve.st1.scatter.index Value Mask BasePtr (sve.index IndexBase 1))
2474 // => (masked.store Value (gep BasePtr IndexBase) Align Mask)
2475 Value *IndexBase;
2476 if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>(
2477 m_Value(IndexBase), m_SpecificInt(1)))) {
2478 Align Alignment =
2479 BasePtr->getPointerAlignment(II.getDataLayout());
2480
2481 Value *Ptr = IC.Builder.CreateGEP(cast<VectorType>(Ty)->getElementType(),
2482 BasePtr, IndexBase);
2483 (void)IC.Builder.CreateMaskedStore(Val, Ptr, Alignment, Mask);
2484
2485 return IC.eraseInstFromFunction(II);
2486 }
2487
2488 return std::nullopt;
2489 }
2490
instCombineSVESDIV(InstCombiner & IC,IntrinsicInst & II)2491 static std::optional<Instruction *> instCombineSVESDIV(InstCombiner &IC,
2492 IntrinsicInst &II) {
2493 Type *Int32Ty = IC.Builder.getInt32Ty();
2494 Value *Pred = II.getOperand(0);
2495 Value *Vec = II.getOperand(1);
2496 Value *DivVec = II.getOperand(2);
2497
2498 Value *SplatValue = getSplatValue(DivVec);
2499 ConstantInt *SplatConstantInt = dyn_cast_or_null<ConstantInt>(SplatValue);
2500 if (!SplatConstantInt)
2501 return std::nullopt;
2502
2503 APInt Divisor = SplatConstantInt->getValue();
2504 const int64_t DivisorValue = Divisor.getSExtValue();
2505 if (DivisorValue == -1)
2506 return std::nullopt;
2507 if (DivisorValue == 1)
2508 IC.replaceInstUsesWith(II, Vec);
2509
2510 if (Divisor.isPowerOf2()) {
2511 Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
2512 auto ASRD = IC.Builder.CreateIntrinsic(
2513 Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
2514 return IC.replaceInstUsesWith(II, ASRD);
2515 }
2516 if (Divisor.isNegatedPowerOf2()) {
2517 Divisor.negate();
2518 Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2());
2519 auto ASRD = IC.Builder.CreateIntrinsic(
2520 Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2});
2521 auto NEG = IC.Builder.CreateIntrinsic(
2522 Intrinsic::aarch64_sve_neg, {ASRD->getType()}, {ASRD, Pred, ASRD});
2523 return IC.replaceInstUsesWith(II, NEG);
2524 }
2525
2526 return std::nullopt;
2527 }
2528
SimplifyValuePattern(SmallVector<Value * > & Vec,bool AllowPoison)2529 bool SimplifyValuePattern(SmallVector<Value *> &Vec, bool AllowPoison) {
2530 size_t VecSize = Vec.size();
2531 if (VecSize == 1)
2532 return true;
2533 if (!isPowerOf2_64(VecSize))
2534 return false;
2535 size_t HalfVecSize = VecSize / 2;
2536
2537 for (auto LHS = Vec.begin(), RHS = Vec.begin() + HalfVecSize;
2538 RHS != Vec.end(); LHS++, RHS++) {
2539 if (*LHS != nullptr && *RHS != nullptr) {
2540 if (*LHS == *RHS)
2541 continue;
2542 else
2543 return false;
2544 }
2545 if (!AllowPoison)
2546 return false;
2547 if (*LHS == nullptr && *RHS != nullptr)
2548 *LHS = *RHS;
2549 }
2550
2551 Vec.resize(HalfVecSize);
2552 SimplifyValuePattern(Vec, AllowPoison);
2553 return true;
2554 }
2555
2556 // Try to simplify dupqlane patterns like dupqlane(f32 A, f32 B, f32 A, f32 B)
2557 // to dupqlane(f64(C)) where C is A concatenated with B
instCombineSVEDupqLane(InstCombiner & IC,IntrinsicInst & II)2558 static std::optional<Instruction *> instCombineSVEDupqLane(InstCombiner &IC,
2559 IntrinsicInst &II) {
2560 Value *CurrentInsertElt = nullptr, *Default = nullptr;
2561 if (!match(II.getOperand(0),
2562 m_Intrinsic<Intrinsic::vector_insert>(
2563 m_Value(Default), m_Value(CurrentInsertElt), m_Value())) ||
2564 !isa<FixedVectorType>(CurrentInsertElt->getType()))
2565 return std::nullopt;
2566 auto IIScalableTy = cast<ScalableVectorType>(II.getType());
2567
2568 // Insert the scalars into a container ordered by InsertElement index
2569 SmallVector<Value *> Elts(IIScalableTy->getMinNumElements(), nullptr);
2570 while (auto InsertElt = dyn_cast<InsertElementInst>(CurrentInsertElt)) {
2571 auto Idx = cast<ConstantInt>(InsertElt->getOperand(2));
2572 Elts[Idx->getValue().getZExtValue()] = InsertElt->getOperand(1);
2573 CurrentInsertElt = InsertElt->getOperand(0);
2574 }
2575
2576 bool AllowPoison =
2577 isa<PoisonValue>(CurrentInsertElt) && isa<PoisonValue>(Default);
2578 if (!SimplifyValuePattern(Elts, AllowPoison))
2579 return std::nullopt;
2580
2581 // Rebuild the simplified chain of InsertElements. e.g. (a, b, a, b) as (a, b)
2582 Value *InsertEltChain = PoisonValue::get(CurrentInsertElt->getType());
2583 for (size_t I = 0; I < Elts.size(); I++) {
2584 if (Elts[I] == nullptr)
2585 continue;
2586 InsertEltChain = IC.Builder.CreateInsertElement(InsertEltChain, Elts[I],
2587 IC.Builder.getInt64(I));
2588 }
2589 if (InsertEltChain == nullptr)
2590 return std::nullopt;
2591
2592 // Splat the simplified sequence, e.g. (f16 a, f16 b, f16 c, f16 d) as one i64
2593 // value or (f16 a, f16 b) as one i32 value. This requires an InsertSubvector
2594 // be bitcast to a type wide enough to fit the sequence, be splatted, and then
2595 // be narrowed back to the original type.
2596 unsigned PatternWidth = IIScalableTy->getScalarSizeInBits() * Elts.size();
2597 unsigned PatternElementCount = IIScalableTy->getScalarSizeInBits() *
2598 IIScalableTy->getMinNumElements() /
2599 PatternWidth;
2600
2601 IntegerType *WideTy = IC.Builder.getIntNTy(PatternWidth);
2602 auto *WideScalableTy = ScalableVectorType::get(WideTy, PatternElementCount);
2603 auto *WideShuffleMaskTy =
2604 ScalableVectorType::get(IC.Builder.getInt32Ty(), PatternElementCount);
2605
2606 auto InsertSubvector = IC.Builder.CreateInsertVector(
2607 II.getType(), PoisonValue::get(II.getType()), InsertEltChain,
2608 uint64_t(0));
2609 auto WideBitcast =
2610 IC.Builder.CreateBitOrPointerCast(InsertSubvector, WideScalableTy);
2611 auto WideShuffleMask = ConstantAggregateZero::get(WideShuffleMaskTy);
2612 auto WideShuffle = IC.Builder.CreateShuffleVector(
2613 WideBitcast, PoisonValue::get(WideScalableTy), WideShuffleMask);
2614 auto NarrowBitcast =
2615 IC.Builder.CreateBitOrPointerCast(WideShuffle, II.getType());
2616
2617 return IC.replaceInstUsesWith(II, NarrowBitcast);
2618 }
2619
instCombineMaxMinNM(InstCombiner & IC,IntrinsicInst & II)2620 static std::optional<Instruction *> instCombineMaxMinNM(InstCombiner &IC,
2621 IntrinsicInst &II) {
2622 Value *A = II.getArgOperand(0);
2623 Value *B = II.getArgOperand(1);
2624 if (A == B)
2625 return IC.replaceInstUsesWith(II, A);
2626
2627 return std::nullopt;
2628 }
2629
instCombineSVESrshl(InstCombiner & IC,IntrinsicInst & II)2630 static std::optional<Instruction *> instCombineSVESrshl(InstCombiner &IC,
2631 IntrinsicInst &II) {
2632 Value *Pred = II.getOperand(0);
2633 Value *Vec = II.getOperand(1);
2634 Value *Shift = II.getOperand(2);
2635
2636 // Convert SRSHL into the simpler LSL intrinsic when fed by an ABS intrinsic.
2637 Value *AbsPred, *MergedValue;
2638 if (!match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_sqabs>(
2639 m_Value(MergedValue), m_Value(AbsPred), m_Value())) &&
2640 !match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_abs>(
2641 m_Value(MergedValue), m_Value(AbsPred), m_Value())))
2642
2643 return std::nullopt;
2644
2645 // Transform is valid if any of the following are true:
2646 // * The ABS merge value is an undef or non-negative
2647 // * The ABS predicate is all active
2648 // * The ABS predicate and the SRSHL predicates are the same
2649 if (!isa<UndefValue>(MergedValue) && !match(MergedValue, m_NonNegative()) &&
2650 AbsPred != Pred && !isAllActivePredicate(AbsPred))
2651 return std::nullopt;
2652
2653 // Only valid when the shift amount is non-negative, otherwise the rounding
2654 // behaviour of SRSHL cannot be ignored.
2655 if (!match(Shift, m_NonNegative()))
2656 return std::nullopt;
2657
2658 auto LSL = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_lsl,
2659 {II.getType()}, {Pred, Vec, Shift});
2660
2661 return IC.replaceInstUsesWith(II, LSL);
2662 }
2663
instCombineSVEInsr(InstCombiner & IC,IntrinsicInst & II)2664 static std::optional<Instruction *> instCombineSVEInsr(InstCombiner &IC,
2665 IntrinsicInst &II) {
2666 Value *Vec = II.getOperand(0);
2667
2668 if (getSplatValue(Vec) == II.getOperand(1))
2669 return IC.replaceInstUsesWith(II, Vec);
2670
2671 return std::nullopt;
2672 }
2673
instCombineDMB(InstCombiner & IC,IntrinsicInst & II)2674 static std::optional<Instruction *> instCombineDMB(InstCombiner &IC,
2675 IntrinsicInst &II) {
2676 // If this barrier is post-dominated by identical one we can remove it
2677 auto *NI = II.getNextNonDebugInstruction();
2678 unsigned LookaheadThreshold = DMBLookaheadThreshold;
2679 auto CanSkipOver = [](Instruction *I) {
2680 return !I->mayReadOrWriteMemory() && !I->mayHaveSideEffects();
2681 };
2682 while (LookaheadThreshold-- && CanSkipOver(NI)) {
2683 auto *NIBB = NI->getParent();
2684 NI = NI->getNextNonDebugInstruction();
2685 if (!NI) {
2686 if (auto *SuccBB = NIBB->getUniqueSuccessor())
2687 NI = &*SuccBB->getFirstNonPHIOrDbgOrLifetime();
2688 else
2689 break;
2690 }
2691 }
2692 auto *NextII = dyn_cast_or_null<IntrinsicInst>(NI);
2693 if (NextII && II.isIdenticalTo(NextII))
2694 return IC.eraseInstFromFunction(II);
2695
2696 return std::nullopt;
2697 }
2698
instCombinePTrue(InstCombiner & IC,IntrinsicInst & II)2699 static std::optional<Instruction *> instCombinePTrue(InstCombiner &IC,
2700 IntrinsicInst &II) {
2701 if (match(II.getOperand(0), m_ConstantInt<AArch64SVEPredPattern::all>()))
2702 return IC.replaceInstUsesWith(II, Constant::getAllOnesValue(II.getType()));
2703 return std::nullopt;
2704 }
2705
instCombineSVEUxt(InstCombiner & IC,IntrinsicInst & II,unsigned NumBits)2706 static std::optional<Instruction *> instCombineSVEUxt(InstCombiner &IC,
2707 IntrinsicInst &II,
2708 unsigned NumBits) {
2709 Value *Passthru = II.getOperand(0);
2710 Value *Pg = II.getOperand(1);
2711 Value *Op = II.getOperand(2);
2712
2713 // Convert UXT[BHW] to AND.
2714 if (isa<UndefValue>(Passthru) || isAllActivePredicate(Pg)) {
2715 auto *Ty = cast<VectorType>(II.getType());
2716 auto MaskValue = APInt::getLowBitsSet(Ty->getScalarSizeInBits(), NumBits);
2717 auto *Mask = ConstantInt::get(Ty, MaskValue);
2718 auto *And = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_and_u, {Ty},
2719 {Pg, Op, Mask});
2720 return IC.replaceInstUsesWith(II, And);
2721 }
2722
2723 return std::nullopt;
2724 }
2725
2726 static std::optional<Instruction *>
instCombineInStreamingMode(InstCombiner & IC,IntrinsicInst & II)2727 instCombineInStreamingMode(InstCombiner &IC, IntrinsicInst &II) {
2728 SMEAttrs FnSMEAttrs(*II.getFunction());
2729 bool IsStreaming = FnSMEAttrs.hasStreamingInterfaceOrBody();
2730 if (IsStreaming || !FnSMEAttrs.hasStreamingCompatibleInterface())
2731 return IC.replaceInstUsesWith(
2732 II, ConstantInt::getBool(II.getType(), IsStreaming));
2733 return std::nullopt;
2734 }
2735
2736 std::optional<Instruction *>
instCombineIntrinsic(InstCombiner & IC,IntrinsicInst & II) const2737 AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC,
2738 IntrinsicInst &II) const {
2739 const SVEIntrinsicInfo &IInfo = constructSVEIntrinsicInfo(II);
2740 if (std::optional<Instruction *> I = simplifySVEIntrinsic(IC, II, IInfo))
2741 return I;
2742
2743 Intrinsic::ID IID = II.getIntrinsicID();
2744 switch (IID) {
2745 default:
2746 break;
2747 case Intrinsic::aarch64_dmb:
2748 return instCombineDMB(IC, II);
2749 case Intrinsic::aarch64_neon_fmaxnm:
2750 case Intrinsic::aarch64_neon_fminnm:
2751 return instCombineMaxMinNM(IC, II);
2752 case Intrinsic::aarch64_sve_convert_from_svbool:
2753 return instCombineConvertFromSVBool(IC, II);
2754 case Intrinsic::aarch64_sve_dup:
2755 return instCombineSVEDup(IC, II);
2756 case Intrinsic::aarch64_sve_dup_x:
2757 return instCombineSVEDupX(IC, II);
2758 case Intrinsic::aarch64_sve_cmpne:
2759 case Intrinsic::aarch64_sve_cmpne_wide:
2760 return instCombineSVECmpNE(IC, II);
2761 case Intrinsic::aarch64_sve_rdffr:
2762 return instCombineRDFFR(IC, II);
2763 case Intrinsic::aarch64_sve_lasta:
2764 case Intrinsic::aarch64_sve_lastb:
2765 return instCombineSVELast(IC, II);
2766 case Intrinsic::aarch64_sve_clasta_n:
2767 case Intrinsic::aarch64_sve_clastb_n:
2768 return instCombineSVECondLast(IC, II);
2769 case Intrinsic::aarch64_sve_cntd:
2770 return instCombineSVECntElts(IC, II, 2);
2771 case Intrinsic::aarch64_sve_cntw:
2772 return instCombineSVECntElts(IC, II, 4);
2773 case Intrinsic::aarch64_sve_cnth:
2774 return instCombineSVECntElts(IC, II, 8);
2775 case Intrinsic::aarch64_sve_cntb:
2776 return instCombineSVECntElts(IC, II, 16);
2777 case Intrinsic::aarch64_sve_ptest_any:
2778 case Intrinsic::aarch64_sve_ptest_first:
2779 case Intrinsic::aarch64_sve_ptest_last:
2780 return instCombineSVEPTest(IC, II);
2781 case Intrinsic::aarch64_sve_fadd:
2782 return instCombineSVEVectorFAdd(IC, II);
2783 case Intrinsic::aarch64_sve_fadd_u:
2784 return instCombineSVEVectorFAddU(IC, II);
2785 case Intrinsic::aarch64_sve_fmul_u:
2786 return instCombineSVEVectorBinOp(IC, II);
2787 case Intrinsic::aarch64_sve_fsub:
2788 return instCombineSVEVectorFSub(IC, II);
2789 case Intrinsic::aarch64_sve_fsub_u:
2790 return instCombineSVEVectorFSubU(IC, II);
2791 case Intrinsic::aarch64_sve_add:
2792 return instCombineSVEVectorAdd(IC, II);
2793 case Intrinsic::aarch64_sve_add_u:
2794 return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul_u,
2795 Intrinsic::aarch64_sve_mla_u>(
2796 IC, II, true);
2797 case Intrinsic::aarch64_sve_sub:
2798 return instCombineSVEVectorSub(IC, II);
2799 case Intrinsic::aarch64_sve_sub_u:
2800 return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul_u,
2801 Intrinsic::aarch64_sve_mls_u>(
2802 IC, II, true);
2803 case Intrinsic::aarch64_sve_tbl:
2804 return instCombineSVETBL(IC, II);
2805 case Intrinsic::aarch64_sve_uunpkhi:
2806 case Intrinsic::aarch64_sve_uunpklo:
2807 case Intrinsic::aarch64_sve_sunpkhi:
2808 case Intrinsic::aarch64_sve_sunpklo:
2809 return instCombineSVEUnpack(IC, II);
2810 case Intrinsic::aarch64_sve_uzp1:
2811 return instCombineSVEUzp1(IC, II);
2812 case Intrinsic::aarch64_sve_zip1:
2813 case Intrinsic::aarch64_sve_zip2:
2814 return instCombineSVEZip(IC, II);
2815 case Intrinsic::aarch64_sve_ld1_gather_index:
2816 return instCombineLD1GatherIndex(IC, II);
2817 case Intrinsic::aarch64_sve_st1_scatter_index:
2818 return instCombineST1ScatterIndex(IC, II);
2819 case Intrinsic::aarch64_sve_ld1:
2820 return instCombineSVELD1(IC, II, DL);
2821 case Intrinsic::aarch64_sve_st1:
2822 return instCombineSVEST1(IC, II, DL);
2823 case Intrinsic::aarch64_sve_sdiv:
2824 return instCombineSVESDIV(IC, II);
2825 case Intrinsic::aarch64_sve_sel:
2826 return instCombineSVESel(IC, II);
2827 case Intrinsic::aarch64_sve_srshl:
2828 return instCombineSVESrshl(IC, II);
2829 case Intrinsic::aarch64_sve_dupq_lane:
2830 return instCombineSVEDupqLane(IC, II);
2831 case Intrinsic::aarch64_sve_insr:
2832 return instCombineSVEInsr(IC, II);
2833 case Intrinsic::aarch64_sve_ptrue:
2834 return instCombinePTrue(IC, II);
2835 case Intrinsic::aarch64_sve_uxtb:
2836 return instCombineSVEUxt(IC, II, 8);
2837 case Intrinsic::aarch64_sve_uxth:
2838 return instCombineSVEUxt(IC, II, 16);
2839 case Intrinsic::aarch64_sve_uxtw:
2840 return instCombineSVEUxt(IC, II, 32);
2841 case Intrinsic::aarch64_sme_in_streaming_mode:
2842 return instCombineInStreamingMode(IC, II);
2843 }
2844
2845 return std::nullopt;
2846 }
2847
simplifyDemandedVectorEltsIntrinsic(InstCombiner & IC,IntrinsicInst & II,APInt OrigDemandedElts,APInt & UndefElts,APInt & UndefElts2,APInt & UndefElts3,std::function<void (Instruction *,unsigned,APInt,APInt &)> SimplifyAndSetOp) const2848 std::optional<Value *> AArch64TTIImpl::simplifyDemandedVectorEltsIntrinsic(
2849 InstCombiner &IC, IntrinsicInst &II, APInt OrigDemandedElts,
2850 APInt &UndefElts, APInt &UndefElts2, APInt &UndefElts3,
2851 std::function<void(Instruction *, unsigned, APInt, APInt &)>
2852 SimplifyAndSetOp) const {
2853 switch (II.getIntrinsicID()) {
2854 default:
2855 break;
2856 case Intrinsic::aarch64_neon_fcvtxn:
2857 case Intrinsic::aarch64_neon_rshrn:
2858 case Intrinsic::aarch64_neon_sqrshrn:
2859 case Intrinsic::aarch64_neon_sqrshrun:
2860 case Intrinsic::aarch64_neon_sqshrn:
2861 case Intrinsic::aarch64_neon_sqshrun:
2862 case Intrinsic::aarch64_neon_sqxtn:
2863 case Intrinsic::aarch64_neon_sqxtun:
2864 case Intrinsic::aarch64_neon_uqrshrn:
2865 case Intrinsic::aarch64_neon_uqshrn:
2866 case Intrinsic::aarch64_neon_uqxtn:
2867 SimplifyAndSetOp(&II, 0, OrigDemandedElts, UndefElts);
2868 break;
2869 }
2870
2871 return std::nullopt;
2872 }
2873
enableScalableVectorization() const2874 bool AArch64TTIImpl::enableScalableVectorization() const {
2875 return ST->isSVEAvailable() || (ST->isSVEorStreamingSVEAvailable() &&
2876 EnableScalableAutovecInStreamingMode);
2877 }
2878
2879 TypeSize
getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const2880 AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const {
2881 switch (K) {
2882 case TargetTransformInfo::RGK_Scalar:
2883 return TypeSize::getFixed(64);
2884 case TargetTransformInfo::RGK_FixedWidthVector:
2885 if (ST->useSVEForFixedLengthVectors() &&
2886 (ST->isSVEAvailable() || EnableFixedwidthAutovecInStreamingMode))
2887 return TypeSize::getFixed(
2888 std::max(ST->getMinSVEVectorSizeInBits(), 128u));
2889 else if (ST->isNeonAvailable())
2890 return TypeSize::getFixed(128);
2891 else
2892 return TypeSize::getFixed(0);
2893 case TargetTransformInfo::RGK_ScalableVector:
2894 if (ST->isSVEAvailable() || (ST->isSVEorStreamingSVEAvailable() &&
2895 EnableScalableAutovecInStreamingMode))
2896 return TypeSize::getScalable(128);
2897 else
2898 return TypeSize::getScalable(0);
2899 }
2900 llvm_unreachable("Unsupported register kind");
2901 }
2902
isWideningInstruction(Type * DstTy,unsigned Opcode,ArrayRef<const Value * > Args,Type * SrcOverrideTy) const2903 bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode,
2904 ArrayRef<const Value *> Args,
2905 Type *SrcOverrideTy) const {
2906 // A helper that returns a vector type from the given type. The number of
2907 // elements in type Ty determines the vector width.
2908 auto toVectorTy = [&](Type *ArgTy) {
2909 return VectorType::get(ArgTy->getScalarType(),
2910 cast<VectorType>(DstTy)->getElementCount());
2911 };
2912
2913 // Exit early if DstTy is not a vector type whose elements are one of [i16,
2914 // i32, i64]. SVE doesn't generally have the same set of instructions to
2915 // perform an extend with the add/sub/mul. There are SMULLB style
2916 // instructions, but they operate on top/bottom, requiring some sort of lane
2917 // interleaving to be used with zext/sext.
2918 unsigned DstEltSize = DstTy->getScalarSizeInBits();
2919 if (!useNeonVector(DstTy) || Args.size() != 2 ||
2920 (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64))
2921 return false;
2922
2923 // Determine if the operation has a widening variant. We consider both the
2924 // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the
2925 // instructions.
2926 //
2927 // TODO: Add additional widening operations (e.g., shl, etc.) once we
2928 // verify that their extending operands are eliminated during code
2929 // generation.
2930 Type *SrcTy = SrcOverrideTy;
2931 switch (Opcode) {
2932 case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2).
2933 case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2).
2934 // The second operand needs to be an extend
2935 if (isa<SExtInst>(Args[1]) || isa<ZExtInst>(Args[1])) {
2936 if (!SrcTy)
2937 SrcTy =
2938 toVectorTy(cast<Instruction>(Args[1])->getOperand(0)->getType());
2939 } else
2940 return false;
2941 break;
2942 case Instruction::Mul: { // SMULL(2), UMULL(2)
2943 // Both operands need to be extends of the same type.
2944 if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) ||
2945 (isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) {
2946 if (!SrcTy)
2947 SrcTy =
2948 toVectorTy(cast<Instruction>(Args[0])->getOperand(0)->getType());
2949 } else if (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1])) {
2950 // If one of the operands is a Zext and the other has enough zero bits to
2951 // be treated as unsigned, we can still general a umull, meaning the zext
2952 // is free.
2953 KnownBits Known =
2954 computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL);
2955 if (Args[0]->getType()->getScalarSizeInBits() -
2956 Known.Zero.countLeadingOnes() >
2957 DstTy->getScalarSizeInBits() / 2)
2958 return false;
2959 if (!SrcTy)
2960 SrcTy = toVectorTy(Type::getIntNTy(DstTy->getContext(),
2961 DstTy->getScalarSizeInBits() / 2));
2962 } else
2963 return false;
2964 break;
2965 }
2966 default:
2967 return false;
2968 }
2969
2970 // Legalize the destination type and ensure it can be used in a widening
2971 // operation.
2972 auto DstTyL = getTypeLegalizationCost(DstTy);
2973 if (!DstTyL.second.isVector() || DstEltSize != DstTy->getScalarSizeInBits())
2974 return false;
2975
2976 // Legalize the source type and ensure it can be used in a widening
2977 // operation.
2978 assert(SrcTy && "Expected some SrcTy");
2979 auto SrcTyL = getTypeLegalizationCost(SrcTy);
2980 unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits();
2981 if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits())
2982 return false;
2983
2984 // Get the total number of vector elements in the legalized types.
2985 InstructionCost NumDstEls =
2986 DstTyL.first * DstTyL.second.getVectorMinNumElements();
2987 InstructionCost NumSrcEls =
2988 SrcTyL.first * SrcTyL.second.getVectorMinNumElements();
2989
2990 // Return true if the legalized types have the same number of vector elements
2991 // and the destination element type size is twice that of the source type.
2992 return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize;
2993 }
2994
2995 // s/urhadd instructions implement the following pattern, making the
2996 // extends free:
2997 // %x = add ((zext i8 -> i16), 1)
2998 // %y = (zext i8 -> i16)
2999 // trunc i16 (lshr (add %x, %y), 1) -> i8
3000 //
isExtPartOfAvgExpr(const Instruction * ExtUser,Type * Dst,Type * Src) const3001 bool AArch64TTIImpl::isExtPartOfAvgExpr(const Instruction *ExtUser, Type *Dst,
3002 Type *Src) const {
3003 // The source should be a legal vector type.
3004 if (!Src->isVectorTy() || !TLI->isTypeLegal(TLI->getValueType(DL, Src)) ||
3005 (Src->isScalableTy() && !ST->hasSVE2()))
3006 return false;
3007
3008 if (ExtUser->getOpcode() != Instruction::Add || !ExtUser->hasOneUse())
3009 return false;
3010
3011 // Look for trunc/shl/add before trying to match the pattern.
3012 const Instruction *Add = ExtUser;
3013 auto *AddUser =
3014 dyn_cast_or_null<Instruction>(Add->getUniqueUndroppableUser());
3015 if (AddUser && AddUser->getOpcode() == Instruction::Add)
3016 Add = AddUser;
3017
3018 auto *Shr = dyn_cast_or_null<Instruction>(Add->getUniqueUndroppableUser());
3019 if (!Shr || Shr->getOpcode() != Instruction::LShr)
3020 return false;
3021
3022 auto *Trunc = dyn_cast_or_null<Instruction>(Shr->getUniqueUndroppableUser());
3023 if (!Trunc || Trunc->getOpcode() != Instruction::Trunc ||
3024 Src->getScalarSizeInBits() !=
3025 cast<CastInst>(Trunc)->getDestTy()->getScalarSizeInBits())
3026 return false;
3027
3028 // Try to match the whole pattern. Ext could be either the first or second
3029 // m_ZExtOrSExt matched.
3030 Instruction *Ex1, *Ex2;
3031 if (!(match(Add, m_c_Add(m_Instruction(Ex1),
3032 m_c_Add(m_Instruction(Ex2), m_SpecificInt(1))))))
3033 return false;
3034
3035 // Ensure both extends are of the same type
3036 if (match(Ex1, m_ZExtOrSExt(m_Value())) &&
3037 Ex1->getOpcode() == Ex2->getOpcode())
3038 return true;
3039
3040 return false;
3041 }
3042
getCastInstrCost(unsigned Opcode,Type * Dst,Type * Src,TTI::CastContextHint CCH,TTI::TargetCostKind CostKind,const Instruction * I) const3043 InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst,
3044 Type *Src,
3045 TTI::CastContextHint CCH,
3046 TTI::TargetCostKind CostKind,
3047 const Instruction *I) const {
3048 int ISD = TLI->InstructionOpcodeToISD(Opcode);
3049 assert(ISD && "Invalid opcode");
3050 // If the cast is observable, and it is used by a widening instruction (e.g.,
3051 // uaddl, saddw, etc.), it may be free.
3052 if (I && I->hasOneUser()) {
3053 auto *SingleUser = cast<Instruction>(*I->user_begin());
3054 SmallVector<const Value *, 4> Operands(SingleUser->operand_values());
3055 if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands, Src)) {
3056 // For adds only count the second operand as free if both operands are
3057 // extends but not the same operation. (i.e both operands are not free in
3058 // add(sext, zext)).
3059 if (SingleUser->getOpcode() == Instruction::Add) {
3060 if (I == SingleUser->getOperand(1) ||
3061 (isa<CastInst>(SingleUser->getOperand(1)) &&
3062 cast<CastInst>(SingleUser->getOperand(1))->getOpcode() == Opcode))
3063 return 0;
3064 } else // Others are free so long as isWideningInstruction returned true.
3065 return 0;
3066 }
3067
3068 // The cast will be free for the s/urhadd instructions
3069 if ((isa<ZExtInst>(I) || isa<SExtInst>(I)) &&
3070 isExtPartOfAvgExpr(SingleUser, Dst, Src))
3071 return 0;
3072 }
3073
3074 // TODO: Allow non-throughput costs that aren't binary.
3075 auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost {
3076 if (CostKind != TTI::TCK_RecipThroughput)
3077 return Cost == 0 ? 0 : 1;
3078 return Cost;
3079 };
3080
3081 EVT SrcTy = TLI->getValueType(DL, Src);
3082 EVT DstTy = TLI->getValueType(DL, Dst);
3083
3084 if (!SrcTy.isSimple() || !DstTy.isSimple())
3085 return AdjustCost(
3086 BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
3087
3088 static const TypeConversionCostTblEntry BF16Tbl[] = {
3089 {ISD::FP_ROUND, MVT::bf16, MVT::f32, 1}, // bfcvt
3090 {ISD::FP_ROUND, MVT::bf16, MVT::f64, 1}, // bfcvt
3091 {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f32, 1}, // bfcvtn
3092 {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f32, 2}, // bfcvtn+bfcvtn2
3093 {ISD::FP_ROUND, MVT::v2bf16, MVT::v2f64, 2}, // bfcvtn+fcvtn
3094 {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f64, 3}, // fcvtn+fcvtl2+bfcvtn
3095 {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f64, 6}, // 2 * fcvtn+fcvtn2+bfcvtn
3096 };
3097
3098 if (ST->hasBF16())
3099 if (const auto *Entry = ConvertCostTableLookup(
3100 BF16Tbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
3101 return AdjustCost(Entry->Cost);
3102
3103 // Symbolic constants for the SVE sitofp/uitofp entries in the table below
3104 // The cost of unpacking twice is artificially increased for now in order
3105 // to avoid regressions against NEON, which will use tbl instructions directly
3106 // instead of multiple layers of [s|u]unpk[lo|hi].
3107 // We use the unpacks in cases where the destination type is illegal and
3108 // requires splitting of the input, even if the input type itself is legal.
3109 const unsigned int SVE_EXT_COST = 1;
3110 const unsigned int SVE_FCVT_COST = 1;
3111 const unsigned int SVE_UNPACK_ONCE = 4;
3112 const unsigned int SVE_UNPACK_TWICE = 16;
3113
3114 static const TypeConversionCostTblEntry ConversionTbl[] = {
3115 {ISD::TRUNCATE, MVT::v2i8, MVT::v2i64, 1}, // xtn
3116 {ISD::TRUNCATE, MVT::v2i16, MVT::v2i64, 1}, // xtn
3117 {ISD::TRUNCATE, MVT::v2i32, MVT::v2i64, 1}, // xtn
3118 {ISD::TRUNCATE, MVT::v4i8, MVT::v4i32, 1}, // xtn
3119 {ISD::TRUNCATE, MVT::v4i8, MVT::v4i64, 3}, // 2 xtn + 1 uzp1
3120 {ISD::TRUNCATE, MVT::v4i16, MVT::v4i32, 1}, // xtn
3121 {ISD::TRUNCATE, MVT::v4i16, MVT::v4i64, 2}, // 1 uzp1 + 1 xtn
3122 {ISD::TRUNCATE, MVT::v4i32, MVT::v4i64, 1}, // 1 uzp1
3123 {ISD::TRUNCATE, MVT::v8i8, MVT::v8i16, 1}, // 1 xtn
3124 {ISD::TRUNCATE, MVT::v8i8, MVT::v8i32, 2}, // 1 uzp1 + 1 xtn
3125 {ISD::TRUNCATE, MVT::v8i8, MVT::v8i64, 4}, // 3 x uzp1 + xtn
3126 {ISD::TRUNCATE, MVT::v8i16, MVT::v8i32, 1}, // 1 uzp1
3127 {ISD::TRUNCATE, MVT::v8i16, MVT::v8i64, 3}, // 3 x uzp1
3128 {ISD::TRUNCATE, MVT::v8i32, MVT::v8i64, 2}, // 2 x uzp1
3129 {ISD::TRUNCATE, MVT::v16i8, MVT::v16i16, 1}, // uzp1
3130 {ISD::TRUNCATE, MVT::v16i8, MVT::v16i32, 3}, // (2 + 1) x uzp1
3131 {ISD::TRUNCATE, MVT::v16i8, MVT::v16i64, 7}, // (4 + 2 + 1) x uzp1
3132 {ISD::TRUNCATE, MVT::v16i16, MVT::v16i32, 2}, // 2 x uzp1
3133 {ISD::TRUNCATE, MVT::v16i16, MVT::v16i64, 6}, // (4 + 2) x uzp1
3134 {ISD::TRUNCATE, MVT::v16i32, MVT::v16i64, 4}, // 4 x uzp1
3135
3136 // Truncations on nxvmiN
3137 {ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i8, 2},
3138 {ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i16, 2},
3139 {ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i32, 2},
3140 {ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i64, 2},
3141 {ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i8, 2},
3142 {ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i16, 2},
3143 {ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i32, 2},
3144 {ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i64, 5},
3145 {ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i8, 2},
3146 {ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i16, 2},
3147 {ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i32, 5},
3148 {ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i64, 11},
3149 {ISD::TRUNCATE, MVT::nxv16i1, MVT::nxv16i8, 2},
3150 {ISD::TRUNCATE, MVT::nxv2i8, MVT::nxv2i16, 0},
3151 {ISD::TRUNCATE, MVT::nxv2i8, MVT::nxv2i32, 0},
3152 {ISD::TRUNCATE, MVT::nxv2i8, MVT::nxv2i64, 0},
3153 {ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i32, 0},
3154 {ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i64, 0},
3155 {ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 0},
3156 {ISD::TRUNCATE, MVT::nxv4i8, MVT::nxv4i16, 0},
3157 {ISD::TRUNCATE, MVT::nxv4i8, MVT::nxv4i32, 0},
3158 {ISD::TRUNCATE, MVT::nxv4i8, MVT::nxv4i64, 1},
3159 {ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 0},
3160 {ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i64, 1},
3161 {ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 1},
3162 {ISD::TRUNCATE, MVT::nxv8i8, MVT::nxv8i16, 0},
3163 {ISD::TRUNCATE, MVT::nxv8i8, MVT::nxv8i32, 1},
3164 {ISD::TRUNCATE, MVT::nxv8i8, MVT::nxv8i64, 3},
3165 {ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 1},
3166 {ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i64, 3},
3167 {ISD::TRUNCATE, MVT::nxv16i8, MVT::nxv16i16, 1},
3168 {ISD::TRUNCATE, MVT::nxv16i8, MVT::nxv16i32, 3},
3169 {ISD::TRUNCATE, MVT::nxv16i8, MVT::nxv16i64, 7},
3170
3171 // The number of shll instructions for the extension.
3172 {ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i16, 3},
3173 {ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i16, 3},
3174 {ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i32, 2},
3175 {ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i32, 2},
3176 {ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i8, 3},
3177 {ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i8, 3},
3178 {ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i16, 2},
3179 {ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i16, 2},
3180 {ISD::SIGN_EXTEND, MVT::v8i64, MVT::v8i8, 7},
3181 {ISD::ZERO_EXTEND, MVT::v8i64, MVT::v8i8, 7},
3182 {ISD::SIGN_EXTEND, MVT::v8i64, MVT::v8i16, 6},
3183 {ISD::ZERO_EXTEND, MVT::v8i64, MVT::v8i16, 6},
3184 {ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i8, 2},
3185 {ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i8, 2},
3186 {ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i8, 6},
3187 {ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i8, 6},
3188
3189 // FP Ext and trunc
3190 {ISD::FP_EXTEND, MVT::f64, MVT::f32, 1}, // fcvt
3191 {ISD::FP_EXTEND, MVT::v2f64, MVT::v2f32, 1}, // fcvtl
3192 {ISD::FP_EXTEND, MVT::v4f64, MVT::v4f32, 2}, // fcvtl+fcvtl2
3193 // FP16
3194 {ISD::FP_EXTEND, MVT::f32, MVT::f16, 1}, // fcvt
3195 {ISD::FP_EXTEND, MVT::f64, MVT::f16, 1}, // fcvt
3196 {ISD::FP_EXTEND, MVT::v4f32, MVT::v4f16, 1}, // fcvtl
3197 {ISD::FP_EXTEND, MVT::v8f32, MVT::v8f16, 2}, // fcvtl+fcvtl2
3198 {ISD::FP_EXTEND, MVT::v2f64, MVT::v2f16, 2}, // fcvtl+fcvtl
3199 {ISD::FP_EXTEND, MVT::v4f64, MVT::v4f16, 3}, // fcvtl+fcvtl2+fcvtl
3200 {ISD::FP_EXTEND, MVT::v8f64, MVT::v8f16, 6}, // 2 * fcvtl+fcvtl2+fcvtl
3201 // BF16 (uses shift)
3202 {ISD::FP_EXTEND, MVT::f32, MVT::bf16, 1}, // shl
3203 {ISD::FP_EXTEND, MVT::f64, MVT::bf16, 2}, // shl+fcvt
3204 {ISD::FP_EXTEND, MVT::v4f32, MVT::v4bf16, 1}, // shll
3205 {ISD::FP_EXTEND, MVT::v8f32, MVT::v8bf16, 2}, // shll+shll2
3206 {ISD::FP_EXTEND, MVT::v2f64, MVT::v2bf16, 2}, // shll+fcvtl
3207 {ISD::FP_EXTEND, MVT::v4f64, MVT::v4bf16, 3}, // shll+fcvtl+fcvtl2
3208 {ISD::FP_EXTEND, MVT::v8f64, MVT::v8bf16, 6}, // 2 * shll+fcvtl+fcvtl2
3209 // FP Ext and trunc
3210 {ISD::FP_ROUND, MVT::f32, MVT::f64, 1}, // fcvt
3211 {ISD::FP_ROUND, MVT::v2f32, MVT::v2f64, 1}, // fcvtn
3212 {ISD::FP_ROUND, MVT::v4f32, MVT::v4f64, 2}, // fcvtn+fcvtn2
3213 // FP16
3214 {ISD::FP_ROUND, MVT::f16, MVT::f32, 1}, // fcvt
3215 {ISD::FP_ROUND, MVT::f16, MVT::f64, 1}, // fcvt
3216 {ISD::FP_ROUND, MVT::v4f16, MVT::v4f32, 1}, // fcvtn
3217 {ISD::FP_ROUND, MVT::v8f16, MVT::v8f32, 2}, // fcvtn+fcvtn2
3218 {ISD::FP_ROUND, MVT::v2f16, MVT::v2f64, 2}, // fcvtn+fcvtn
3219 {ISD::FP_ROUND, MVT::v4f16, MVT::v4f64, 3}, // fcvtn+fcvtn2+fcvtn
3220 {ISD::FP_ROUND, MVT::v8f16, MVT::v8f64, 6}, // 2 * fcvtn+fcvtn2+fcvtn
3221 // BF16 (more complex, with +bf16 is handled above)
3222 {ISD::FP_ROUND, MVT::bf16, MVT::f32, 8}, // Expansion is ~8 insns
3223 {ISD::FP_ROUND, MVT::bf16, MVT::f64, 9}, // fcvtn + above
3224 {ISD::FP_ROUND, MVT::v2bf16, MVT::v2f32, 8},
3225 {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f32, 8},
3226 {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f32, 15},
3227 {ISD::FP_ROUND, MVT::v2bf16, MVT::v2f64, 9},
3228 {ISD::FP_ROUND, MVT::v4bf16, MVT::v4f64, 10},
3229 {ISD::FP_ROUND, MVT::v8bf16, MVT::v8f64, 19},
3230
3231 // LowerVectorINT_TO_FP:
3232 {ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1},
3233 {ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 1},
3234 {ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i64, 1},
3235 {ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i32, 1},
3236 {ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 1},
3237 {ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i64, 1},
3238
3239 // SVE: to nxv2f16
3240 {ISD::SINT_TO_FP, MVT::nxv2f16, MVT::nxv2i8,
3241 SVE_EXT_COST + SVE_FCVT_COST},
3242 {ISD::SINT_TO_FP, MVT::nxv2f16, MVT::nxv2i16, SVE_FCVT_COST},
3243 {ISD::SINT_TO_FP, MVT::nxv2f16, MVT::nxv2i32, SVE_FCVT_COST},
3244 {ISD::SINT_TO_FP, MVT::nxv2f16, MVT::nxv2i64, SVE_FCVT_COST},
3245 {ISD::UINT_TO_FP, MVT::nxv2f16, MVT::nxv2i8,
3246 SVE_EXT_COST + SVE_FCVT_COST},
3247 {ISD::UINT_TO_FP, MVT::nxv2f16, MVT::nxv2i16, SVE_FCVT_COST},
3248 {ISD::UINT_TO_FP, MVT::nxv2f16, MVT::nxv2i32, SVE_FCVT_COST},
3249 {ISD::UINT_TO_FP, MVT::nxv2f16, MVT::nxv2i64, SVE_FCVT_COST},
3250
3251 // SVE: to nxv4f16
3252 {ISD::SINT_TO_FP, MVT::nxv4f16, MVT::nxv4i8,
3253 SVE_EXT_COST + SVE_FCVT_COST},
3254 {ISD::SINT_TO_FP, MVT::nxv4f16, MVT::nxv4i16, SVE_FCVT_COST},
3255 {ISD::SINT_TO_FP, MVT::nxv4f16, MVT::nxv4i32, SVE_FCVT_COST},
3256 {ISD::UINT_TO_FP, MVT::nxv4f16, MVT::nxv4i8,
3257 SVE_EXT_COST + SVE_FCVT_COST},
3258 {ISD::UINT_TO_FP, MVT::nxv4f16, MVT::nxv4i16, SVE_FCVT_COST},
3259 {ISD::UINT_TO_FP, MVT::nxv4f16, MVT::nxv4i32, SVE_FCVT_COST},
3260
3261 // SVE: to nxv8f16
3262 {ISD::SINT_TO_FP, MVT::nxv8f16, MVT::nxv8i8,
3263 SVE_EXT_COST + SVE_FCVT_COST},
3264 {ISD::SINT_TO_FP, MVT::nxv8f16, MVT::nxv8i16, SVE_FCVT_COST},
3265 {ISD::UINT_TO_FP, MVT::nxv8f16, MVT::nxv8i8,
3266 SVE_EXT_COST + SVE_FCVT_COST},
3267 {ISD::UINT_TO_FP, MVT::nxv8f16, MVT::nxv8i16, SVE_FCVT_COST},
3268
3269 // SVE: to nxv16f16
3270 {ISD::SINT_TO_FP, MVT::nxv16f16, MVT::nxv16i8,
3271 SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3272 {ISD::UINT_TO_FP, MVT::nxv16f16, MVT::nxv16i8,
3273 SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3274
3275 // Complex: to v2f32
3276 {ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i8, 3},
3277 {ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i16, 3},
3278 {ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i8, 3},
3279 {ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i16, 3},
3280
3281 // SVE: to nxv2f32
3282 {ISD::SINT_TO_FP, MVT::nxv2f32, MVT::nxv2i8,
3283 SVE_EXT_COST + SVE_FCVT_COST},
3284 {ISD::SINT_TO_FP, MVT::nxv2f32, MVT::nxv2i16, SVE_FCVT_COST},
3285 {ISD::SINT_TO_FP, MVT::nxv2f32, MVT::nxv2i32, SVE_FCVT_COST},
3286 {ISD::SINT_TO_FP, MVT::nxv2f32, MVT::nxv2i64, SVE_FCVT_COST},
3287 {ISD::UINT_TO_FP, MVT::nxv2f32, MVT::nxv2i8,
3288 SVE_EXT_COST + SVE_FCVT_COST},
3289 {ISD::UINT_TO_FP, MVT::nxv2f32, MVT::nxv2i16, SVE_FCVT_COST},
3290 {ISD::UINT_TO_FP, MVT::nxv2f32, MVT::nxv2i32, SVE_FCVT_COST},
3291 {ISD::UINT_TO_FP, MVT::nxv2f32, MVT::nxv2i64, SVE_FCVT_COST},
3292
3293 // Complex: to v4f32
3294 {ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i8, 4},
3295 {ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i16, 2},
3296 {ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8, 3},
3297 {ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i16, 2},
3298
3299 // SVE: to nxv4f32
3300 {ISD::SINT_TO_FP, MVT::nxv4f32, MVT::nxv4i8,
3301 SVE_EXT_COST + SVE_FCVT_COST},
3302 {ISD::SINT_TO_FP, MVT::nxv4f32, MVT::nxv4i16, SVE_FCVT_COST},
3303 {ISD::SINT_TO_FP, MVT::nxv4f32, MVT::nxv4i32, SVE_FCVT_COST},
3304 {ISD::UINT_TO_FP, MVT::nxv4f32, MVT::nxv4i8,
3305 SVE_EXT_COST + SVE_FCVT_COST},
3306 {ISD::UINT_TO_FP, MVT::nxv4f32, MVT::nxv4i16, SVE_FCVT_COST},
3307 {ISD::SINT_TO_FP, MVT::nxv4f32, MVT::nxv4i32, SVE_FCVT_COST},
3308
3309 // Complex: to v8f32
3310 {ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8, 10},
3311 {ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i16, 4},
3312 {ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8, 10},
3313 {ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i16, 4},
3314
3315 // SVE: to nxv8f32
3316 {ISD::SINT_TO_FP, MVT::nxv8f32, MVT::nxv8i8,
3317 SVE_EXT_COST + SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3318 {ISD::SINT_TO_FP, MVT::nxv8f32, MVT::nxv8i16,
3319 SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3320 {ISD::UINT_TO_FP, MVT::nxv8f32, MVT::nxv8i8,
3321 SVE_EXT_COST + SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3322 {ISD::UINT_TO_FP, MVT::nxv8f32, MVT::nxv8i16,
3323 SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3324
3325 // SVE: to nxv16f32
3326 {ISD::SINT_TO_FP, MVT::nxv16f32, MVT::nxv16i8,
3327 SVE_UNPACK_TWICE + 4 * SVE_FCVT_COST},
3328 {ISD::UINT_TO_FP, MVT::nxv16f32, MVT::nxv16i8,
3329 SVE_UNPACK_TWICE + 4 * SVE_FCVT_COST},
3330
3331 // Complex: to v16f32
3332 {ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i8, 21},
3333 {ISD::UINT_TO_FP, MVT::v16f32, MVT::v16i8, 21},
3334
3335 // Complex: to v2f64
3336 {ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i8, 4},
3337 {ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i16, 4},
3338 {ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i32, 2},
3339 {ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i8, 4},
3340 {ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i16, 4},
3341 {ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i32, 2},
3342
3343 // SVE: to nxv2f64
3344 {ISD::SINT_TO_FP, MVT::nxv2f64, MVT::nxv2i8,
3345 SVE_EXT_COST + SVE_FCVT_COST},
3346 {ISD::SINT_TO_FP, MVT::nxv2f64, MVT::nxv2i16, SVE_FCVT_COST},
3347 {ISD::SINT_TO_FP, MVT::nxv2f64, MVT::nxv2i32, SVE_FCVT_COST},
3348 {ISD::SINT_TO_FP, MVT::nxv2f64, MVT::nxv2i64, SVE_FCVT_COST},
3349 {ISD::UINT_TO_FP, MVT::nxv2f64, MVT::nxv2i8,
3350 SVE_EXT_COST + SVE_FCVT_COST},
3351 {ISD::UINT_TO_FP, MVT::nxv2f64, MVT::nxv2i16, SVE_FCVT_COST},
3352 {ISD::UINT_TO_FP, MVT::nxv2f64, MVT::nxv2i32, SVE_FCVT_COST},
3353 {ISD::UINT_TO_FP, MVT::nxv2f64, MVT::nxv2i64, SVE_FCVT_COST},
3354
3355 // Complex: to v4f64
3356 {ISD::SINT_TO_FP, MVT::v4f64, MVT::v4i32, 4},
3357 {ISD::UINT_TO_FP, MVT::v4f64, MVT::v4i32, 4},
3358
3359 // SVE: to nxv4f64
3360 {ISD::SINT_TO_FP, MVT::nxv4f64, MVT::nxv4i8,
3361 SVE_EXT_COST + SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3362 {ISD::SINT_TO_FP, MVT::nxv4f64, MVT::nxv4i16,
3363 SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3364 {ISD::SINT_TO_FP, MVT::nxv4f64, MVT::nxv4i32,
3365 SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3366 {ISD::UINT_TO_FP, MVT::nxv4f64, MVT::nxv4i8,
3367 SVE_EXT_COST + SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3368 {ISD::UINT_TO_FP, MVT::nxv4f64, MVT::nxv4i16,
3369 SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3370 {ISD::UINT_TO_FP, MVT::nxv4f64, MVT::nxv4i32,
3371 SVE_UNPACK_ONCE + 2 * SVE_FCVT_COST},
3372
3373 // SVE: to nxv8f64
3374 {ISD::SINT_TO_FP, MVT::nxv8f64, MVT::nxv8i8,
3375 SVE_EXT_COST + SVE_UNPACK_TWICE + 4 * SVE_FCVT_COST},
3376 {ISD::SINT_TO_FP, MVT::nxv8f64, MVT::nxv8i16,
3377 SVE_UNPACK_TWICE + 4 * SVE_FCVT_COST},
3378 {ISD::UINT_TO_FP, MVT::nxv8f64, MVT::nxv8i8,
3379 SVE_EXT_COST + SVE_UNPACK_TWICE + 4 * SVE_FCVT_COST},
3380 {ISD::UINT_TO_FP, MVT::nxv8f64, MVT::nxv8i16,
3381 SVE_UNPACK_TWICE + 4 * SVE_FCVT_COST},
3382
3383 // LowerVectorFP_TO_INT
3384 {ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f32, 1},
3385 {ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f32, 1},
3386 {ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f64, 1},
3387 {ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f32, 1},
3388 {ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f32, 1},
3389 {ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f64, 1},
3390
3391 // Complex, from v2f32: legal type is v2i32 (no cost) or v2i64 (1 ext).
3392 {ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f32, 2},
3393 {ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f32, 1},
3394 {ISD::FP_TO_SINT, MVT::v2i8, MVT::v2f32, 1},
3395 {ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f32, 2},
3396 {ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f32, 1},
3397 {ISD::FP_TO_UINT, MVT::v2i8, MVT::v2f32, 1},
3398
3399 // Complex, from v4f32: legal type is v4i16, 1 narrowing => ~2
3400 {ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f32, 2},
3401 {ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f32, 2},
3402 {ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 2},
3403 {ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f32, 2},
3404
3405 // Complex, from nxv2f32.
3406 {ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f32, 1},
3407 {ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f32, 1},
3408 {ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f32, 1},
3409 {ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f32, 1},
3410 {ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f32, 1},
3411 {ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f32, 1},
3412 {ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f32, 1},
3413 {ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f32, 1},
3414
3415 // Complex, from v2f64: legal type is v2i32, 1 narrowing => ~2.
3416 {ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f64, 2},
3417 {ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f64, 2},
3418 {ISD::FP_TO_SINT, MVT::v2i8, MVT::v2f64, 2},
3419 {ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f64, 2},
3420 {ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f64, 2},
3421 {ISD::FP_TO_UINT, MVT::v2i8, MVT::v2f64, 2},
3422
3423 // Complex, from nxv2f64.
3424 {ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f64, 1},
3425 {ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f64, 1},
3426 {ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f64, 1},
3427 {ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f64, 1},
3428 {ISD::FP_TO_SINT, MVT::nxv2i1, MVT::nxv2f64, 1},
3429 {ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f64, 1},
3430 {ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f64, 1},
3431 {ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f64, 1},
3432 {ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f64, 1},
3433 {ISD::FP_TO_UINT, MVT::nxv2i1, MVT::nxv2f64, 1},
3434
3435 // Complex, from nxv4f32.
3436 {ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f32, 4},
3437 {ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f32, 1},
3438 {ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f32, 1},
3439 {ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f32, 1},
3440 {ISD::FP_TO_SINT, MVT::nxv4i1, MVT::nxv4f32, 1},
3441 {ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f32, 4},
3442 {ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f32, 1},
3443 {ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f32, 1},
3444 {ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f32, 1},
3445 {ISD::FP_TO_UINT, MVT::nxv4i1, MVT::nxv4f32, 1},
3446
3447 // Complex, from nxv8f64. Illegal -> illegal conversions not required.
3448 {ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f64, 7},
3449 {ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f64, 7},
3450 {ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f64, 7},
3451 {ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f64, 7},
3452
3453 // Complex, from nxv4f64. Illegal -> illegal conversions not required.
3454 {ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f64, 3},
3455 {ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f64, 3},
3456 {ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f64, 3},
3457 {ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f64, 3},
3458 {ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f64, 3},
3459 {ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f64, 3},
3460
3461 // Complex, from nxv8f32. Illegal -> illegal conversions not required.
3462 {ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f32, 3},
3463 {ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f32, 3},
3464 {ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f32, 3},
3465 {ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f32, 3},
3466
3467 // Complex, from nxv8f16.
3468 {ISD::FP_TO_SINT, MVT::nxv8i64, MVT::nxv8f16, 10},
3469 {ISD::FP_TO_SINT, MVT::nxv8i32, MVT::nxv8f16, 4},
3470 {ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f16, 1},
3471 {ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f16, 1},
3472 {ISD::FP_TO_SINT, MVT::nxv8i1, MVT::nxv8f16, 1},
3473 {ISD::FP_TO_UINT, MVT::nxv8i64, MVT::nxv8f16, 10},
3474 {ISD::FP_TO_UINT, MVT::nxv8i32, MVT::nxv8f16, 4},
3475 {ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f16, 1},
3476 {ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f16, 1},
3477 {ISD::FP_TO_UINT, MVT::nxv8i1, MVT::nxv8f16, 1},
3478
3479 // Complex, from nxv4f16.
3480 {ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f16, 4},
3481 {ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f16, 1},
3482 {ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f16, 1},
3483 {ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f16, 1},
3484 {ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f16, 4},
3485 {ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f16, 1},
3486 {ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f16, 1},
3487 {ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f16, 1},
3488
3489 // Complex, from nxv2f16.
3490 {ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f16, 1},
3491 {ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f16, 1},
3492 {ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f16, 1},
3493 {ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f16, 1},
3494 {ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f16, 1},
3495 {ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f16, 1},
3496 {ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f16, 1},
3497 {ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f16, 1},
3498
3499 // Truncate from nxvmf32 to nxvmf16.
3500 {ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f32, 1},
3501 {ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1},
3502 {ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3},
3503
3504 // Truncate from nxvmf64 to nxvmf16.
3505 {ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1},
3506 {ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3},
3507 {ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7},
3508
3509 // Truncate from nxvmf64 to nxvmf32.
3510 {ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1},
3511 {ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3},
3512 {ISD::FP_ROUND, MVT::nxv8f32, MVT::nxv8f64, 6},
3513
3514 // Extend from nxvmf16 to nxvmf32.
3515 {ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2f16, 1},
3516 {ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1},
3517 {ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2},
3518
3519 // Extend from nxvmf16 to nxvmf64.
3520 {ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1},
3521 {ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2},
3522 {ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4},
3523
3524 // Extend from nxvmf32 to nxvmf64.
3525 {ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1},
3526 {ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2},
3527 {ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f32, 6},
3528
3529 // Bitcasts from float to integer
3530 {ISD::BITCAST, MVT::nxv2f16, MVT::nxv2i16, 0},
3531 {ISD::BITCAST, MVT::nxv4f16, MVT::nxv4i16, 0},
3532 {ISD::BITCAST, MVT::nxv2f32, MVT::nxv2i32, 0},
3533
3534 // Bitcasts from integer to float
3535 {ISD::BITCAST, MVT::nxv2i16, MVT::nxv2f16, 0},
3536 {ISD::BITCAST, MVT::nxv4i16, MVT::nxv4f16, 0},
3537 {ISD::BITCAST, MVT::nxv2i32, MVT::nxv2f32, 0},
3538
3539 // Add cost for extending to illegal -too wide- scalable vectors.
3540 // zero/sign extend are implemented by multiple unpack operations,
3541 // where each operation has a cost of 1.
3542 {ISD::ZERO_EXTEND, MVT::nxv16i16, MVT::nxv16i8, 2},
3543 {ISD::ZERO_EXTEND, MVT::nxv16i32, MVT::nxv16i8, 6},
3544 {ISD::ZERO_EXTEND, MVT::nxv16i64, MVT::nxv16i8, 14},
3545 {ISD::ZERO_EXTEND, MVT::nxv8i32, MVT::nxv8i16, 2},
3546 {ISD::ZERO_EXTEND, MVT::nxv8i64, MVT::nxv8i16, 6},
3547 {ISD::ZERO_EXTEND, MVT::nxv4i64, MVT::nxv4i32, 2},
3548
3549 {ISD::SIGN_EXTEND, MVT::nxv16i16, MVT::nxv16i8, 2},
3550 {ISD::SIGN_EXTEND, MVT::nxv16i32, MVT::nxv16i8, 6},
3551 {ISD::SIGN_EXTEND, MVT::nxv16i64, MVT::nxv16i8, 14},
3552 {ISD::SIGN_EXTEND, MVT::nxv8i32, MVT::nxv8i16, 2},
3553 {ISD::SIGN_EXTEND, MVT::nxv8i64, MVT::nxv8i16, 6},
3554 {ISD::SIGN_EXTEND, MVT::nxv4i64, MVT::nxv4i32, 2},
3555 };
3556
3557 // We have to estimate a cost of fixed length operation upon
3558 // SVE registers(operations) with the number of registers required
3559 // for a fixed type to be represented upon SVE registers.
3560 EVT WiderTy = SrcTy.bitsGT(DstTy) ? SrcTy : DstTy;
3561 if (SrcTy.isFixedLengthVector() && DstTy.isFixedLengthVector() &&
3562 SrcTy.getVectorNumElements() == DstTy.getVectorNumElements() &&
3563 ST->useSVEForFixedLengthVectors(WiderTy)) {
3564 std::pair<InstructionCost, MVT> LT =
3565 getTypeLegalizationCost(WiderTy.getTypeForEVT(Dst->getContext()));
3566 unsigned NumElements =
3567 AArch64::SVEBitsPerBlock / LT.second.getScalarSizeInBits();
3568 return AdjustCost(
3569 LT.first *
3570 getCastInstrCost(
3571 Opcode, ScalableVectorType::get(Dst->getScalarType(), NumElements),
3572 ScalableVectorType::get(Src->getScalarType(), NumElements), CCH,
3573 CostKind, I));
3574 }
3575
3576 if (const auto *Entry = ConvertCostTableLookup(
3577 ConversionTbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
3578 return AdjustCost(Entry->Cost);
3579
3580 static const TypeConversionCostTblEntry FP16Tbl[] = {
3581 {ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f16, 1}, // fcvtzs
3582 {ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f16, 1},
3583 {ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f16, 1}, // fcvtzs
3584 {ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f16, 1},
3585 {ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f16, 2}, // fcvtl+fcvtzs
3586 {ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f16, 2},
3587 {ISD::FP_TO_SINT, MVT::v8i8, MVT::v8f16, 2}, // fcvtzs+xtn
3588 {ISD::FP_TO_UINT, MVT::v8i8, MVT::v8f16, 2},
3589 {ISD::FP_TO_SINT, MVT::v8i16, MVT::v8f16, 1}, // fcvtzs
3590 {ISD::FP_TO_UINT, MVT::v8i16, MVT::v8f16, 1},
3591 {ISD::FP_TO_SINT, MVT::v8i32, MVT::v8f16, 4}, // 2*fcvtl+2*fcvtzs
3592 {ISD::FP_TO_UINT, MVT::v8i32, MVT::v8f16, 4},
3593 {ISD::FP_TO_SINT, MVT::v16i8, MVT::v16f16, 3}, // 2*fcvtzs+xtn
3594 {ISD::FP_TO_UINT, MVT::v16i8, MVT::v16f16, 3},
3595 {ISD::FP_TO_SINT, MVT::v16i16, MVT::v16f16, 2}, // 2*fcvtzs
3596 {ISD::FP_TO_UINT, MVT::v16i16, MVT::v16f16, 2},
3597 {ISD::FP_TO_SINT, MVT::v16i32, MVT::v16f16, 8}, // 4*fcvtl+4*fcvtzs
3598 {ISD::FP_TO_UINT, MVT::v16i32, MVT::v16f16, 8},
3599 {ISD::UINT_TO_FP, MVT::v8f16, MVT::v8i8, 2}, // ushll + ucvtf
3600 {ISD::SINT_TO_FP, MVT::v8f16, MVT::v8i8, 2}, // sshll + scvtf
3601 {ISD::UINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * ushl(2) + 2 * ucvtf
3602 {ISD::SINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * sshl(2) + 2 * scvtf
3603 };
3604
3605 if (ST->hasFullFP16())
3606 if (const auto *Entry = ConvertCostTableLookup(
3607 FP16Tbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT()))
3608 return AdjustCost(Entry->Cost);
3609
3610 // INT_TO_FP of i64->f32 will scalarize, which is required to avoid
3611 // double-rounding issues.
3612 if ((ISD == ISD::SINT_TO_FP || ISD == ISD::UINT_TO_FP) &&
3613 DstTy.getScalarType() == MVT::f32 && SrcTy.getScalarSizeInBits() > 32 &&
3614 isa<FixedVectorType>(Dst) && isa<FixedVectorType>(Src))
3615 return AdjustCost(
3616 cast<FixedVectorType>(Dst)->getNumElements() *
3617 getCastInstrCost(Opcode, Dst->getScalarType(), Src->getScalarType(),
3618 CCH, CostKind) +
3619 BaseT::getScalarizationOverhead(cast<FixedVectorType>(Src), false, true,
3620 CostKind) +
3621 BaseT::getScalarizationOverhead(cast<FixedVectorType>(Dst), true, false,
3622 CostKind));
3623
3624 if ((ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND) &&
3625 CCH == TTI::CastContextHint::Masked &&
3626 ST->isSVEorStreamingSVEAvailable() &&
3627 TLI->getTypeAction(Src->getContext(), SrcTy) ==
3628 TargetLowering::TypePromoteInteger &&
3629 TLI->getTypeAction(Dst->getContext(), DstTy) ==
3630 TargetLowering::TypeSplitVector) {
3631 // The standard behaviour in the backend for these cases is to split the
3632 // extend up into two parts:
3633 // 1. Perform an extending load or masked load up to the legal type.
3634 // 2. Extend the loaded data to the final type.
3635 std::pair<InstructionCost, MVT> SrcLT = getTypeLegalizationCost(Src);
3636 Type *LegalTy = EVT(SrcLT.second).getTypeForEVT(Src->getContext());
3637 InstructionCost Part1 = AArch64TTIImpl::getCastInstrCost(
3638 Opcode, LegalTy, Src, CCH, CostKind, I);
3639 InstructionCost Part2 = AArch64TTIImpl::getCastInstrCost(
3640 Opcode, Dst, LegalTy, TTI::CastContextHint::None, CostKind, I);
3641 return Part1 + Part2;
3642 }
3643
3644 // The BasicTTIImpl version only deals with CCH==TTI::CastContextHint::Normal,
3645 // but we also want to include the TTI::CastContextHint::Masked case too.
3646 if ((ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND) &&
3647 CCH == TTI::CastContextHint::Masked &&
3648 ST->isSVEorStreamingSVEAvailable() && TLI->isTypeLegal(DstTy))
3649 CCH = TTI::CastContextHint::Normal;
3650
3651 return AdjustCost(
3652 BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I));
3653 }
3654
3655 InstructionCost
getExtractWithExtendCost(unsigned Opcode,Type * Dst,VectorType * VecTy,unsigned Index,TTI::TargetCostKind CostKind) const3656 AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, Type *Dst,
3657 VectorType *VecTy, unsigned Index,
3658 TTI::TargetCostKind CostKind) const {
3659
3660 // Make sure we were given a valid extend opcode.
3661 assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) &&
3662 "Invalid opcode");
3663
3664 // We are extending an element we extract from a vector, so the source type
3665 // of the extend is the element type of the vector.
3666 auto *Src = VecTy->getElementType();
3667
3668 // Sign- and zero-extends are for integer types only.
3669 assert(isa<IntegerType>(Dst) && isa<IntegerType>(Src) && "Invalid type");
3670
3671 // Get the cost for the extract. We compute the cost (if any) for the extend
3672 // below.
3673 InstructionCost Cost = getVectorInstrCost(Instruction::ExtractElement, VecTy,
3674 CostKind, Index, nullptr, nullptr);
3675
3676 // Legalize the types.
3677 auto VecLT = getTypeLegalizationCost(VecTy);
3678 auto DstVT = TLI->getValueType(DL, Dst);
3679 auto SrcVT = TLI->getValueType(DL, Src);
3680
3681 // If the resulting type is still a vector and the destination type is legal,
3682 // we may get the extension for free. If not, get the default cost for the
3683 // extend.
3684 if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT))
3685 return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
3686 CostKind);
3687
3688 // The destination type should be larger than the element type. If not, get
3689 // the default cost for the extend.
3690 if (DstVT.getFixedSizeInBits() < SrcVT.getFixedSizeInBits())
3691 return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
3692 CostKind);
3693
3694 switch (Opcode) {
3695 default:
3696 llvm_unreachable("Opcode should be either SExt or ZExt");
3697
3698 // For sign-extends, we only need a smov, which performs the extension
3699 // automatically.
3700 case Instruction::SExt:
3701 return Cost;
3702
3703 // For zero-extends, the extend is performed automatically by a umov unless
3704 // the destination type is i64 and the element type is i8 or i16.
3705 case Instruction::ZExt:
3706 if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u)
3707 return Cost;
3708 }
3709
3710 // If we are unable to perform the extend for free, get the default cost.
3711 return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None,
3712 CostKind);
3713 }
3714
getCFInstrCost(unsigned Opcode,TTI::TargetCostKind CostKind,const Instruction * I) const3715 InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode,
3716 TTI::TargetCostKind CostKind,
3717 const Instruction *I) const {
3718 if (CostKind != TTI::TCK_RecipThroughput)
3719 return Opcode == Instruction::PHI ? 0 : 1;
3720 assert(CostKind == TTI::TCK_RecipThroughput && "unexpected CostKind");
3721 // Branches are assumed to be predicted.
3722 return 0;
3723 }
3724
getVectorInstrCostHelper(unsigned Opcode,Type * Val,TTI::TargetCostKind CostKind,unsigned Index,const Instruction * I,Value * Scalar,ArrayRef<std::tuple<Value *,User *,int>> ScalarUserAndIdx) const3725 InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(
3726 unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
3727 const Instruction *I, Value *Scalar,
3728 ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const {
3729 assert(Val->isVectorTy() && "This must be a vector type");
3730
3731 if (Index != -1U) {
3732 // Legalize the type.
3733 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Val);
3734
3735 // This type is legalized to a scalar type.
3736 if (!LT.second.isVector())
3737 return 0;
3738
3739 // The type may be split. For fixed-width vectors we can normalize the
3740 // index to the new type.
3741 if (LT.second.isFixedLengthVector()) {
3742 unsigned Width = LT.second.getVectorNumElements();
3743 Index = Index % Width;
3744 }
3745
3746 // The element at index zero is already inside the vector.
3747 // - For a insert-element or extract-element
3748 // instruction that extracts integers, an explicit FPR -> GPR move is
3749 // needed. So it has non-zero cost.
3750 if (Index == 0 && !Val->getScalarType()->isIntegerTy())
3751 return 0;
3752
3753 // This is recognising a LD1 single-element structure to one lane of one
3754 // register instruction. I.e., if this is an `insertelement` instruction,
3755 // and its second operand is a load, then we will generate a LD1, which
3756 // are expensive instructions.
3757 if (I && dyn_cast<LoadInst>(I->getOperand(1)))
3758 return CostKind == TTI::TCK_CodeSize
3759 ? 0
3760 : ST->getVectorInsertExtractBaseCost() + 1;
3761
3762 // i1 inserts and extract will include an extra cset or cmp of the vector
3763 // value. Increase the cost by 1 to account.
3764 if (Val->getScalarSizeInBits() == 1)
3765 return CostKind == TTI::TCK_CodeSize
3766 ? 2
3767 : ST->getVectorInsertExtractBaseCost() + 1;
3768
3769 // FIXME:
3770 // If the extract-element and insert-element instructions could be
3771 // simplified away (e.g., could be combined into users by looking at use-def
3772 // context), they have no cost. This is not done in the first place for
3773 // compile-time considerations.
3774 }
3775
3776 // In case of Neon, if there exists extractelement from lane != 0 such that
3777 // 1. extractelement does not necessitate a move from vector_reg -> GPR.
3778 // 2. extractelement result feeds into fmul.
3779 // 3. Other operand of fmul is an extractelement from lane 0 or lane
3780 // equivalent to 0.
3781 // then the extractelement can be merged with fmul in the backend and it
3782 // incurs no cost.
3783 // e.g.
3784 // define double @foo(<2 x double> %a) {
3785 // %1 = extractelement <2 x double> %a, i32 0
3786 // %2 = extractelement <2 x double> %a, i32 1
3787 // %res = fmul double %1, %2
3788 // ret double %res
3789 // }
3790 // %2 and %res can be merged in the backend to generate fmul d0, d0, v1.d[1]
3791 auto ExtractCanFuseWithFmul = [&]() {
3792 // We bail out if the extract is from lane 0.
3793 if (Index == 0)
3794 return false;
3795
3796 // Check if the scalar element type of the vector operand of ExtractElement
3797 // instruction is one of the allowed types.
3798 auto IsAllowedScalarTy = [&](const Type *T) {
3799 return T->isFloatTy() || T->isDoubleTy() ||
3800 (T->isHalfTy() && ST->hasFullFP16());
3801 };
3802
3803 // Check if the extractelement user is scalar fmul.
3804 auto IsUserFMulScalarTy = [](const Value *EEUser) {
3805 // Check if the user is scalar fmul.
3806 const auto *BO = dyn_cast<BinaryOperator>(EEUser);
3807 return BO && BO->getOpcode() == BinaryOperator::FMul &&
3808 !BO->getType()->isVectorTy();
3809 };
3810
3811 // Check if the extract index is from lane 0 or lane equivalent to 0 for a
3812 // certain scalar type and a certain vector register width.
3813 auto IsExtractLaneEquivalentToZero = [&](unsigned Idx, unsigned EltSz) {
3814 auto RegWidth =
3815 getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
3816 .getFixedValue();
3817 return Idx == 0 || (RegWidth != 0 && (Idx * EltSz) % RegWidth == 0);
3818 };
3819
3820 // Check if the type constraints on input vector type and result scalar type
3821 // of extractelement instruction are satisfied.
3822 if (!isa<FixedVectorType>(Val) || !IsAllowedScalarTy(Val->getScalarType()))
3823 return false;
3824
3825 if (Scalar) {
3826 DenseMap<User *, unsigned> UserToExtractIdx;
3827 for (auto *U : Scalar->users()) {
3828 if (!IsUserFMulScalarTy(U))
3829 return false;
3830 // Recording entry for the user is important. Index value is not
3831 // important.
3832 UserToExtractIdx[U];
3833 }
3834 if (UserToExtractIdx.empty())
3835 return false;
3836 for (auto &[S, U, L] : ScalarUserAndIdx) {
3837 for (auto *U : S->users()) {
3838 if (UserToExtractIdx.contains(U)) {
3839 auto *FMul = cast<BinaryOperator>(U);
3840 auto *Op0 = FMul->getOperand(0);
3841 auto *Op1 = FMul->getOperand(1);
3842 if ((Op0 == S && Op1 == S) || Op0 != S || Op1 != S) {
3843 UserToExtractIdx[U] = L;
3844 break;
3845 }
3846 }
3847 }
3848 }
3849 for (auto &[U, L] : UserToExtractIdx) {
3850 if (!IsExtractLaneEquivalentToZero(Index, Val->getScalarSizeInBits()) &&
3851 !IsExtractLaneEquivalentToZero(L, Val->getScalarSizeInBits()))
3852 return false;
3853 }
3854 } else {
3855 const auto *EE = cast<ExtractElementInst>(I);
3856
3857 const auto *IdxOp = dyn_cast<ConstantInt>(EE->getIndexOperand());
3858 if (!IdxOp)
3859 return false;
3860
3861 return !EE->users().empty() && all_of(EE->users(), [&](const User *U) {
3862 if (!IsUserFMulScalarTy(U))
3863 return false;
3864
3865 // Check if the other operand of extractelement is also extractelement
3866 // from lane equivalent to 0.
3867 const auto *BO = cast<BinaryOperator>(U);
3868 const auto *OtherEE = dyn_cast<ExtractElementInst>(
3869 BO->getOperand(0) == EE ? BO->getOperand(1) : BO->getOperand(0));
3870 if (OtherEE) {
3871 const auto *IdxOp = dyn_cast<ConstantInt>(OtherEE->getIndexOperand());
3872 if (!IdxOp)
3873 return false;
3874 return IsExtractLaneEquivalentToZero(
3875 cast<ConstantInt>(OtherEE->getIndexOperand())
3876 ->getValue()
3877 .getZExtValue(),
3878 OtherEE->getType()->getScalarSizeInBits());
3879 }
3880 return true;
3881 });
3882 }
3883 return true;
3884 };
3885
3886 if (Opcode == Instruction::ExtractElement && (I || Scalar) &&
3887 ExtractCanFuseWithFmul())
3888 return 0;
3889
3890 // All other insert/extracts cost this much.
3891 return CostKind == TTI::TCK_CodeSize ? 1
3892 : ST->getVectorInsertExtractBaseCost();
3893 }
3894
getVectorInstrCost(unsigned Opcode,Type * Val,TTI::TargetCostKind CostKind,unsigned Index,const Value * Op0,const Value * Op1) const3895 InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val,
3896 TTI::TargetCostKind CostKind,
3897 unsigned Index,
3898 const Value *Op0,
3899 const Value *Op1) const {
3900 // Treat insert at lane 0 into a poison vector as having zero cost. This
3901 // ensures vector broadcasts via an insert + shuffle (and will be lowered to a
3902 // single dup) are treated as cheap.
3903 if (Opcode == Instruction::InsertElement && Index == 0 && Op0 &&
3904 isa<PoisonValue>(Op0))
3905 return 0;
3906 return getVectorInstrCostHelper(Opcode, Val, CostKind, Index);
3907 }
3908
getVectorInstrCost(unsigned Opcode,Type * Val,TTI::TargetCostKind CostKind,unsigned Index,Value * Scalar,ArrayRef<std::tuple<Value *,User *,int>> ScalarUserAndIdx) const3909 InstructionCost AArch64TTIImpl::getVectorInstrCost(
3910 unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
3911 Value *Scalar,
3912 ArrayRef<std::tuple<Value *, User *, int>> ScalarUserAndIdx) const {
3913 return getVectorInstrCostHelper(Opcode, Val, CostKind, Index, nullptr, Scalar,
3914 ScalarUserAndIdx);
3915 }
3916
getVectorInstrCost(const Instruction & I,Type * Val,TTI::TargetCostKind CostKind,unsigned Index) const3917 InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I,
3918 Type *Val,
3919 TTI::TargetCostKind CostKind,
3920 unsigned Index) const {
3921 return getVectorInstrCostHelper(I.getOpcode(), Val, CostKind, Index, &I);
3922 }
3923
getScalarizationOverhead(VectorType * Ty,const APInt & DemandedElts,bool Insert,bool Extract,TTI::TargetCostKind CostKind,bool ForPoisonSrc,ArrayRef<Value * > VL) const3924 InstructionCost AArch64TTIImpl::getScalarizationOverhead(
3925 VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
3926 TTI::TargetCostKind CostKind, bool ForPoisonSrc,
3927 ArrayRef<Value *> VL) const {
3928 if (isa<ScalableVectorType>(Ty))
3929 return InstructionCost::getInvalid();
3930 if (Ty->getElementType()->isFloatingPointTy())
3931 return BaseT::getScalarizationOverhead(Ty, DemandedElts, Insert, Extract,
3932 CostKind);
3933 unsigned VecInstCost =
3934 CostKind == TTI::TCK_CodeSize ? 1 : ST->getVectorInsertExtractBaseCost();
3935 return DemandedElts.popcount() * (Insert + Extract) * VecInstCost;
3936 }
3937
getArithmeticInstrCost(unsigned Opcode,Type * Ty,TTI::TargetCostKind CostKind,TTI::OperandValueInfo Op1Info,TTI::OperandValueInfo Op2Info,ArrayRef<const Value * > Args,const Instruction * CxtI) const3938 InstructionCost AArch64TTIImpl::getArithmeticInstrCost(
3939 unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
3940 TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info,
3941 ArrayRef<const Value *> Args, const Instruction *CxtI) const {
3942
3943 // The code-generator is currently not able to handle scalable vectors
3944 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
3945 // it. This change will be removed when code-generation for these types is
3946 // sufficiently reliable.
3947 if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
3948 if (VTy->getElementCount() == ElementCount::getScalable(1))
3949 return InstructionCost::getInvalid();
3950
3951 // TODO: Handle more cost kinds.
3952 if (CostKind != TTI::TCK_RecipThroughput)
3953 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
3954 Op2Info, Args, CxtI);
3955
3956 // Legalize the type.
3957 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
3958 int ISD = TLI->InstructionOpcodeToISD(Opcode);
3959
3960 switch (ISD) {
3961 default:
3962 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
3963 Op2Info);
3964 case ISD::SREM:
3965 case ISD::SDIV:
3966 /*
3967 Notes for sdiv/srem specific costs:
3968 1. This only considers the cases where the divisor is constant, uniform and
3969 (pow-of-2/non-pow-of-2). Other cases are not important since they either
3970 result in some form of (ldr + adrp), corresponding to constant vectors, or
3971 scalarization of the division operation.
3972 2. Constant divisors, either negative in whole or partially, don't result in
3973 significantly different codegen as compared to positive constant divisors.
3974 So, we don't consider negative divisors separately.
3975 3. If the codegen is significantly different with SVE, it has been indicated
3976 using comments at appropriate places.
3977
3978 sdiv specific cases:
3979 -----------------------------------------------------------------------
3980 codegen | pow-of-2 | Type
3981 -----------------------------------------------------------------------
3982 add + cmp + csel + asr | Y | i64
3983 add + cmp + csel + asr | Y | i32
3984 -----------------------------------------------------------------------
3985
3986 srem specific cases:
3987 -----------------------------------------------------------------------
3988 codegen | pow-of-2 | Type
3989 -----------------------------------------------------------------------
3990 negs + and + and + csneg | Y | i64
3991 negs + and + and + csneg | Y | i32
3992 -----------------------------------------------------------------------
3993
3994 other sdiv/srem cases:
3995 -------------------------------------------------------------------------
3996 common codegen | + srem | + sdiv | pow-of-2 | Type
3997 -------------------------------------------------------------------------
3998 smulh + asr + add + add | - | - | N | i64
3999 smull + lsr + add + add | - | - | N | i32
4000 usra | and + sub | sshr | Y | <2 x i64>
4001 2 * (scalar code) | - | - | N | <2 x i64>
4002 usra | bic + sub | sshr + neg | Y | <4 x i32>
4003 smull2 + smull + uzp2 | mls | - | N | <4 x i32>
4004 + sshr + usra | | | |
4005 -------------------------------------------------------------------------
4006 */
4007 if (Op2Info.isConstant() && Op2Info.isUniform()) {
4008 InstructionCost AddCost =
4009 getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
4010 Op1Info.getNoProps(), Op2Info.getNoProps());
4011 InstructionCost AsrCost =
4012 getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
4013 Op1Info.getNoProps(), Op2Info.getNoProps());
4014 InstructionCost MulCost =
4015 getArithmeticInstrCost(Instruction::Mul, Ty, CostKind,
4016 Op1Info.getNoProps(), Op2Info.getNoProps());
4017 // add/cmp/csel/csneg should have similar cost while asr/negs/and should
4018 // have similar cost.
4019 auto VT = TLI->getValueType(DL, Ty);
4020 if (VT.isScalarInteger() && VT.getSizeInBits() <= 64) {
4021 if (Op2Info.isPowerOf2() || Op2Info.isNegatedPowerOf2()) {
4022 // Neg can be folded into the asr instruction.
4023 return ISD == ISD::SDIV ? (3 * AddCost + AsrCost)
4024 : (3 * AsrCost + AddCost);
4025 } else {
4026 return MulCost + AsrCost + 2 * AddCost;
4027 }
4028 } else if (VT.isVector()) {
4029 InstructionCost UsraCost = 2 * AsrCost;
4030 if (Op2Info.isPowerOf2() || Op2Info.isNegatedPowerOf2()) {
4031 // Division with scalable types corresponds to native 'asrd'
4032 // instruction when SVE is available.
4033 // e.g. %1 = sdiv <vscale x 4 x i32> %a, splat (i32 8)
4034
4035 // One more for the negation in SDIV
4036 InstructionCost Cost =
4037 (Op2Info.isNegatedPowerOf2() && ISD == ISD::SDIV) ? AsrCost : 0;
4038 if (Ty->isScalableTy() && ST->hasSVE())
4039 Cost += 2 * AsrCost;
4040 else {
4041 Cost +=
4042 UsraCost +
4043 (ISD == ISD::SDIV
4044 ? (LT.second.getScalarType() == MVT::i64 ? 1 : 2) * AsrCost
4045 : 2 * AddCost);
4046 }
4047 return Cost;
4048 } else if (LT.second == MVT::v2i64) {
4049 return VT.getVectorNumElements() *
4050 getArithmeticInstrCost(Opcode, Ty->getScalarType(), CostKind,
4051 Op1Info.getNoProps(),
4052 Op2Info.getNoProps());
4053 } else {
4054 // When SVE is available, we get:
4055 // smulh + lsr + add/sub + asr + add/sub.
4056 if (Ty->isScalableTy() && ST->hasSVE())
4057 return MulCost /*smulh cost*/ + 2 * AddCost + 2 * AsrCost;
4058 return 2 * MulCost + AddCost /*uzp2 cost*/ + AsrCost + UsraCost;
4059 }
4060 }
4061 }
4062 if (Op2Info.isConstant() && !Op2Info.isUniform() &&
4063 LT.second.isFixedLengthVector()) {
4064 // FIXME: When the constant vector is non-uniform, this may result in
4065 // loading the vector from constant pool or in some cases, may also result
4066 // in scalarization. For now, we are approximating this with the
4067 // scalarization cost.
4068 auto ExtractCost = 2 * getVectorInstrCost(Instruction::ExtractElement, Ty,
4069 CostKind, -1, nullptr, nullptr);
4070 auto InsertCost = getVectorInstrCost(Instruction::InsertElement, Ty,
4071 CostKind, -1, nullptr, nullptr);
4072 unsigned NElts = cast<FixedVectorType>(Ty)->getNumElements();
4073 return ExtractCost + InsertCost +
4074 NElts * getArithmeticInstrCost(Opcode, Ty->getScalarType(),
4075 CostKind, Op1Info.getNoProps(),
4076 Op2Info.getNoProps());
4077 }
4078 [[fallthrough]];
4079 case ISD::UDIV:
4080 case ISD::UREM: {
4081 auto VT = TLI->getValueType(DL, Ty);
4082 if (Op2Info.isConstant()) {
4083 // If the operand is a power of 2 we can use the shift or and cost.
4084 if (ISD == ISD::UDIV && Op2Info.isPowerOf2())
4085 return getArithmeticInstrCost(Instruction::LShr, Ty, CostKind,
4086 Op1Info.getNoProps(),
4087 Op2Info.getNoProps());
4088 if (ISD == ISD::UREM && Op2Info.isPowerOf2())
4089 return getArithmeticInstrCost(Instruction::And, Ty, CostKind,
4090 Op1Info.getNoProps(),
4091 Op2Info.getNoProps());
4092
4093 if (ISD == ISD::UDIV || ISD == ISD::UREM) {
4094 // Divides by a constant are expanded to MULHU + SUB + SRL + ADD + SRL.
4095 // The MULHU will be expanded to UMULL for the types not listed below,
4096 // and will become a pair of UMULL+MULL2 for 128bit vectors.
4097 bool HasMULH = VT == MVT::i64 || LT.second == MVT::nxv2i64 ||
4098 LT.second == MVT::nxv4i32 || LT.second == MVT::nxv8i16 ||
4099 LT.second == MVT::nxv16i8;
4100 bool Is128bit = LT.second.is128BitVector();
4101
4102 InstructionCost MulCost =
4103 getArithmeticInstrCost(Instruction::Mul, Ty, CostKind,
4104 Op1Info.getNoProps(), Op2Info.getNoProps());
4105 InstructionCost AddCost =
4106 getArithmeticInstrCost(Instruction::Add, Ty, CostKind,
4107 Op1Info.getNoProps(), Op2Info.getNoProps());
4108 InstructionCost ShrCost =
4109 getArithmeticInstrCost(Instruction::AShr, Ty, CostKind,
4110 Op1Info.getNoProps(), Op2Info.getNoProps());
4111 InstructionCost DivCost = MulCost * (Is128bit ? 2 : 1) + // UMULL/UMULH
4112 (HasMULH ? 0 : ShrCost) + // UMULL shift
4113 AddCost * 2 + ShrCost;
4114 return DivCost + (ISD == ISD::UREM ? MulCost + AddCost : 0);
4115 }
4116 }
4117
4118 // div i128's are lowered as libcalls. Pass nullptr as (u)divti3 calls are
4119 // emitted by the backend even when those functions are not declared in the
4120 // module.
4121 if (!VT.isVector() && VT.getSizeInBits() > 64)
4122 return getCallInstrCost(/*Function*/ nullptr, Ty, {Ty, Ty}, CostKind);
4123
4124 InstructionCost Cost = BaseT::getArithmeticInstrCost(
4125 Opcode, Ty, CostKind, Op1Info, Op2Info);
4126 if (Ty->isVectorTy() && (ISD == ISD::SDIV || ISD == ISD::UDIV)) {
4127 if (TLI->isOperationLegalOrCustom(ISD, LT.second) && ST->hasSVE()) {
4128 // SDIV/UDIV operations are lowered using SVE, then we can have less
4129 // costs.
4130 if (VT.isSimple() && isa<FixedVectorType>(Ty) &&
4131 Ty->getPrimitiveSizeInBits().getFixedValue() < 128) {
4132 static const CostTblEntry DivTbl[]{
4133 {ISD::SDIV, MVT::v2i8, 5}, {ISD::SDIV, MVT::v4i8, 8},
4134 {ISD::SDIV, MVT::v8i8, 8}, {ISD::SDIV, MVT::v2i16, 5},
4135 {ISD::SDIV, MVT::v4i16, 5}, {ISD::SDIV, MVT::v2i32, 1},
4136 {ISD::UDIV, MVT::v2i8, 5}, {ISD::UDIV, MVT::v4i8, 8},
4137 {ISD::UDIV, MVT::v8i8, 8}, {ISD::UDIV, MVT::v2i16, 5},
4138 {ISD::UDIV, MVT::v4i16, 5}, {ISD::UDIV, MVT::v2i32, 1}};
4139
4140 const auto *Entry = CostTableLookup(DivTbl, ISD, VT.getSimpleVT());
4141 if (nullptr != Entry)
4142 return Entry->Cost;
4143 }
4144 // For 8/16-bit elements, the cost is higher because the type
4145 // requires promotion and possibly splitting:
4146 if (LT.second.getScalarType() == MVT::i8)
4147 Cost *= 8;
4148 else if (LT.second.getScalarType() == MVT::i16)
4149 Cost *= 4;
4150 return Cost;
4151 } else {
4152 // If one of the operands is a uniform constant then the cost for each
4153 // element is Cost for insertion, extraction and division.
4154 // Insertion cost = 2, Extraction Cost = 2, Division = cost for the
4155 // operation with scalar type
4156 if ((Op1Info.isConstant() && Op1Info.isUniform()) ||
4157 (Op2Info.isConstant() && Op2Info.isUniform())) {
4158 if (auto *VTy = dyn_cast<FixedVectorType>(Ty)) {
4159 InstructionCost DivCost = BaseT::getArithmeticInstrCost(
4160 Opcode, Ty->getScalarType(), CostKind, Op1Info, Op2Info);
4161 return (4 + DivCost) * VTy->getNumElements();
4162 }
4163 }
4164 // On AArch64, without SVE, vector divisions are expanded
4165 // into scalar divisions of each pair of elements.
4166 Cost += getVectorInstrCost(Instruction::ExtractElement, Ty, CostKind,
4167 -1, nullptr, nullptr);
4168 Cost += getVectorInstrCost(Instruction::InsertElement, Ty, CostKind, -1,
4169 nullptr, nullptr);
4170 }
4171
4172 // TODO: if one of the arguments is scalar, then it's not necessary to
4173 // double the cost of handling the vector elements.
4174 Cost += Cost;
4175 }
4176 return Cost;
4177 }
4178 case ISD::MUL:
4179 // When SVE is available, then we can lower the v2i64 operation using
4180 // the SVE mul instruction, which has a lower cost.
4181 if (LT.second == MVT::v2i64 && ST->hasSVE())
4182 return LT.first;
4183
4184 // When SVE is not available, there is no MUL.2d instruction,
4185 // which means mul <2 x i64> is expensive as elements are extracted
4186 // from the vectors and the muls scalarized.
4187 // As getScalarizationOverhead is a bit too pessimistic, we
4188 // estimate the cost for a i64 vector directly here, which is:
4189 // - four 2-cost i64 extracts,
4190 // - two 2-cost i64 inserts, and
4191 // - two 1-cost muls.
4192 // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with
4193 // LT.first = 2 the cost is 28. If both operands are extensions it will not
4194 // need to scalarize so the cost can be cheaper (smull or umull).
4195 // so the cost can be cheaper (smull or umull).
4196 if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args))
4197 return LT.first;
4198 return cast<VectorType>(Ty)->getElementCount().getKnownMinValue() *
4199 (getArithmeticInstrCost(Opcode, Ty->getScalarType(), CostKind) +
4200 getVectorInstrCost(Instruction::ExtractElement, Ty, CostKind, -1,
4201 nullptr, nullptr) *
4202 2 +
4203 getVectorInstrCost(Instruction::InsertElement, Ty, CostKind, -1,
4204 nullptr, nullptr));
4205 case ISD::ADD:
4206 case ISD::XOR:
4207 case ISD::OR:
4208 case ISD::AND:
4209 case ISD::SRL:
4210 case ISD::SRA:
4211 case ISD::SHL:
4212 // These nodes are marked as 'custom' for combining purposes only.
4213 // We know that they are legal. See LowerAdd in ISelLowering.
4214 return LT.first;
4215
4216 case ISD::FNEG:
4217 // Scalar fmul(fneg) or fneg(fmul) can be converted to fnmul
4218 if ((Ty->isFloatTy() || Ty->isDoubleTy() ||
4219 (Ty->isHalfTy() && ST->hasFullFP16())) &&
4220 CxtI &&
4221 ((CxtI->hasOneUse() &&
4222 match(*CxtI->user_begin(), m_FMul(m_Value(), m_Value()))) ||
4223 match(CxtI->getOperand(0), m_FMul(m_Value(), m_Value()))))
4224 return 0;
4225 [[fallthrough]];
4226 case ISD::FADD:
4227 case ISD::FSUB:
4228 // Increase the cost for half and bfloat types if not architecturally
4229 // supported.
4230 if ((Ty->getScalarType()->isHalfTy() && !ST->hasFullFP16()) ||
4231 (Ty->getScalarType()->isBFloatTy() && !ST->hasBF16()))
4232 return 2 * LT.first;
4233 if (!Ty->getScalarType()->isFP128Ty())
4234 return LT.first;
4235 [[fallthrough]];
4236 case ISD::FMUL:
4237 case ISD::FDIV:
4238 // These nodes are marked as 'custom' just to lower them to SVE.
4239 // We know said lowering will incur no additional cost.
4240 if (!Ty->getScalarType()->isFP128Ty())
4241 return 2 * LT.first;
4242
4243 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
4244 Op2Info);
4245 case ISD::FREM:
4246 // Pass nullptr as fmod/fmodf calls are emitted by the backend even when
4247 // those functions are not declared in the module.
4248 if (!Ty->isVectorTy())
4249 return getCallInstrCost(/*Function*/ nullptr, Ty, {Ty, Ty}, CostKind);
4250 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info,
4251 Op2Info);
4252 }
4253 }
4254
4255 InstructionCost
getAddressComputationCost(Type * Ty,ScalarEvolution * SE,const SCEV * Ptr) const4256 AArch64TTIImpl::getAddressComputationCost(Type *Ty, ScalarEvolution *SE,
4257 const SCEV *Ptr) const {
4258 // Address computations in vectorized code with non-consecutive addresses will
4259 // likely result in more instructions compared to scalar code where the
4260 // computation can more often be merged into the index mode. The resulting
4261 // extra micro-ops can significantly decrease throughput.
4262 unsigned NumVectorInstToHideOverhead = NeonNonConstStrideOverhead;
4263 int MaxMergeDistance = 64;
4264
4265 if (Ty->isVectorTy() && SE &&
4266 !BaseT::isConstantStridedAccessLessThan(SE, Ptr, MaxMergeDistance + 1))
4267 return NumVectorInstToHideOverhead;
4268
4269 // In many cases the address computation is not merged into the instruction
4270 // addressing mode.
4271 return 1;
4272 }
4273
getCmpSelInstrCost(unsigned Opcode,Type * ValTy,Type * CondTy,CmpInst::Predicate VecPred,TTI::TargetCostKind CostKind,TTI::OperandValueInfo Op1Info,TTI::OperandValueInfo Op2Info,const Instruction * I) const4274 InstructionCost AArch64TTIImpl::getCmpSelInstrCost(
4275 unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
4276 TTI::TargetCostKind CostKind, TTI::OperandValueInfo Op1Info,
4277 TTI::OperandValueInfo Op2Info, const Instruction *I) const {
4278 int ISD = TLI->InstructionOpcodeToISD(Opcode);
4279 // We don't lower some vector selects well that are wider than the register
4280 // width. TODO: Improve this with different cost kinds.
4281 if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT) {
4282 // We would need this many instructions to hide the scalarization happening.
4283 const int AmortizationCost = 20;
4284
4285 // If VecPred is not set, check if we can get a predicate from the context
4286 // instruction, if its type matches the requested ValTy.
4287 if (VecPred == CmpInst::BAD_ICMP_PREDICATE && I && I->getType() == ValTy) {
4288 CmpPredicate CurrentPred;
4289 if (match(I, m_Select(m_Cmp(CurrentPred, m_Value(), m_Value()), m_Value(),
4290 m_Value())))
4291 VecPred = CurrentPred;
4292 }
4293 // Check if we have a compare/select chain that can be lowered using
4294 // a (F)CMxx & BFI pair.
4295 if (CmpInst::isIntPredicate(VecPred) || VecPred == CmpInst::FCMP_OLE ||
4296 VecPred == CmpInst::FCMP_OLT || VecPred == CmpInst::FCMP_OGT ||
4297 VecPred == CmpInst::FCMP_OGE || VecPred == CmpInst::FCMP_OEQ ||
4298 VecPred == CmpInst::FCMP_UNE) {
4299 static const auto ValidMinMaxTys = {
4300 MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16, MVT::v2i32,
4301 MVT::v4i32, MVT::v2i64, MVT::v2f32, MVT::v4f32, MVT::v2f64};
4302 static const auto ValidFP16MinMaxTys = {MVT::v4f16, MVT::v8f16};
4303
4304 auto LT = getTypeLegalizationCost(ValTy);
4305 if (any_of(ValidMinMaxTys, [<](MVT M) { return M == LT.second; }) ||
4306 (ST->hasFullFP16() &&
4307 any_of(ValidFP16MinMaxTys, [<](MVT M) { return M == LT.second; })))
4308 return LT.first;
4309 }
4310
4311 static const TypeConversionCostTblEntry
4312 VectorSelectTbl[] = {
4313 { ISD::SELECT, MVT::v2i1, MVT::v2f32, 2 },
4314 { ISD::SELECT, MVT::v2i1, MVT::v2f64, 2 },
4315 { ISD::SELECT, MVT::v4i1, MVT::v4f32, 2 },
4316 { ISD::SELECT, MVT::v4i1, MVT::v4f16, 2 },
4317 { ISD::SELECT, MVT::v8i1, MVT::v8f16, 2 },
4318 { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 },
4319 { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 },
4320 { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 },
4321 { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost },
4322 { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost },
4323 { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost }
4324 };
4325
4326 EVT SelCondTy = TLI->getValueType(DL, CondTy);
4327 EVT SelValTy = TLI->getValueType(DL, ValTy);
4328 if (SelCondTy.isSimple() && SelValTy.isSimple()) {
4329 if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD,
4330 SelCondTy.getSimpleVT(),
4331 SelValTy.getSimpleVT()))
4332 return Entry->Cost;
4333 }
4334 }
4335
4336 if (isa<FixedVectorType>(ValTy) && ISD == ISD::SETCC) {
4337 Type *ValScalarTy = ValTy->getScalarType();
4338 if ((ValScalarTy->isHalfTy() && !ST->hasFullFP16()) ||
4339 ValScalarTy->isBFloatTy()) {
4340 auto *ValVTy = cast<FixedVectorType>(ValTy);
4341
4342 // Without dedicated instructions we promote [b]f16 compares to f32.
4343 auto *PromotedTy =
4344 VectorType::get(Type::getFloatTy(ValTy->getContext()), ValVTy);
4345
4346 InstructionCost Cost = 0;
4347 // Promote operands to float vectors.
4348 Cost += 2 * getCastInstrCost(Instruction::FPExt, PromotedTy, ValTy,
4349 TTI::CastContextHint::None, CostKind);
4350 // Compare float vectors.
4351 Cost += getCmpSelInstrCost(Opcode, PromotedTy, CondTy, VecPred, CostKind,
4352 Op1Info, Op2Info);
4353 // During codegen we'll truncate the vector result from i32 to i16.
4354 Cost +=
4355 getCastInstrCost(Instruction::Trunc, VectorType::getInteger(ValVTy),
4356 VectorType::getInteger(PromotedTy),
4357 TTI::CastContextHint::None, CostKind);
4358 return Cost;
4359 }
4360 }
4361
4362 // Treat the icmp in icmp(and, 0) or icmp(and, -1/1) when it can be folded to
4363 // icmp(and, 0) as free, as we can make use of ands, but only if the
4364 // comparison is not unsigned. FIXME: Enable for non-throughput cost kinds
4365 // providing it will not cause performance regressions.
4366 if (CostKind == TTI::TCK_RecipThroughput && ValTy->isIntegerTy() &&
4367 ISD == ISD::SETCC && I && !CmpInst::isUnsigned(VecPred) &&
4368 TLI->isTypeLegal(TLI->getValueType(DL, ValTy)) &&
4369 match(I->getOperand(0), m_And(m_Value(), m_Value()))) {
4370 if (match(I->getOperand(1), m_Zero()))
4371 return 0;
4372
4373 // x >= 1 / x < 1 -> x > 0 / x <= 0
4374 if (match(I->getOperand(1), m_One()) &&
4375 (VecPred == CmpInst::ICMP_SLT || VecPred == CmpInst::ICMP_SGE))
4376 return 0;
4377
4378 // x <= -1 / x > -1 -> x > 0 / x <= 0
4379 if (match(I->getOperand(1), m_AllOnes()) &&
4380 (VecPred == CmpInst::ICMP_SLE || VecPred == CmpInst::ICMP_SGT))
4381 return 0;
4382 }
4383
4384 // The base case handles scalable vectors fine for now, since it treats the
4385 // cost as 1 * legalization cost.
4386 return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind,
4387 Op1Info, Op2Info, I);
4388 }
4389
4390 AArch64TTIImpl::TTI::MemCmpExpansionOptions
enableMemCmpExpansion(bool OptSize,bool IsZeroCmp) const4391 AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const {
4392 TTI::MemCmpExpansionOptions Options;
4393 if (ST->requiresStrictAlign()) {
4394 // TODO: Add cost modeling for strict align. Misaligned loads expand to
4395 // a bunch of instructions when strict align is enabled.
4396 return Options;
4397 }
4398 Options.AllowOverlappingLoads = true;
4399 Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize);
4400 Options.NumLoadsPerBlock = Options.MaxNumLoads;
4401 // TODO: Though vector loads usually perform well on AArch64, in some targets
4402 // they may wake up the FP unit, which raises the power consumption. Perhaps
4403 // they could be used with no holds barred (-O3).
4404 Options.LoadSizes = {8, 4, 2, 1};
4405 Options.AllowedTailExpansions = {3, 5, 6};
4406 return Options;
4407 }
4408
prefersVectorizedAddressing() const4409 bool AArch64TTIImpl::prefersVectorizedAddressing() const {
4410 return ST->hasSVE();
4411 }
4412
4413 InstructionCost
getMaskedMemoryOpCost(unsigned Opcode,Type * Src,Align Alignment,unsigned AddressSpace,TTI::TargetCostKind CostKind) const4414 AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
4415 Align Alignment, unsigned AddressSpace,
4416 TTI::TargetCostKind CostKind) const {
4417 if (useNeonVector(Src))
4418 return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace,
4419 CostKind);
4420 auto LT = getTypeLegalizationCost(Src);
4421 if (!LT.first.isValid())
4422 return InstructionCost::getInvalid();
4423
4424 // Return an invalid cost for element types that we are unable to lower.
4425 auto *VT = cast<VectorType>(Src);
4426 if (VT->getElementType()->isIntegerTy(1))
4427 return InstructionCost::getInvalid();
4428
4429 // The code-generator is currently not able to handle scalable vectors
4430 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
4431 // it. This change will be removed when code-generation for these types is
4432 // sufficiently reliable.
4433 if (VT->getElementCount() == ElementCount::getScalable(1))
4434 return InstructionCost::getInvalid();
4435
4436 return LT.first;
4437 }
4438
4439 // This function returns gather/scatter overhead either from
4440 // user-provided value or specialized values per-target from \p ST.
getSVEGatherScatterOverhead(unsigned Opcode,const AArch64Subtarget * ST)4441 static unsigned getSVEGatherScatterOverhead(unsigned Opcode,
4442 const AArch64Subtarget *ST) {
4443 assert((Opcode == Instruction::Load || Opcode == Instruction::Store) &&
4444 "Should be called on only load or stores.");
4445 switch (Opcode) {
4446 case Instruction::Load:
4447 if (SVEGatherOverhead.getNumOccurrences() > 0)
4448 return SVEGatherOverhead;
4449 return ST->getGatherOverhead();
4450 break;
4451 case Instruction::Store:
4452 if (SVEScatterOverhead.getNumOccurrences() > 0)
4453 return SVEScatterOverhead;
4454 return ST->getScatterOverhead();
4455 break;
4456 default:
4457 llvm_unreachable("Shouldn't have reached here");
4458 }
4459 }
4460
getGatherScatterOpCost(unsigned Opcode,Type * DataTy,const Value * Ptr,bool VariableMask,Align Alignment,TTI::TargetCostKind CostKind,const Instruction * I) const4461 InstructionCost AArch64TTIImpl::getGatherScatterOpCost(
4462 unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask,
4463 Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) const {
4464 if (useNeonVector(DataTy) || !isLegalMaskedGatherScatter(DataTy))
4465 return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask,
4466 Alignment, CostKind, I);
4467 auto *VT = cast<VectorType>(DataTy);
4468 auto LT = getTypeLegalizationCost(DataTy);
4469 if (!LT.first.isValid())
4470 return InstructionCost::getInvalid();
4471
4472 // Return an invalid cost for element types that we are unable to lower.
4473 if (!LT.second.isVector() ||
4474 !isElementTypeLegalForScalableVector(VT->getElementType()) ||
4475 VT->getElementType()->isIntegerTy(1))
4476 return InstructionCost::getInvalid();
4477
4478 // The code-generator is currently not able to handle scalable vectors
4479 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
4480 // it. This change will be removed when code-generation for these types is
4481 // sufficiently reliable.
4482 if (VT->getElementCount() == ElementCount::getScalable(1))
4483 return InstructionCost::getInvalid();
4484
4485 ElementCount LegalVF = LT.second.getVectorElementCount();
4486 InstructionCost MemOpCost =
4487 getMemoryOpCost(Opcode, VT->getElementType(), Alignment, 0, CostKind,
4488 {TTI::OK_AnyValue, TTI::OP_None}, I);
4489 // Add on an overhead cost for using gathers/scatters.
4490 MemOpCost *= getSVEGatherScatterOverhead(Opcode, ST);
4491 return LT.first * MemOpCost * getMaxNumElements(LegalVF);
4492 }
4493
useNeonVector(const Type * Ty) const4494 bool AArch64TTIImpl::useNeonVector(const Type *Ty) const {
4495 return isa<FixedVectorType>(Ty) && !ST->useSVEForFixedLengthVectors();
4496 }
4497
getMemoryOpCost(unsigned Opcode,Type * Ty,Align Alignment,unsigned AddressSpace,TTI::TargetCostKind CostKind,TTI::OperandValueInfo OpInfo,const Instruction * I) const4498 InstructionCost AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty,
4499 Align Alignment,
4500 unsigned AddressSpace,
4501 TTI::TargetCostKind CostKind,
4502 TTI::OperandValueInfo OpInfo,
4503 const Instruction *I) const {
4504 EVT VT = TLI->getValueType(DL, Ty, true);
4505 // Type legalization can't handle structs
4506 if (VT == MVT::Other)
4507 return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
4508 CostKind);
4509
4510 auto LT = getTypeLegalizationCost(Ty);
4511 if (!LT.first.isValid())
4512 return InstructionCost::getInvalid();
4513
4514 // The code-generator is currently not able to handle scalable vectors
4515 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
4516 // it. This change will be removed when code-generation for these types is
4517 // sufficiently reliable.
4518 // We also only support full register predicate loads and stores.
4519 if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
4520 if (VTy->getElementCount() == ElementCount::getScalable(1) ||
4521 (VTy->getElementType()->isIntegerTy(1) &&
4522 !VTy->getElementCount().isKnownMultipleOf(
4523 ElementCount::getScalable(16))))
4524 return InstructionCost::getInvalid();
4525
4526 // TODO: consider latency as well for TCK_SizeAndLatency.
4527 if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency)
4528 return LT.first;
4529
4530 if (CostKind != TTI::TCK_RecipThroughput)
4531 return 1;
4532
4533 if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store &&
4534 LT.second.is128BitVector() && Alignment < Align(16)) {
4535 // Unaligned stores are extremely inefficient. We don't split all
4536 // unaligned 128-bit stores because the negative impact that has shown in
4537 // practice on inlined block copy code.
4538 // We make such stores expensive so that we will only vectorize if there
4539 // are 6 other instructions getting vectorized.
4540 const int AmortizationCost = 6;
4541
4542 return LT.first * 2 * AmortizationCost;
4543 }
4544
4545 // Opaque ptr or ptr vector types are i64s and can be lowered to STP/LDPs.
4546 if (Ty->isPtrOrPtrVectorTy())
4547 return LT.first;
4548
4549 if (useNeonVector(Ty)) {
4550 // Check truncating stores and extending loads.
4551 if (Ty->getScalarSizeInBits() != LT.second.getScalarSizeInBits()) {
4552 // v4i8 types are lowered to scalar a load/store and sshll/xtn.
4553 if (VT == MVT::v4i8)
4554 return 2;
4555 // Otherwise we need to scalarize.
4556 return cast<FixedVectorType>(Ty)->getNumElements() * 2;
4557 }
4558 EVT EltVT = VT.getVectorElementType();
4559 unsigned EltSize = EltVT.getScalarSizeInBits();
4560 if (!isPowerOf2_32(EltSize) || EltSize < 8 || EltSize > 64 ||
4561 VT.getVectorNumElements() >= (128 / EltSize) || Alignment != Align(1))
4562 return LT.first;
4563 // FIXME: v3i8 lowering currently is very inefficient, due to automatic
4564 // widening to v4i8, which produces suboptimal results.
4565 if (VT.getVectorNumElements() == 3 && EltVT == MVT::i8)
4566 return LT.first;
4567
4568 // Check non-power-of-2 loads/stores for legal vector element types with
4569 // NEON. Non-power-of-2 memory ops will get broken down to a set of
4570 // operations on smaller power-of-2 ops, including ld1/st1.
4571 LLVMContext &C = Ty->getContext();
4572 InstructionCost Cost(0);
4573 SmallVector<EVT> TypeWorklist;
4574 TypeWorklist.push_back(VT);
4575 while (!TypeWorklist.empty()) {
4576 EVT CurrVT = TypeWorklist.pop_back_val();
4577 unsigned CurrNumElements = CurrVT.getVectorNumElements();
4578 if (isPowerOf2_32(CurrNumElements)) {
4579 Cost += 1;
4580 continue;
4581 }
4582
4583 unsigned PrevPow2 = NextPowerOf2(CurrNumElements) / 2;
4584 TypeWorklist.push_back(EVT::getVectorVT(C, EltVT, PrevPow2));
4585 TypeWorklist.push_back(
4586 EVT::getVectorVT(C, EltVT, CurrNumElements - PrevPow2));
4587 }
4588 return Cost;
4589 }
4590
4591 return LT.first;
4592 }
4593
getInterleavedMemoryOpCost(unsigned Opcode,Type * VecTy,unsigned Factor,ArrayRef<unsigned> Indices,Align Alignment,unsigned AddressSpace,TTI::TargetCostKind CostKind,bool UseMaskForCond,bool UseMaskForGaps) const4594 InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost(
4595 unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
4596 Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
4597 bool UseMaskForCond, bool UseMaskForGaps) const {
4598 assert(Factor >= 2 && "Invalid interleave factor");
4599 auto *VecVTy = cast<VectorType>(VecTy);
4600
4601 if (VecTy->isScalableTy() && !ST->hasSVE())
4602 return InstructionCost::getInvalid();
4603
4604 // Scalable VFs will emit vector.[de]interleave intrinsics, and currently we
4605 // only have lowering for power-of-2 factors.
4606 // TODO: Add lowering for vector.[de]interleave3 intrinsics and support in
4607 // InterleavedAccessPass for ld3/st3
4608 if (VecTy->isScalableTy() && !isPowerOf2_32(Factor))
4609 return InstructionCost::getInvalid();
4610
4611 // Vectorization for masked interleaved accesses is only enabled for scalable
4612 // VF.
4613 if (!VecTy->isScalableTy() && (UseMaskForCond || UseMaskForGaps))
4614 return InstructionCost::getInvalid();
4615
4616 if (!UseMaskForGaps && Factor <= TLI->getMaxSupportedInterleaveFactor()) {
4617 unsigned MinElts = VecVTy->getElementCount().getKnownMinValue();
4618 auto *SubVecTy =
4619 VectorType::get(VecVTy->getElementType(),
4620 VecVTy->getElementCount().divideCoefficientBy(Factor));
4621
4622 // ldN/stN only support legal vector types of size 64 or 128 in bits.
4623 // Accesses having vector types that are a multiple of 128 bits can be
4624 // matched to more than one ldN/stN instruction.
4625 bool UseScalable;
4626 if (MinElts % Factor == 0 &&
4627 TLI->isLegalInterleavedAccessType(SubVecTy, DL, UseScalable))
4628 return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL, UseScalable);
4629 }
4630
4631 return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices,
4632 Alignment, AddressSpace, CostKind,
4633 UseMaskForCond, UseMaskForGaps);
4634 }
4635
4636 InstructionCost
getCostOfKeepingLiveOverCall(ArrayRef<Type * > Tys) const4637 AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) const {
4638 InstructionCost Cost = 0;
4639 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
4640 for (auto *I : Tys) {
4641 if (!I->isVectorTy())
4642 continue;
4643 if (I->getScalarSizeInBits() * cast<FixedVectorType>(I)->getNumElements() ==
4644 128)
4645 Cost += getMemoryOpCost(Instruction::Store, I, Align(128), 0, CostKind) +
4646 getMemoryOpCost(Instruction::Load, I, Align(128), 0, CostKind);
4647 }
4648 return Cost;
4649 }
4650
getMaxInterleaveFactor(ElementCount VF) const4651 unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) const {
4652 return ST->getMaxInterleaveFactor();
4653 }
4654
4655 // For Falkor, we want to avoid having too many strided loads in a loop since
4656 // that can exhaust the HW prefetcher resources. We adjust the unroller
4657 // MaxCount preference below to attempt to ensure unrolling doesn't create too
4658 // many strided loads.
4659 static void
getFalkorUnrollingPreferences(Loop * L,ScalarEvolution & SE,TargetTransformInfo::UnrollingPreferences & UP)4660 getFalkorUnrollingPreferences(Loop *L, ScalarEvolution &SE,
4661 TargetTransformInfo::UnrollingPreferences &UP) {
4662 enum { MaxStridedLoads = 7 };
4663 auto countStridedLoads = [](Loop *L, ScalarEvolution &SE) {
4664 int StridedLoads = 0;
4665 // FIXME? We could make this more precise by looking at the CFG and
4666 // e.g. not counting loads in each side of an if-then-else diamond.
4667 for (const auto BB : L->blocks()) {
4668 for (auto &I : *BB) {
4669 LoadInst *LMemI = dyn_cast<LoadInst>(&I);
4670 if (!LMemI)
4671 continue;
4672
4673 Value *PtrValue = LMemI->getPointerOperand();
4674 if (L->isLoopInvariant(PtrValue))
4675 continue;
4676
4677 const SCEV *LSCEV = SE.getSCEV(PtrValue);
4678 const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV);
4679 if (!LSCEVAddRec || !LSCEVAddRec->isAffine())
4680 continue;
4681
4682 // FIXME? We could take pairing of unrolled load copies into account
4683 // by looking at the AddRec, but we would probably have to limit this
4684 // to loops with no stores or other memory optimization barriers.
4685 ++StridedLoads;
4686 // We've seen enough strided loads that seeing more won't make a
4687 // difference.
4688 if (StridedLoads > MaxStridedLoads / 2)
4689 return StridedLoads;
4690 }
4691 }
4692 return StridedLoads;
4693 };
4694
4695 int StridedLoads = countStridedLoads(L, SE);
4696 LLVM_DEBUG(dbgs() << "falkor-hwpf: detected " << StridedLoads
4697 << " strided loads\n");
4698 // Pick the largest power of 2 unroll count that won't result in too many
4699 // strided loads.
4700 if (StridedLoads) {
4701 UP.MaxCount = 1 << Log2_32(MaxStridedLoads / StridedLoads);
4702 LLVM_DEBUG(dbgs() << "falkor-hwpf: setting unroll MaxCount to "
4703 << UP.MaxCount << '\n');
4704 }
4705 }
4706
4707 // This function returns true if the loop:
4708 // 1. Has a valid cost, and
4709 // 2. Has a cost within the supplied budget.
4710 // Otherwise it returns false.
isLoopSizeWithinBudget(Loop * L,const AArch64TTIImpl & TTI,InstructionCost Budget,unsigned * FinalSize)4711 static bool isLoopSizeWithinBudget(Loop *L, const AArch64TTIImpl &TTI,
4712 InstructionCost Budget,
4713 unsigned *FinalSize) {
4714 // Estimate the size of the loop.
4715 InstructionCost LoopCost = 0;
4716
4717 for (auto *BB : L->getBlocks()) {
4718 for (auto &I : *BB) {
4719 SmallVector<const Value *, 4> Operands(I.operand_values());
4720 InstructionCost Cost =
4721 TTI.getInstructionCost(&I, Operands, TTI::TCK_CodeSize);
4722 // This can happen with intrinsics that don't currently have a cost model
4723 // or for some operations that require SVE.
4724 if (!Cost.isValid())
4725 return false;
4726
4727 LoopCost += Cost;
4728 if (LoopCost > Budget)
4729 return false;
4730 }
4731 }
4732
4733 if (FinalSize)
4734 *FinalSize = LoopCost.getValue();
4735 return true;
4736 }
4737
shouldUnrollMultiExitLoop(Loop * L,ScalarEvolution & SE,const AArch64TTIImpl & TTI)4738 static bool shouldUnrollMultiExitLoop(Loop *L, ScalarEvolution &SE,
4739 const AArch64TTIImpl &TTI) {
4740 // Only consider loops with unknown trip counts for which we can determine
4741 // a symbolic expression. Multi-exit loops with small known trip counts will
4742 // likely be unrolled anyway.
4743 const SCEV *BTC = SE.getSymbolicMaxBackedgeTakenCount(L);
4744 if (isa<SCEVConstant>(BTC) || isa<SCEVCouldNotCompute>(BTC))
4745 return false;
4746
4747 // It might not be worth unrolling loops with low max trip counts. Restrict
4748 // this to max trip counts > 32 for now.
4749 unsigned MaxTC = SE.getSmallConstantMaxTripCount(L);
4750 if (MaxTC > 0 && MaxTC <= 32)
4751 return false;
4752
4753 // Make sure the loop size is <= 5.
4754 if (!isLoopSizeWithinBudget(L, TTI, 5, nullptr))
4755 return false;
4756
4757 // Small search loops with multiple exits can be highly beneficial to unroll.
4758 // We only care about loops with exactly two exiting blocks, although each
4759 // block could jump to the same exit block.
4760 ArrayRef<BasicBlock *> Blocks = L->getBlocks();
4761 if (Blocks.size() != 2)
4762 return false;
4763
4764 if (any_of(Blocks, [](BasicBlock *BB) {
4765 return !isa<BranchInst>(BB->getTerminator());
4766 }))
4767 return false;
4768
4769 return true;
4770 }
4771
4772 /// For Apple CPUs, we want to runtime-unroll loops to make better use if the
4773 /// OOO engine's wide instruction window and various predictors.
4774 static void
getAppleRuntimeUnrollPreferences(Loop * L,ScalarEvolution & SE,TargetTransformInfo::UnrollingPreferences & UP,const AArch64TTIImpl & TTI)4775 getAppleRuntimeUnrollPreferences(Loop *L, ScalarEvolution &SE,
4776 TargetTransformInfo::UnrollingPreferences &UP,
4777 const AArch64TTIImpl &TTI) {
4778 // Limit loops with structure that is highly likely to benefit from runtime
4779 // unrolling; that is we exclude outer loops and loops with many blocks (i.e.
4780 // likely with complex control flow). Note that the heuristics here may be
4781 // overly conservative and we err on the side of avoiding runtime unrolling
4782 // rather than unroll excessively. They are all subject to further refinement.
4783 if (!L->isInnermost() || L->getNumBlocks() > 8)
4784 return;
4785
4786 // Loops with multiple exits are handled by common code.
4787 if (!L->getExitBlock())
4788 return;
4789
4790 const SCEV *BTC = SE.getSymbolicMaxBackedgeTakenCount(L);
4791 if (isa<SCEVConstant>(BTC) || isa<SCEVCouldNotCompute>(BTC) ||
4792 (SE.getSmallConstantMaxTripCount(L) > 0 &&
4793 SE.getSmallConstantMaxTripCount(L) <= 32))
4794 return;
4795
4796 if (findStringMetadataForLoop(L, "llvm.loop.isvectorized"))
4797 return;
4798
4799 if (SE.getSymbolicMaxBackedgeTakenCount(L) != SE.getBackedgeTakenCount(L))
4800 return;
4801
4802 // Limit to loops with trip counts that are cheap to expand.
4803 UP.SCEVExpansionBudget = 1;
4804
4805 // Try to unroll small, single block loops, if they have load/store
4806 // dependencies, to expose more parallel memory access streams.
4807 BasicBlock *Header = L->getHeader();
4808 if (Header == L->getLoopLatch()) {
4809 // Estimate the size of the loop.
4810 unsigned Size;
4811 if (!isLoopSizeWithinBudget(L, TTI, 8, &Size))
4812 return;
4813
4814 SmallPtrSet<Value *, 8> LoadedValues;
4815 SmallVector<StoreInst *> Stores;
4816 for (auto *BB : L->blocks()) {
4817 for (auto &I : *BB) {
4818 Value *Ptr = getLoadStorePointerOperand(&I);
4819 if (!Ptr)
4820 continue;
4821 const SCEV *PtrSCEV = SE.getSCEV(Ptr);
4822 if (SE.isLoopInvariant(PtrSCEV, L))
4823 continue;
4824 if (isa<LoadInst>(&I))
4825 LoadedValues.insert(&I);
4826 else
4827 Stores.push_back(cast<StoreInst>(&I));
4828 }
4829 }
4830
4831 // Try to find an unroll count that maximizes the use of the instruction
4832 // window, i.e. trying to fetch as many instructions per cycle as possible.
4833 unsigned MaxInstsPerLine = 16;
4834 unsigned UC = 1;
4835 unsigned BestUC = 1;
4836 unsigned SizeWithBestUC = BestUC * Size;
4837 while (UC <= 8) {
4838 unsigned SizeWithUC = UC * Size;
4839 if (SizeWithUC > 48)
4840 break;
4841 if ((SizeWithUC % MaxInstsPerLine) == 0 ||
4842 (SizeWithBestUC % MaxInstsPerLine) < (SizeWithUC % MaxInstsPerLine)) {
4843 BestUC = UC;
4844 SizeWithBestUC = BestUC * Size;
4845 }
4846 UC++;
4847 }
4848
4849 if (BestUC == 1 || none_of(Stores, [&LoadedValues](StoreInst *SI) {
4850 return LoadedValues.contains(SI->getOperand(0));
4851 }))
4852 return;
4853
4854 UP.Runtime = true;
4855 UP.DefaultUnrollRuntimeCount = BestUC;
4856 return;
4857 }
4858
4859 // Try to runtime-unroll loops with early-continues depending on loop-varying
4860 // loads; this helps with branch-prediction for the early-continues.
4861 auto *Term = dyn_cast<BranchInst>(Header->getTerminator());
4862 auto *Latch = L->getLoopLatch();
4863 SmallVector<BasicBlock *> Preds(predecessors(Latch));
4864 if (!Term || !Term->isConditional() || Preds.size() == 1 ||
4865 !llvm::is_contained(Preds, Header) ||
4866 none_of(Preds, [L](BasicBlock *Pred) { return L->contains(Pred); }))
4867 return;
4868
4869 std::function<bool(Instruction *, unsigned)> DependsOnLoopLoad =
4870 [&](Instruction *I, unsigned Depth) -> bool {
4871 if (isa<PHINode>(I) || L->isLoopInvariant(I) || Depth > 8)
4872 return false;
4873
4874 if (isa<LoadInst>(I))
4875 return true;
4876
4877 return any_of(I->operands(), [&](Value *V) {
4878 auto *I = dyn_cast<Instruction>(V);
4879 return I && DependsOnLoopLoad(I, Depth + 1);
4880 });
4881 };
4882 CmpPredicate Pred;
4883 Instruction *I;
4884 if (match(Term, m_Br(m_ICmp(Pred, m_Instruction(I), m_Value()), m_Value(),
4885 m_Value())) &&
4886 DependsOnLoopLoad(I, 0)) {
4887 UP.Runtime = true;
4888 }
4889 }
4890
getUnrollingPreferences(Loop * L,ScalarEvolution & SE,TTI::UnrollingPreferences & UP,OptimizationRemarkEmitter * ORE) const4891 void AArch64TTIImpl::getUnrollingPreferences(
4892 Loop *L, ScalarEvolution &SE, TTI::UnrollingPreferences &UP,
4893 OptimizationRemarkEmitter *ORE) const {
4894 // Enable partial unrolling and runtime unrolling.
4895 BaseT::getUnrollingPreferences(L, SE, UP, ORE);
4896
4897 UP.UpperBound = true;
4898
4899 // For inner loop, it is more likely to be a hot one, and the runtime check
4900 // can be promoted out from LICM pass, so the overhead is less, let's try
4901 // a larger threshold to unroll more loops.
4902 if (L->getLoopDepth() > 1)
4903 UP.PartialThreshold *= 2;
4904
4905 // Disable partial & runtime unrolling on -Os.
4906 UP.PartialOptSizeThreshold = 0;
4907
4908 // No need to unroll auto-vectorized loops
4909 if (findStringMetadataForLoop(L, "llvm.loop.isvectorized"))
4910 return;
4911
4912 // Scan the loop: don't unroll loops with calls as this could prevent
4913 // inlining.
4914 for (auto *BB : L->getBlocks()) {
4915 for (auto &I : *BB) {
4916 if (isa<CallBase>(I)) {
4917 if (isa<CallInst>(I) || isa<InvokeInst>(I))
4918 if (const Function *F = cast<CallBase>(I).getCalledFunction())
4919 if (!isLoweredToCall(F))
4920 continue;
4921 return;
4922 }
4923 }
4924 }
4925
4926 // Apply subtarget-specific unrolling preferences.
4927 switch (ST->getProcFamily()) {
4928 case AArch64Subtarget::AppleA14:
4929 case AArch64Subtarget::AppleA15:
4930 case AArch64Subtarget::AppleA16:
4931 case AArch64Subtarget::AppleM4:
4932 getAppleRuntimeUnrollPreferences(L, SE, UP, *this);
4933 break;
4934 case AArch64Subtarget::Falkor:
4935 if (EnableFalkorHWPFUnrollFix)
4936 getFalkorUnrollingPreferences(L, SE, UP);
4937 break;
4938 default:
4939 break;
4940 }
4941
4942 // If this is a small, multi-exit loop similar to something like std::find,
4943 // then there is typically a performance improvement achieved by unrolling.
4944 if (!L->getExitBlock() && shouldUnrollMultiExitLoop(L, SE, *this)) {
4945 UP.RuntimeUnrollMultiExit = true;
4946 UP.Runtime = true;
4947 // Limit unroll count.
4948 UP.DefaultUnrollRuntimeCount = 4;
4949 // Allow slightly more costly trip-count expansion to catch search loops
4950 // with pointer inductions.
4951 UP.SCEVExpansionBudget = 5;
4952 return;
4953 }
4954
4955 // Enable runtime unrolling for in-order models
4956 // If mcpu is omitted, getProcFamily() returns AArch64Subtarget::Others, so by
4957 // checking for that case, we can ensure that the default behaviour is
4958 // unchanged
4959 if (ST->getProcFamily() != AArch64Subtarget::Generic &&
4960 !ST->getSchedModel().isOutOfOrder()) {
4961 UP.Runtime = true;
4962 UP.Partial = true;
4963 UP.UnrollRemainder = true;
4964 UP.DefaultUnrollRuntimeCount = 4;
4965
4966 UP.UnrollAndJam = true;
4967 UP.UnrollAndJamInnerLoopThreshold = 60;
4968 }
4969 }
4970
getPeelingPreferences(Loop * L,ScalarEvolution & SE,TTI::PeelingPreferences & PP) const4971 void AArch64TTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE,
4972 TTI::PeelingPreferences &PP) const {
4973 BaseT::getPeelingPreferences(L, SE, PP);
4974 }
4975
getOrCreateResultFromMemIntrinsic(IntrinsicInst * Inst,Type * ExpectedType,bool CanCreate) const4976 Value *AArch64TTIImpl::getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst,
4977 Type *ExpectedType,
4978 bool CanCreate) const {
4979 switch (Inst->getIntrinsicID()) {
4980 default:
4981 return nullptr;
4982 case Intrinsic::aarch64_neon_st2:
4983 case Intrinsic::aarch64_neon_st3:
4984 case Intrinsic::aarch64_neon_st4: {
4985 // Create a struct type
4986 StructType *ST = dyn_cast<StructType>(ExpectedType);
4987 if (!CanCreate || !ST)
4988 return nullptr;
4989 unsigned NumElts = Inst->arg_size() - 1;
4990 if (ST->getNumElements() != NumElts)
4991 return nullptr;
4992 for (unsigned i = 0, e = NumElts; i != e; ++i) {
4993 if (Inst->getArgOperand(i)->getType() != ST->getElementType(i))
4994 return nullptr;
4995 }
4996 Value *Res = PoisonValue::get(ExpectedType);
4997 IRBuilder<> Builder(Inst);
4998 for (unsigned i = 0, e = NumElts; i != e; ++i) {
4999 Value *L = Inst->getArgOperand(i);
5000 Res = Builder.CreateInsertValue(Res, L, i);
5001 }
5002 return Res;
5003 }
5004 case Intrinsic::aarch64_neon_ld2:
5005 case Intrinsic::aarch64_neon_ld3:
5006 case Intrinsic::aarch64_neon_ld4:
5007 if (Inst->getType() == ExpectedType)
5008 return Inst;
5009 return nullptr;
5010 }
5011 }
5012
getTgtMemIntrinsic(IntrinsicInst * Inst,MemIntrinsicInfo & Info) const5013 bool AArch64TTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst,
5014 MemIntrinsicInfo &Info) const {
5015 switch (Inst->getIntrinsicID()) {
5016 default:
5017 break;
5018 case Intrinsic::aarch64_neon_ld2:
5019 case Intrinsic::aarch64_neon_ld3:
5020 case Intrinsic::aarch64_neon_ld4:
5021 Info.ReadMem = true;
5022 Info.WriteMem = false;
5023 Info.PtrVal = Inst->getArgOperand(0);
5024 break;
5025 case Intrinsic::aarch64_neon_st2:
5026 case Intrinsic::aarch64_neon_st3:
5027 case Intrinsic::aarch64_neon_st4:
5028 Info.ReadMem = false;
5029 Info.WriteMem = true;
5030 Info.PtrVal = Inst->getArgOperand(Inst->arg_size() - 1);
5031 break;
5032 }
5033
5034 switch (Inst->getIntrinsicID()) {
5035 default:
5036 return false;
5037 case Intrinsic::aarch64_neon_ld2:
5038 case Intrinsic::aarch64_neon_st2:
5039 Info.MatchingId = VECTOR_LDST_TWO_ELEMENTS;
5040 break;
5041 case Intrinsic::aarch64_neon_ld3:
5042 case Intrinsic::aarch64_neon_st3:
5043 Info.MatchingId = VECTOR_LDST_THREE_ELEMENTS;
5044 break;
5045 case Intrinsic::aarch64_neon_ld4:
5046 case Intrinsic::aarch64_neon_st4:
5047 Info.MatchingId = VECTOR_LDST_FOUR_ELEMENTS;
5048 break;
5049 }
5050 return true;
5051 }
5052
5053 /// See if \p I should be considered for address type promotion. We check if \p
5054 /// I is a sext with right type and used in memory accesses. If it used in a
5055 /// "complex" getelementptr, we allow it to be promoted without finding other
5056 /// sext instructions that sign extended the same initial value. A getelementptr
5057 /// is considered as "complex" if it has more than 2 operands.
shouldConsiderAddressTypePromotion(const Instruction & I,bool & AllowPromotionWithoutCommonHeader) const5058 bool AArch64TTIImpl::shouldConsiderAddressTypePromotion(
5059 const Instruction &I, bool &AllowPromotionWithoutCommonHeader) const {
5060 bool Considerable = false;
5061 AllowPromotionWithoutCommonHeader = false;
5062 if (!isa<SExtInst>(&I))
5063 return false;
5064 Type *ConsideredSExtType =
5065 Type::getInt64Ty(I.getParent()->getParent()->getContext());
5066 if (I.getType() != ConsideredSExtType)
5067 return false;
5068 // See if the sext is the one with the right type and used in at least one
5069 // GetElementPtrInst.
5070 for (const User *U : I.users()) {
5071 if (const GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(U)) {
5072 Considerable = true;
5073 // A getelementptr is considered as "complex" if it has more than 2
5074 // operands. We will promote a SExt used in such complex GEP as we
5075 // expect some computation to be merged if they are done on 64 bits.
5076 if (GEPInst->getNumOperands() > 2) {
5077 AllowPromotionWithoutCommonHeader = true;
5078 break;
5079 }
5080 }
5081 }
5082 return Considerable;
5083 }
5084
isLegalToVectorizeReduction(const RecurrenceDescriptor & RdxDesc,ElementCount VF) const5085 bool AArch64TTIImpl::isLegalToVectorizeReduction(
5086 const RecurrenceDescriptor &RdxDesc, ElementCount VF) const {
5087 if (!VF.isScalable())
5088 return true;
5089
5090 Type *Ty = RdxDesc.getRecurrenceType();
5091 if (Ty->isBFloatTy() || !isElementTypeLegalForScalableVector(Ty))
5092 return false;
5093
5094 switch (RdxDesc.getRecurrenceKind()) {
5095 case RecurKind::Add:
5096 case RecurKind::FAdd:
5097 case RecurKind::And:
5098 case RecurKind::Or:
5099 case RecurKind::Xor:
5100 case RecurKind::SMin:
5101 case RecurKind::SMax:
5102 case RecurKind::UMin:
5103 case RecurKind::UMax:
5104 case RecurKind::FMin:
5105 case RecurKind::FMax:
5106 case RecurKind::FMulAdd:
5107 case RecurKind::AnyOf:
5108 return true;
5109 default:
5110 return false;
5111 }
5112 }
5113
5114 InstructionCost
getMinMaxReductionCost(Intrinsic::ID IID,VectorType * Ty,FastMathFlags FMF,TTI::TargetCostKind CostKind) const5115 AArch64TTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty,
5116 FastMathFlags FMF,
5117 TTI::TargetCostKind CostKind) const {
5118 // The code-generator is currently not able to handle scalable vectors
5119 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
5120 // it. This change will be removed when code-generation for these types is
5121 // sufficiently reliable.
5122 if (auto *VTy = dyn_cast<ScalableVectorType>(Ty))
5123 if (VTy->getElementCount() == ElementCount::getScalable(1))
5124 return InstructionCost::getInvalid();
5125
5126 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty);
5127
5128 if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16())
5129 return BaseT::getMinMaxReductionCost(IID, Ty, FMF, CostKind);
5130
5131 InstructionCost LegalizationCost = 0;
5132 if (LT.first > 1) {
5133 Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext());
5134 IntrinsicCostAttributes Attrs(IID, LegalVTy, {LegalVTy, LegalVTy}, FMF);
5135 LegalizationCost = getIntrinsicInstrCost(Attrs, CostKind) * (LT.first - 1);
5136 }
5137
5138 return LegalizationCost + /*Cost of horizontal reduction*/ 2;
5139 }
5140
getArithmeticReductionCostSVE(unsigned Opcode,VectorType * ValTy,TTI::TargetCostKind CostKind) const5141 InstructionCost AArch64TTIImpl::getArithmeticReductionCostSVE(
5142 unsigned Opcode, VectorType *ValTy, TTI::TargetCostKind CostKind) const {
5143 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
5144 InstructionCost LegalizationCost = 0;
5145 if (LT.first > 1) {
5146 Type *LegalVTy = EVT(LT.second).getTypeForEVT(ValTy->getContext());
5147 LegalizationCost = getArithmeticInstrCost(Opcode, LegalVTy, CostKind);
5148 LegalizationCost *= LT.first - 1;
5149 }
5150
5151 int ISD = TLI->InstructionOpcodeToISD(Opcode);
5152 assert(ISD && "Invalid opcode");
5153 // Add the final reduction cost for the legal horizontal reduction
5154 switch (ISD) {
5155 case ISD::ADD:
5156 case ISD::AND:
5157 case ISD::OR:
5158 case ISD::XOR:
5159 case ISD::FADD:
5160 return LegalizationCost + 2;
5161 default:
5162 return InstructionCost::getInvalid();
5163 }
5164 }
5165
5166 InstructionCost
getArithmeticReductionCost(unsigned Opcode,VectorType * ValTy,std::optional<FastMathFlags> FMF,TTI::TargetCostKind CostKind) const5167 AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy,
5168 std::optional<FastMathFlags> FMF,
5169 TTI::TargetCostKind CostKind) const {
5170 // The code-generator is currently not able to handle scalable vectors
5171 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
5172 // it. This change will be removed when code-generation for these types is
5173 // sufficiently reliable.
5174 if (auto *VTy = dyn_cast<ScalableVectorType>(ValTy))
5175 if (VTy->getElementCount() == ElementCount::getScalable(1))
5176 return InstructionCost::getInvalid();
5177
5178 if (TTI::requiresOrderedReduction(FMF)) {
5179 if (auto *FixedVTy = dyn_cast<FixedVectorType>(ValTy)) {
5180 InstructionCost BaseCost =
5181 BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
5182 // Add on extra cost to reflect the extra overhead on some CPUs. We still
5183 // end up vectorizing for more computationally intensive loops.
5184 return BaseCost + FixedVTy->getNumElements();
5185 }
5186
5187 if (Opcode != Instruction::FAdd)
5188 return InstructionCost::getInvalid();
5189
5190 auto *VTy = cast<ScalableVectorType>(ValTy);
5191 InstructionCost Cost =
5192 getArithmeticInstrCost(Opcode, VTy->getScalarType(), CostKind);
5193 Cost *= getMaxNumElements(VTy->getElementCount());
5194 return Cost;
5195 }
5196
5197 if (isa<ScalableVectorType>(ValTy))
5198 return getArithmeticReductionCostSVE(Opcode, ValTy, CostKind);
5199
5200 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy);
5201 MVT MTy = LT.second;
5202 int ISD = TLI->InstructionOpcodeToISD(Opcode);
5203 assert(ISD && "Invalid opcode");
5204
5205 // Horizontal adds can use the 'addv' instruction. We model the cost of these
5206 // instructions as twice a normal vector add, plus 1 for each legalization
5207 // step (LT.first). This is the only arithmetic vector reduction operation for
5208 // which we have an instruction.
5209 // OR, XOR and AND costs should match the codegen from:
5210 // OR: llvm/test/CodeGen/AArch64/reduce-or.ll
5211 // XOR: llvm/test/CodeGen/AArch64/reduce-xor.ll
5212 // AND: llvm/test/CodeGen/AArch64/reduce-and.ll
5213 static const CostTblEntry CostTblNoPairwise[]{
5214 {ISD::ADD, MVT::v8i8, 2},
5215 {ISD::ADD, MVT::v16i8, 2},
5216 {ISD::ADD, MVT::v4i16, 2},
5217 {ISD::ADD, MVT::v8i16, 2},
5218 {ISD::ADD, MVT::v2i32, 2},
5219 {ISD::ADD, MVT::v4i32, 2},
5220 {ISD::ADD, MVT::v2i64, 2},
5221 {ISD::OR, MVT::v8i8, 15},
5222 {ISD::OR, MVT::v16i8, 17},
5223 {ISD::OR, MVT::v4i16, 7},
5224 {ISD::OR, MVT::v8i16, 9},
5225 {ISD::OR, MVT::v2i32, 3},
5226 {ISD::OR, MVT::v4i32, 5},
5227 {ISD::OR, MVT::v2i64, 3},
5228 {ISD::XOR, MVT::v8i8, 15},
5229 {ISD::XOR, MVT::v16i8, 17},
5230 {ISD::XOR, MVT::v4i16, 7},
5231 {ISD::XOR, MVT::v8i16, 9},
5232 {ISD::XOR, MVT::v2i32, 3},
5233 {ISD::XOR, MVT::v4i32, 5},
5234 {ISD::XOR, MVT::v2i64, 3},
5235 {ISD::AND, MVT::v8i8, 15},
5236 {ISD::AND, MVT::v16i8, 17},
5237 {ISD::AND, MVT::v4i16, 7},
5238 {ISD::AND, MVT::v8i16, 9},
5239 {ISD::AND, MVT::v2i32, 3},
5240 {ISD::AND, MVT::v4i32, 5},
5241 {ISD::AND, MVT::v2i64, 3},
5242 };
5243 switch (ISD) {
5244 default:
5245 break;
5246 case ISD::FADD:
5247 if (Type *EltTy = ValTy->getScalarType();
5248 // FIXME: For half types without fullfp16 support, this could extend and
5249 // use a fp32 faddp reduction but current codegen unrolls.
5250 MTy.isVector() && (EltTy->isFloatTy() || EltTy->isDoubleTy() ||
5251 (EltTy->isHalfTy() && ST->hasFullFP16()))) {
5252 const unsigned NElts = MTy.getVectorNumElements();
5253 if (ValTy->getElementCount().getFixedValue() >= 2 && NElts >= 2 &&
5254 isPowerOf2_32(NElts))
5255 // Reduction corresponding to series of fadd instructions is lowered to
5256 // series of faddp instructions. faddp has latency/throughput that
5257 // matches fadd instruction and hence, every faddp instruction can be
5258 // considered to have a relative cost = 1 with
5259 // CostKind = TCK_RecipThroughput.
5260 // An faddp will pairwise add vector elements, so the size of input
5261 // vector reduces by half every time, requiring
5262 // #(faddp instructions) = log2_32(NElts).
5263 return (LT.first - 1) + /*No of faddp instructions*/ Log2_32(NElts);
5264 }
5265 break;
5266 case ISD::ADD:
5267 if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy))
5268 return (LT.first - 1) + Entry->Cost;
5269 break;
5270 case ISD::XOR:
5271 case ISD::AND:
5272 case ISD::OR:
5273 const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy);
5274 if (!Entry)
5275 break;
5276 auto *ValVTy = cast<FixedVectorType>(ValTy);
5277 if (MTy.getVectorNumElements() <= ValVTy->getNumElements() &&
5278 isPowerOf2_32(ValVTy->getNumElements())) {
5279 InstructionCost ExtraCost = 0;
5280 if (LT.first != 1) {
5281 // Type needs to be split, so there is an extra cost of LT.first - 1
5282 // arithmetic ops.
5283 auto *Ty = FixedVectorType::get(ValTy->getElementType(),
5284 MTy.getVectorNumElements());
5285 ExtraCost = getArithmeticInstrCost(Opcode, Ty, CostKind);
5286 ExtraCost *= LT.first - 1;
5287 }
5288 // All and/or/xor of i1 will be lowered with maxv/minv/addv + fmov
5289 auto Cost = ValVTy->getElementType()->isIntegerTy(1) ? 2 : Entry->Cost;
5290 return Cost + ExtraCost;
5291 }
5292 break;
5293 }
5294 return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind);
5295 }
5296
getExtendedReductionCost(unsigned Opcode,bool IsUnsigned,Type * ResTy,VectorType * VecTy,std::optional<FastMathFlags> FMF,TTI::TargetCostKind CostKind) const5297 InstructionCost AArch64TTIImpl::getExtendedReductionCost(
5298 unsigned Opcode, bool IsUnsigned, Type *ResTy, VectorType *VecTy,
5299 std::optional<FastMathFlags> FMF, TTI::TargetCostKind CostKind) const {
5300 EVT VecVT = TLI->getValueType(DL, VecTy);
5301 EVT ResVT = TLI->getValueType(DL, ResTy);
5302
5303 if (Opcode == Instruction::Add && VecVT.isSimple() && ResVT.isSimple() &&
5304 VecVT.getSizeInBits() >= 64) {
5305 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(VecTy);
5306
5307 // The legal cases are:
5308 // UADDLV 8/16/32->32
5309 // UADDLP 32->64
5310 unsigned RevVTSize = ResVT.getSizeInBits();
5311 if (((LT.second == MVT::v8i8 || LT.second == MVT::v16i8) &&
5312 RevVTSize <= 32) ||
5313 ((LT.second == MVT::v4i16 || LT.second == MVT::v8i16) &&
5314 RevVTSize <= 32) ||
5315 ((LT.second == MVT::v2i32 || LT.second == MVT::v4i32) &&
5316 RevVTSize <= 64))
5317 return (LT.first - 1) * 2 + 2;
5318 }
5319
5320 return BaseT::getExtendedReductionCost(Opcode, IsUnsigned, ResTy, VecTy, FMF,
5321 CostKind);
5322 }
5323
5324 InstructionCost
getMulAccReductionCost(bool IsUnsigned,Type * ResTy,VectorType * VecTy,TTI::TargetCostKind CostKind) const5325 AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
5326 VectorType *VecTy,
5327 TTI::TargetCostKind CostKind) const {
5328 EVT VecVT = TLI->getValueType(DL, VecTy);
5329 EVT ResVT = TLI->getValueType(DL, ResTy);
5330
5331 if (ST->hasDotProd() && VecVT.isSimple() && ResVT.isSimple()) {
5332 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(VecTy);
5333
5334 // The legal cases with dotprod are
5335 // UDOT 8->32
5336 // Which requires an additional uaddv to sum the i32 values.
5337 if ((LT.second == MVT::v8i8 || LT.second == MVT::v16i8) &&
5338 ResVT == MVT::i32)
5339 return LT.first + 2;
5340 }
5341
5342 return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, CostKind);
5343 }
5344
5345 InstructionCost
getSpliceCost(VectorType * Tp,int Index,TTI::TargetCostKind CostKind) const5346 AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index,
5347 TTI::TargetCostKind CostKind) const {
5348 static const CostTblEntry ShuffleTbl[] = {
5349 { TTI::SK_Splice, MVT::nxv16i8, 1 },
5350 { TTI::SK_Splice, MVT::nxv8i16, 1 },
5351 { TTI::SK_Splice, MVT::nxv4i32, 1 },
5352 { TTI::SK_Splice, MVT::nxv2i64, 1 },
5353 { TTI::SK_Splice, MVT::nxv2f16, 1 },
5354 { TTI::SK_Splice, MVT::nxv4f16, 1 },
5355 { TTI::SK_Splice, MVT::nxv8f16, 1 },
5356 { TTI::SK_Splice, MVT::nxv2bf16, 1 },
5357 { TTI::SK_Splice, MVT::nxv4bf16, 1 },
5358 { TTI::SK_Splice, MVT::nxv8bf16, 1 },
5359 { TTI::SK_Splice, MVT::nxv2f32, 1 },
5360 { TTI::SK_Splice, MVT::nxv4f32, 1 },
5361 { TTI::SK_Splice, MVT::nxv2f64, 1 },
5362 };
5363
5364 // The code-generator is currently not able to handle scalable vectors
5365 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting
5366 // it. This change will be removed when code-generation for these types is
5367 // sufficiently reliable.
5368 if (Tp->getElementCount() == ElementCount::getScalable(1))
5369 return InstructionCost::getInvalid();
5370
5371 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
5372 Type *LegalVTy = EVT(LT.second).getTypeForEVT(Tp->getContext());
5373 EVT PromotedVT = LT.second.getScalarType() == MVT::i1
5374 ? TLI->getPromotedVTForPredicate(EVT(LT.second))
5375 : LT.second;
5376 Type *PromotedVTy = EVT(PromotedVT).getTypeForEVT(Tp->getContext());
5377 InstructionCost LegalizationCost = 0;
5378 if (Index < 0) {
5379 LegalizationCost =
5380 getCmpSelInstrCost(Instruction::ICmp, PromotedVTy, PromotedVTy,
5381 CmpInst::BAD_ICMP_PREDICATE, CostKind) +
5382 getCmpSelInstrCost(Instruction::Select, PromotedVTy, LegalVTy,
5383 CmpInst::BAD_ICMP_PREDICATE, CostKind);
5384 }
5385
5386 // Predicated splice are promoted when lowering. See AArch64ISelLowering.cpp
5387 // Cost performed on a promoted type.
5388 if (LT.second.getScalarType() == MVT::i1) {
5389 LegalizationCost +=
5390 getCastInstrCost(Instruction::ZExt, PromotedVTy, LegalVTy,
5391 TTI::CastContextHint::None, CostKind) +
5392 getCastInstrCost(Instruction::Trunc, LegalVTy, PromotedVTy,
5393 TTI::CastContextHint::None, CostKind);
5394 }
5395 const auto *Entry =
5396 CostTableLookup(ShuffleTbl, TTI::SK_Splice, PromotedVT.getSimpleVT());
5397 assert(Entry && "Illegal Type for Splice");
5398 LegalizationCost += Entry->Cost;
5399 return LegalizationCost * LT.first;
5400 }
5401
getPartialReductionCost(unsigned Opcode,Type * InputTypeA,Type * InputTypeB,Type * AccumType,ElementCount VF,TTI::PartialReductionExtendKind OpAExtend,TTI::PartialReductionExtendKind OpBExtend,std::optional<unsigned> BinOp,TTI::TargetCostKind CostKind) const5402 InstructionCost AArch64TTIImpl::getPartialReductionCost(
5403 unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
5404 ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
5405 TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
5406 TTI::TargetCostKind CostKind) const {
5407 InstructionCost Invalid = InstructionCost::getInvalid();
5408 InstructionCost Cost(TTI::TCC_Basic);
5409
5410 if (CostKind != TTI::TCK_RecipThroughput)
5411 return Invalid;
5412
5413 // Sub opcodes currently only occur in chained cases.
5414 // Independent partial reduction subtractions are still costed as an add
5415 if ((Opcode != Instruction::Add && Opcode != Instruction::Sub) ||
5416 OpAExtend == TTI::PR_None)
5417 return Invalid;
5418
5419 // We only support multiply binary operations for now, and for muls we
5420 // require the types being extended to be the same.
5421 // NOTE: For muls AArch64 supports lowering mixed extensions to a usdot but
5422 // only if the i8mm or sve/streaming features are available.
5423 if (BinOp && (*BinOp != Instruction::Mul || InputTypeA != InputTypeB ||
5424 OpBExtend == TTI::PR_None ||
5425 (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
5426 !ST->isSVEorStreamingSVEAvailable())))
5427 return Invalid;
5428 assert((BinOp || (OpBExtend == TTI::PR_None && !InputTypeB)) &&
5429 "Unexpected values for OpBExtend or InputTypeB");
5430
5431 EVT InputEVT = EVT::getEVT(InputTypeA);
5432 EVT AccumEVT = EVT::getEVT(AccumType);
5433
5434 unsigned VFMinValue = VF.getKnownMinValue();
5435
5436 if (VF.isScalable()) {
5437 if (!ST->isSVEorStreamingSVEAvailable())
5438 return Invalid;
5439
5440 // Don't accept a partial reduction if the scaled accumulator is vscale x 1,
5441 // since we can't lower that type.
5442 unsigned Scale =
5443 AccumEVT.getScalarSizeInBits() / InputEVT.getScalarSizeInBits();
5444 if (VFMinValue == Scale)
5445 return Invalid;
5446 }
5447 if (VF.isFixed() &&
5448 (!ST->isNeonAvailable() || !ST->hasDotProd() || AccumEVT == MVT::i64))
5449 return Invalid;
5450
5451 if (InputEVT == MVT::i8) {
5452 switch (VFMinValue) {
5453 default:
5454 return Invalid;
5455 case 8:
5456 if (AccumEVT == MVT::i32)
5457 Cost *= 2;
5458 else if (AccumEVT != MVT::i64)
5459 return Invalid;
5460 break;
5461 case 16:
5462 if (AccumEVT == MVT::i64)
5463 Cost *= 2;
5464 else if (AccumEVT != MVT::i32)
5465 return Invalid;
5466 break;
5467 }
5468 } else if (InputEVT == MVT::i16) {
5469 // FIXME: Allow i32 accumulator but increase cost, as we would extend
5470 // it to i64.
5471 if (VFMinValue != 8 || AccumEVT != MVT::i64)
5472 return Invalid;
5473 } else
5474 return Invalid;
5475
5476 return Cost;
5477 }
5478
5479 InstructionCost
getShuffleCost(TTI::ShuffleKind Kind,VectorType * DstTy,VectorType * SrcTy,ArrayRef<int> Mask,TTI::TargetCostKind CostKind,int Index,VectorType * SubTp,ArrayRef<const Value * > Args,const Instruction * CxtI) const5480 AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy,
5481 VectorType *SrcTy, ArrayRef<int> Mask,
5482 TTI::TargetCostKind CostKind, int Index,
5483 VectorType *SubTp, ArrayRef<const Value *> Args,
5484 const Instruction *CxtI) const {
5485 assert((Mask.empty() || DstTy->isScalableTy() ||
5486 Mask.size() == DstTy->getElementCount().getKnownMinValue()) &&
5487 "Expected the Mask to match the return size if given");
5488 assert(SrcTy->getScalarType() == DstTy->getScalarType() &&
5489 "Expected the same scalar types");
5490 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(SrcTy);
5491
5492 // If we have a Mask, and the LT is being legalized somehow, split the Mask
5493 // into smaller vectors and sum the cost of each shuffle.
5494 if (!Mask.empty() && isa<FixedVectorType>(SrcTy) && LT.second.isVector() &&
5495 LT.second.getScalarSizeInBits() * Mask.size() > 128 &&
5496 SrcTy->getScalarSizeInBits() == LT.second.getScalarSizeInBits() &&
5497 Mask.size() > LT.second.getVectorNumElements() && !Index && !SubTp) {
5498 // Check for LD3/LD4 instructions, which are represented in llvm IR as
5499 // deinterleaving-shuffle(load). The shuffle cost could potentially be free,
5500 // but we model it with a cost of LT.first so that LD3/LD4 have a higher
5501 // cost than just the load.
5502 if (Args.size() >= 1 && isa<LoadInst>(Args[0]) &&
5503 (ShuffleVectorInst::isDeInterleaveMaskOfFactor(Mask, 3) ||
5504 ShuffleVectorInst::isDeInterleaveMaskOfFactor(Mask, 4)))
5505 return std::max<InstructionCost>(1, LT.first / 4);
5506
5507 // Check for ST3/ST4 instructions, which are represented in llvm IR as
5508 // store(interleaving-shuffle). The shuffle cost could potentially be free,
5509 // but we model it with a cost of LT.first so that ST3/ST4 have a higher
5510 // cost than just the store.
5511 if (CxtI && CxtI->hasOneUse() && isa<StoreInst>(*CxtI->user_begin()) &&
5512 (ShuffleVectorInst::isInterleaveMask(
5513 Mask, 4, SrcTy->getElementCount().getKnownMinValue() * 2) ||
5514 ShuffleVectorInst::isInterleaveMask(
5515 Mask, 3, SrcTy->getElementCount().getKnownMinValue() * 2)))
5516 return LT.first;
5517
5518 unsigned TpNumElts = Mask.size();
5519 unsigned LTNumElts = LT.second.getVectorNumElements();
5520 unsigned NumVecs = (TpNumElts + LTNumElts - 1) / LTNumElts;
5521 VectorType *NTp = VectorType::get(SrcTy->getScalarType(),
5522 LT.second.getVectorElementCount());
5523 InstructionCost Cost;
5524 std::map<std::tuple<unsigned, unsigned, SmallVector<int>>, InstructionCost>
5525 PreviousCosts;
5526 for (unsigned N = 0; N < NumVecs; N++) {
5527 SmallVector<int> NMask;
5528 // Split the existing mask into chunks of size LTNumElts. Track the source
5529 // sub-vectors to ensure the result has at most 2 inputs.
5530 unsigned Source1 = -1U, Source2 = -1U;
5531 unsigned NumSources = 0;
5532 for (unsigned E = 0; E < LTNumElts; E++) {
5533 int MaskElt = (N * LTNumElts + E < TpNumElts) ? Mask[N * LTNumElts + E]
5534 : PoisonMaskElem;
5535 if (MaskElt < 0) {
5536 NMask.push_back(PoisonMaskElem);
5537 continue;
5538 }
5539
5540 // Calculate which source from the input this comes from and whether it
5541 // is new to us.
5542 unsigned Source = MaskElt / LTNumElts;
5543 if (NumSources == 0) {
5544 Source1 = Source;
5545 NumSources = 1;
5546 } else if (NumSources == 1 && Source != Source1) {
5547 Source2 = Source;
5548 NumSources = 2;
5549 } else if (NumSources >= 2 && Source != Source1 && Source != Source2) {
5550 NumSources++;
5551 }
5552
5553 // Add to the new mask. For the NumSources>2 case these are not correct,
5554 // but are only used for the modular lane number.
5555 if (Source == Source1)
5556 NMask.push_back(MaskElt % LTNumElts);
5557 else if (Source == Source2)
5558 NMask.push_back(MaskElt % LTNumElts + LTNumElts);
5559 else
5560 NMask.push_back(MaskElt % LTNumElts);
5561 }
5562 // Check if we have already generated this sub-shuffle, which means we
5563 // will have already generated the output. For example a <16 x i32> splat
5564 // will be the same sub-splat 4 times, which only needs to be generated
5565 // once and reused.
5566 auto Result =
5567 PreviousCosts.insert({std::make_tuple(Source1, Source2, NMask), 0});
5568 // Check if it was already in the map (already costed).
5569 if (!Result.second)
5570 continue;
5571 // If the sub-mask has at most 2 input sub-vectors then re-cost it using
5572 // getShuffleCost. If not then cost it using the worst case as the number
5573 // of element moves into a new vector.
5574 InstructionCost NCost =
5575 NumSources <= 2
5576 ? getShuffleCost(NumSources <= 1 ? TTI::SK_PermuteSingleSrc
5577 : TTI::SK_PermuteTwoSrc,
5578 NTp, NTp, NMask, CostKind, 0, nullptr, Args,
5579 CxtI)
5580 : LTNumElts;
5581 Result.first->second = NCost;
5582 Cost += NCost;
5583 }
5584 return Cost;
5585 }
5586
5587 Kind = improveShuffleKindFromMask(Kind, Mask, SrcTy, Index, SubTp);
5588 bool IsExtractSubvector = Kind == TTI::SK_ExtractSubvector;
5589 // A subvector extract can be implemented with an ext (or trivial extract, if
5590 // from lane 0). This currently only handles low or high extracts to prevent
5591 // SLP vectorizer regressions.
5592 if (IsExtractSubvector && LT.second.isFixedLengthVector()) {
5593 if (LT.second.is128BitVector() &&
5594 cast<FixedVectorType>(SubTp)->getNumElements() ==
5595 LT.second.getVectorNumElements() / 2) {
5596 if (Index == 0)
5597 return 0;
5598 if (Index == (int)LT.second.getVectorNumElements() / 2)
5599 return 1;
5600 }
5601 Kind = TTI::SK_PermuteSingleSrc;
5602 }
5603 // FIXME: This was added to keep the costs equal when adding DstTys. Update
5604 // the code to handle length-changing shuffles.
5605 if (Kind == TTI::SK_InsertSubvector) {
5606 LT = getTypeLegalizationCost(DstTy);
5607 SrcTy = DstTy;
5608 }
5609
5610 // Segmented shuffle matching.
5611 if (Kind == TTI::SK_PermuteSingleSrc && isa<FixedVectorType>(SrcTy) &&
5612 !Mask.empty() && SrcTy->getPrimitiveSizeInBits().isNonZero() &&
5613 SrcTy->getPrimitiveSizeInBits().isKnownMultipleOf(
5614 AArch64::SVEBitsPerBlock)) {
5615
5616 FixedVectorType *VTy = cast<FixedVectorType>(SrcTy);
5617 unsigned Segments =
5618 VTy->getPrimitiveSizeInBits() / AArch64::SVEBitsPerBlock;
5619 unsigned SegmentElts = VTy->getNumElements() / Segments;
5620
5621 // dupq zd.t, zn.t[idx]
5622 if ((ST->hasSVE2p1() || ST->hasSME2p1()) &&
5623 ST->isSVEorStreamingSVEAvailable() &&
5624 isDUPQMask(Mask, Segments, SegmentElts))
5625 return LT.first;
5626
5627 // mov zd.q, vn
5628 if (ST->isSVEorStreamingSVEAvailable() &&
5629 isDUPFirstSegmentMask(Mask, Segments, SegmentElts))
5630 return LT.first;
5631 }
5632
5633 // Check for broadcast loads, which are supported by the LD1R instruction.
5634 // In terms of code-size, the shuffle vector is free when a load + dup get
5635 // folded into a LD1R. That's what we check and return here. For performance
5636 // and reciprocal throughput, a LD1R is not completely free. In this case, we
5637 // return the cost for the broadcast below (i.e. 1 for most/all types), so
5638 // that we model the load + dup sequence slightly higher because LD1R is a
5639 // high latency instruction.
5640 if (CostKind == TTI::TCK_CodeSize && Kind == TTI::SK_Broadcast) {
5641 bool IsLoad = !Args.empty() && isa<LoadInst>(Args[0]);
5642 if (IsLoad && LT.second.isVector() &&
5643 isLegalBroadcastLoad(SrcTy->getElementType(),
5644 LT.second.getVectorElementCount()))
5645 return 0;
5646 }
5647
5648 // If we have 4 elements for the shuffle and a Mask, get the cost straight
5649 // from the perfect shuffle tables.
5650 if (Mask.size() == 4 &&
5651 SrcTy->getElementCount() == ElementCount::getFixed(4) &&
5652 (SrcTy->getScalarSizeInBits() == 16 ||
5653 SrcTy->getScalarSizeInBits() == 32) &&
5654 all_of(Mask, [](int E) { return E < 8; }))
5655 return getPerfectShuffleCost(Mask);
5656
5657 // Check for identity masks, which we can treat as free.
5658 if (!Mask.empty() && LT.second.isFixedLengthVector() &&
5659 (Kind == TTI::SK_PermuteTwoSrc || Kind == TTI::SK_PermuteSingleSrc) &&
5660 all_of(enumerate(Mask), [](const auto &M) {
5661 return M.value() < 0 || M.value() == (int)M.index();
5662 }))
5663 return 0;
5664
5665 // Check for other shuffles that are not SK_ kinds but we have native
5666 // instructions for, for example ZIP and UZP.
5667 unsigned Unused;
5668 if (LT.second.isFixedLengthVector() &&
5669 LT.second.getVectorNumElements() == Mask.size() &&
5670 (Kind == TTI::SK_PermuteTwoSrc || Kind == TTI::SK_PermuteSingleSrc) &&
5671 (isZIPMask(Mask, LT.second.getVectorNumElements(), Unused) ||
5672 isUZPMask(Mask, LT.second.getVectorNumElements(), Unused) ||
5673 isREVMask(Mask, LT.second.getScalarSizeInBits(),
5674 LT.second.getVectorNumElements(), 16) ||
5675 isREVMask(Mask, LT.second.getScalarSizeInBits(),
5676 LT.second.getVectorNumElements(), 32) ||
5677 isREVMask(Mask, LT.second.getScalarSizeInBits(),
5678 LT.second.getVectorNumElements(), 64) ||
5679 // Check for non-zero lane splats
5680 all_of(drop_begin(Mask),
5681 [&Mask](int M) { return M < 0 || M == Mask[0]; })))
5682 return 1;
5683
5684 if (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Transpose ||
5685 Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc ||
5686 Kind == TTI::SK_Reverse || Kind == TTI::SK_Splice) {
5687 static const CostTblEntry ShuffleTbl[] = {
5688 // Broadcast shuffle kinds can be performed with 'dup'.
5689 {TTI::SK_Broadcast, MVT::v8i8, 1},
5690 {TTI::SK_Broadcast, MVT::v16i8, 1},
5691 {TTI::SK_Broadcast, MVT::v4i16, 1},
5692 {TTI::SK_Broadcast, MVT::v8i16, 1},
5693 {TTI::SK_Broadcast, MVT::v2i32, 1},
5694 {TTI::SK_Broadcast, MVT::v4i32, 1},
5695 {TTI::SK_Broadcast, MVT::v2i64, 1},
5696 {TTI::SK_Broadcast, MVT::v4f16, 1},
5697 {TTI::SK_Broadcast, MVT::v8f16, 1},
5698 {TTI::SK_Broadcast, MVT::v4bf16, 1},
5699 {TTI::SK_Broadcast, MVT::v8bf16, 1},
5700 {TTI::SK_Broadcast, MVT::v2f32, 1},
5701 {TTI::SK_Broadcast, MVT::v4f32, 1},
5702 {TTI::SK_Broadcast, MVT::v2f64, 1},
5703 // Transpose shuffle kinds can be performed with 'trn1/trn2' and
5704 // 'zip1/zip2' instructions.
5705 {TTI::SK_Transpose, MVT::v8i8, 1},
5706 {TTI::SK_Transpose, MVT::v16i8, 1},
5707 {TTI::SK_Transpose, MVT::v4i16, 1},
5708 {TTI::SK_Transpose, MVT::v8i16, 1},
5709 {TTI::SK_Transpose, MVT::v2i32, 1},
5710 {TTI::SK_Transpose, MVT::v4i32, 1},
5711 {TTI::SK_Transpose, MVT::v2i64, 1},
5712 {TTI::SK_Transpose, MVT::v4f16, 1},
5713 {TTI::SK_Transpose, MVT::v8f16, 1},
5714 {TTI::SK_Transpose, MVT::v4bf16, 1},
5715 {TTI::SK_Transpose, MVT::v8bf16, 1},
5716 {TTI::SK_Transpose, MVT::v2f32, 1},
5717 {TTI::SK_Transpose, MVT::v4f32, 1},
5718 {TTI::SK_Transpose, MVT::v2f64, 1},
5719 // Select shuffle kinds.
5720 // TODO: handle vXi8/vXi16.
5721 {TTI::SK_Select, MVT::v2i32, 1}, // mov.
5722 {TTI::SK_Select, MVT::v4i32, 2}, // rev+trn (or similar).
5723 {TTI::SK_Select, MVT::v2i64, 1}, // mov.
5724 {TTI::SK_Select, MVT::v2f32, 1}, // mov.
5725 {TTI::SK_Select, MVT::v4f32, 2}, // rev+trn (or similar).
5726 {TTI::SK_Select, MVT::v2f64, 1}, // mov.
5727 // PermuteSingleSrc shuffle kinds.
5728 {TTI::SK_PermuteSingleSrc, MVT::v2i32, 1}, // mov.
5729 {TTI::SK_PermuteSingleSrc, MVT::v4i32, 3}, // perfectshuffle worst case.
5730 {TTI::SK_PermuteSingleSrc, MVT::v2i64, 1}, // mov.
5731 {TTI::SK_PermuteSingleSrc, MVT::v2f32, 1}, // mov.
5732 {TTI::SK_PermuteSingleSrc, MVT::v4f32, 3}, // perfectshuffle worst case.
5733 {TTI::SK_PermuteSingleSrc, MVT::v2f64, 1}, // mov.
5734 {TTI::SK_PermuteSingleSrc, MVT::v4i16, 3}, // perfectshuffle worst case.
5735 {TTI::SK_PermuteSingleSrc, MVT::v4f16, 3}, // perfectshuffle worst case.
5736 {TTI::SK_PermuteSingleSrc, MVT::v4bf16, 3}, // same
5737 {TTI::SK_PermuteSingleSrc, MVT::v8i16, 8}, // constpool + load + tbl
5738 {TTI::SK_PermuteSingleSrc, MVT::v8f16, 8}, // constpool + load + tbl
5739 {TTI::SK_PermuteSingleSrc, MVT::v8bf16, 8}, // constpool + load + tbl
5740 {TTI::SK_PermuteSingleSrc, MVT::v8i8, 8}, // constpool + load + tbl
5741 {TTI::SK_PermuteSingleSrc, MVT::v16i8, 8}, // constpool + load + tbl
5742 // Reverse can be lowered with `rev`.
5743 {TTI::SK_Reverse, MVT::v2i32, 1}, // REV64
5744 {TTI::SK_Reverse, MVT::v4i32, 2}, // REV64; EXT
5745 {TTI::SK_Reverse, MVT::v2i64, 1}, // EXT
5746 {TTI::SK_Reverse, MVT::v2f32, 1}, // REV64
5747 {TTI::SK_Reverse, MVT::v4f32, 2}, // REV64; EXT
5748 {TTI::SK_Reverse, MVT::v2f64, 1}, // EXT
5749 {TTI::SK_Reverse, MVT::v8f16, 2}, // REV64; EXT
5750 {TTI::SK_Reverse, MVT::v8bf16, 2}, // REV64; EXT
5751 {TTI::SK_Reverse, MVT::v8i16, 2}, // REV64; EXT
5752 {TTI::SK_Reverse, MVT::v16i8, 2}, // REV64; EXT
5753 {TTI::SK_Reverse, MVT::v4f16, 1}, // REV64
5754 {TTI::SK_Reverse, MVT::v4bf16, 1}, // REV64
5755 {TTI::SK_Reverse, MVT::v4i16, 1}, // REV64
5756 {TTI::SK_Reverse, MVT::v8i8, 1}, // REV64
5757 // Splice can all be lowered as `ext`.
5758 {TTI::SK_Splice, MVT::v2i32, 1},
5759 {TTI::SK_Splice, MVT::v4i32, 1},
5760 {TTI::SK_Splice, MVT::v2i64, 1},
5761 {TTI::SK_Splice, MVT::v2f32, 1},
5762 {TTI::SK_Splice, MVT::v4f32, 1},
5763 {TTI::SK_Splice, MVT::v2f64, 1},
5764 {TTI::SK_Splice, MVT::v8f16, 1},
5765 {TTI::SK_Splice, MVT::v8bf16, 1},
5766 {TTI::SK_Splice, MVT::v8i16, 1},
5767 {TTI::SK_Splice, MVT::v16i8, 1},
5768 {TTI::SK_Splice, MVT::v4f16, 1},
5769 {TTI::SK_Splice, MVT::v4bf16, 1},
5770 {TTI::SK_Splice, MVT::v4i16, 1},
5771 {TTI::SK_Splice, MVT::v8i8, 1},
5772 // Broadcast shuffle kinds for scalable vectors
5773 {TTI::SK_Broadcast, MVT::nxv16i8, 1},
5774 {TTI::SK_Broadcast, MVT::nxv8i16, 1},
5775 {TTI::SK_Broadcast, MVT::nxv4i32, 1},
5776 {TTI::SK_Broadcast, MVT::nxv2i64, 1},
5777 {TTI::SK_Broadcast, MVT::nxv2f16, 1},
5778 {TTI::SK_Broadcast, MVT::nxv4f16, 1},
5779 {TTI::SK_Broadcast, MVT::nxv8f16, 1},
5780 {TTI::SK_Broadcast, MVT::nxv2bf16, 1},
5781 {TTI::SK_Broadcast, MVT::nxv4bf16, 1},
5782 {TTI::SK_Broadcast, MVT::nxv8bf16, 1},
5783 {TTI::SK_Broadcast, MVT::nxv2f32, 1},
5784 {TTI::SK_Broadcast, MVT::nxv4f32, 1},
5785 {TTI::SK_Broadcast, MVT::nxv2f64, 1},
5786 {TTI::SK_Broadcast, MVT::nxv16i1, 1},
5787 {TTI::SK_Broadcast, MVT::nxv8i1, 1},
5788 {TTI::SK_Broadcast, MVT::nxv4i1, 1},
5789 {TTI::SK_Broadcast, MVT::nxv2i1, 1},
5790 // Handle the cases for vector.reverse with scalable vectors
5791 {TTI::SK_Reverse, MVT::nxv16i8, 1},
5792 {TTI::SK_Reverse, MVT::nxv8i16, 1},
5793 {TTI::SK_Reverse, MVT::nxv4i32, 1},
5794 {TTI::SK_Reverse, MVT::nxv2i64, 1},
5795 {TTI::SK_Reverse, MVT::nxv2f16, 1},
5796 {TTI::SK_Reverse, MVT::nxv4f16, 1},
5797 {TTI::SK_Reverse, MVT::nxv8f16, 1},
5798 {TTI::SK_Reverse, MVT::nxv2bf16, 1},
5799 {TTI::SK_Reverse, MVT::nxv4bf16, 1},
5800 {TTI::SK_Reverse, MVT::nxv8bf16, 1},
5801 {TTI::SK_Reverse, MVT::nxv2f32, 1},
5802 {TTI::SK_Reverse, MVT::nxv4f32, 1},
5803 {TTI::SK_Reverse, MVT::nxv2f64, 1},
5804 {TTI::SK_Reverse, MVT::nxv16i1, 1},
5805 {TTI::SK_Reverse, MVT::nxv8i1, 1},
5806 {TTI::SK_Reverse, MVT::nxv4i1, 1},
5807 {TTI::SK_Reverse, MVT::nxv2i1, 1},
5808 };
5809 if (const auto *Entry = CostTableLookup(ShuffleTbl, Kind, LT.second))
5810 return LT.first * Entry->Cost;
5811 }
5812
5813 if (Kind == TTI::SK_Splice && isa<ScalableVectorType>(SrcTy))
5814 return getSpliceCost(SrcTy, Index, CostKind);
5815
5816 // Inserting a subvector can often be done with either a D, S or H register
5817 // move, so long as the inserted vector is "aligned".
5818 if (Kind == TTI::SK_InsertSubvector && LT.second.isFixedLengthVector() &&
5819 LT.second.getSizeInBits() <= 128 && SubTp) {
5820 std::pair<InstructionCost, MVT> SubLT = getTypeLegalizationCost(SubTp);
5821 if (SubLT.second.isVector()) {
5822 int NumElts = LT.second.getVectorNumElements();
5823 int NumSubElts = SubLT.second.getVectorNumElements();
5824 if ((Index % NumSubElts) == 0 && (NumElts % NumSubElts) == 0)
5825 return SubLT.first;
5826 }
5827 }
5828
5829 // Restore optimal kind.
5830 if (IsExtractSubvector)
5831 Kind = TTI::SK_ExtractSubvector;
5832 return BaseT::getShuffleCost(Kind, DstTy, SrcTy, Mask, CostKind, Index, SubTp,
5833 Args, CxtI);
5834 }
5835
containsDecreasingPointers(Loop * TheLoop,PredicatedScalarEvolution * PSE)5836 static bool containsDecreasingPointers(Loop *TheLoop,
5837 PredicatedScalarEvolution *PSE) {
5838 const auto &Strides = DenseMap<Value *, const SCEV *>();
5839 for (BasicBlock *BB : TheLoop->blocks()) {
5840 // Scan the instructions in the block and look for addresses that are
5841 // consecutive and decreasing.
5842 for (Instruction &I : *BB) {
5843 if (isa<LoadInst>(&I) || isa<StoreInst>(&I)) {
5844 Value *Ptr = getLoadStorePointerOperand(&I);
5845 Type *AccessTy = getLoadStoreType(&I);
5846 if (getPtrStride(*PSE, AccessTy, Ptr, TheLoop, Strides, /*Assume=*/true,
5847 /*ShouldCheckWrap=*/false)
5848 .value_or(0) < 0)
5849 return true;
5850 }
5851 }
5852 }
5853 return false;
5854 }
5855
preferFixedOverScalableIfEqualCost() const5856 bool AArch64TTIImpl::preferFixedOverScalableIfEqualCost() const {
5857 if (SVEPreferFixedOverScalableIfEqualCost.getNumOccurrences())
5858 return SVEPreferFixedOverScalableIfEqualCost;
5859 return ST->useFixedOverScalableIfEqualCost();
5860 }
5861
getEpilogueVectorizationMinVF() const5862 unsigned AArch64TTIImpl::getEpilogueVectorizationMinVF() const {
5863 return ST->getEpilogueVectorizationMinVF();
5864 }
5865
preferPredicateOverEpilogue(TailFoldingInfo * TFI) const5866 bool AArch64TTIImpl::preferPredicateOverEpilogue(TailFoldingInfo *TFI) const {
5867 if (!ST->hasSVE())
5868 return false;
5869
5870 // We don't currently support vectorisation with interleaving for SVE - with
5871 // such loops we're better off not using tail-folding. This gives us a chance
5872 // to fall back on fixed-width vectorisation using NEON's ld2/st2/etc.
5873 if (TFI->IAI->hasGroups())
5874 return false;
5875
5876 TailFoldingOpts Required = TailFoldingOpts::Disabled;
5877 if (TFI->LVL->getReductionVars().size())
5878 Required |= TailFoldingOpts::Reductions;
5879 if (TFI->LVL->getFixedOrderRecurrences().size())
5880 Required |= TailFoldingOpts::Recurrences;
5881
5882 // We call this to discover whether any load/store pointers in the loop have
5883 // negative strides. This will require extra work to reverse the loop
5884 // predicate, which may be expensive.
5885 if (containsDecreasingPointers(TFI->LVL->getLoop(),
5886 TFI->LVL->getPredicatedScalarEvolution()))
5887 Required |= TailFoldingOpts::Reverse;
5888 if (Required == TailFoldingOpts::Disabled)
5889 Required |= TailFoldingOpts::Simple;
5890
5891 if (!TailFoldingOptionLoc.satisfies(ST->getSVETailFoldingDefaultOpts(),
5892 Required))
5893 return false;
5894
5895 // Don't tail-fold for tight loops where we would be better off interleaving
5896 // with an unpredicated loop.
5897 unsigned NumInsns = 0;
5898 for (BasicBlock *BB : TFI->LVL->getLoop()->blocks()) {
5899 NumInsns += BB->sizeWithoutDebug();
5900 }
5901
5902 // We expect 4 of these to be a IV PHI, IV add, IV compare and branch.
5903 return NumInsns >= SVETailFoldInsnThreshold;
5904 }
5905
5906 InstructionCost
getScalingFactorCost(Type * Ty,GlobalValue * BaseGV,StackOffset BaseOffset,bool HasBaseReg,int64_t Scale,unsigned AddrSpace) const5907 AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV,
5908 StackOffset BaseOffset, bool HasBaseReg,
5909 int64_t Scale, unsigned AddrSpace) const {
5910 // Scaling factors are not free at all.
5911 // Operands | Rt Latency
5912 // -------------------------------------------
5913 // Rt, [Xn, Xm] | 4
5914 // -------------------------------------------
5915 // Rt, [Xn, Xm, lsl #imm] | Rn: 4 Rm: 5
5916 // Rt, [Xn, Wm, <extend> #imm] |
5917 TargetLoweringBase::AddrMode AM;
5918 AM.BaseGV = BaseGV;
5919 AM.BaseOffs = BaseOffset.getFixed();
5920 AM.HasBaseReg = HasBaseReg;
5921 AM.Scale = Scale;
5922 AM.ScalableOffset = BaseOffset.getScalable();
5923 if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace))
5924 // Scale represents reg2 * scale, thus account for 1 if
5925 // it is not equal to 0 or 1.
5926 return AM.Scale != 0 && AM.Scale != 1;
5927 return InstructionCost::getInvalid();
5928 }
5929
shouldTreatInstructionLikeSelect(const Instruction * I) const5930 bool AArch64TTIImpl::shouldTreatInstructionLikeSelect(
5931 const Instruction *I) const {
5932 if (EnableOrLikeSelectOpt) {
5933 // For the binary operators (e.g. or) we need to be more careful than
5934 // selects, here we only transform them if they are already at a natural
5935 // break point in the code - the end of a block with an unconditional
5936 // terminator.
5937 if (I->getOpcode() == Instruction::Or &&
5938 isa<BranchInst>(I->getNextNode()) &&
5939 cast<BranchInst>(I->getNextNode())->isUnconditional())
5940 return true;
5941
5942 if (I->getOpcode() == Instruction::Add ||
5943 I->getOpcode() == Instruction::Sub)
5944 return true;
5945 }
5946 return BaseT::shouldTreatInstructionLikeSelect(I);
5947 }
5948
isLSRCostLess(const TargetTransformInfo::LSRCost & C1,const TargetTransformInfo::LSRCost & C2) const5949 bool AArch64TTIImpl::isLSRCostLess(
5950 const TargetTransformInfo::LSRCost &C1,
5951 const TargetTransformInfo::LSRCost &C2) const {
5952 // AArch64 specific here is adding the number of instructions to the
5953 // comparison (though not as the first consideration, as some targets do)
5954 // along with changing the priority of the base additions.
5955 // TODO: Maybe a more nuanced tradeoff between instruction count
5956 // and number of registers? To be investigated at a later date.
5957 if (EnableLSRCostOpt)
5958 return std::tie(C1.NumRegs, C1.Insns, C1.NumBaseAdds, C1.AddRecCost,
5959 C1.NumIVMuls, C1.ScaleCost, C1.ImmCost, C1.SetupCost) <
5960 std::tie(C2.NumRegs, C2.Insns, C2.NumBaseAdds, C2.AddRecCost,
5961 C2.NumIVMuls, C2.ScaleCost, C2.ImmCost, C2.SetupCost);
5962
5963 return TargetTransformInfoImplBase::isLSRCostLess(C1, C2);
5964 }
5965
isSplatShuffle(Value * V)5966 static bool isSplatShuffle(Value *V) {
5967 if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V))
5968 return all_equal(Shuf->getShuffleMask());
5969 return false;
5970 }
5971
5972 /// Check if both Op1 and Op2 are shufflevector extracts of either the lower
5973 /// or upper half of the vector elements.
areExtractShuffleVectors(Value * Op1,Value * Op2,bool AllowSplat=false)5974 static bool areExtractShuffleVectors(Value *Op1, Value *Op2,
5975 bool AllowSplat = false) {
5976 // Scalable types can't be extract shuffle vectors.
5977 if (Op1->getType()->isScalableTy() || Op2->getType()->isScalableTy())
5978 return false;
5979
5980 auto areTypesHalfed = [](Value *FullV, Value *HalfV) {
5981 auto *FullTy = FullV->getType();
5982 auto *HalfTy = HalfV->getType();
5983 return FullTy->getPrimitiveSizeInBits().getFixedValue() ==
5984 2 * HalfTy->getPrimitiveSizeInBits().getFixedValue();
5985 };
5986
5987 auto extractHalf = [](Value *FullV, Value *HalfV) {
5988 auto *FullVT = cast<FixedVectorType>(FullV->getType());
5989 auto *HalfVT = cast<FixedVectorType>(HalfV->getType());
5990 return FullVT->getNumElements() == 2 * HalfVT->getNumElements();
5991 };
5992
5993 ArrayRef<int> M1, M2;
5994 Value *S1Op1 = nullptr, *S2Op1 = nullptr;
5995 if (!match(Op1, m_Shuffle(m_Value(S1Op1), m_Undef(), m_Mask(M1))) ||
5996 !match(Op2, m_Shuffle(m_Value(S2Op1), m_Undef(), m_Mask(M2))))
5997 return false;
5998
5999 // If we allow splats, set S1Op1/S2Op1 to nullptr for the relevant arg so that
6000 // it is not checked as an extract below.
6001 if (AllowSplat && isSplatShuffle(Op1))
6002 S1Op1 = nullptr;
6003 if (AllowSplat && isSplatShuffle(Op2))
6004 S2Op1 = nullptr;
6005
6006 // Check that the operands are half as wide as the result and we extract
6007 // half of the elements of the input vectors.
6008 if ((S1Op1 && (!areTypesHalfed(S1Op1, Op1) || !extractHalf(S1Op1, Op1))) ||
6009 (S2Op1 && (!areTypesHalfed(S2Op1, Op2) || !extractHalf(S2Op1, Op2))))
6010 return false;
6011
6012 // Check the mask extracts either the lower or upper half of vector
6013 // elements.
6014 int M1Start = 0;
6015 int M2Start = 0;
6016 int NumElements = cast<FixedVectorType>(Op1->getType())->getNumElements() * 2;
6017 if ((S1Op1 &&
6018 !ShuffleVectorInst::isExtractSubvectorMask(M1, NumElements, M1Start)) ||
6019 (S2Op1 &&
6020 !ShuffleVectorInst::isExtractSubvectorMask(M2, NumElements, M2Start)))
6021 return false;
6022
6023 if ((M1Start != 0 && M1Start != (NumElements / 2)) ||
6024 (M2Start != 0 && M2Start != (NumElements / 2)))
6025 return false;
6026 if (S1Op1 && S2Op1 && M1Start != M2Start)
6027 return false;
6028
6029 return true;
6030 }
6031
6032 /// Check if Ext1 and Ext2 are extends of the same type, doubling the bitwidth
6033 /// of the vector elements.
areExtractExts(Value * Ext1,Value * Ext2)6034 static bool areExtractExts(Value *Ext1, Value *Ext2) {
6035 auto areExtDoubled = [](Instruction *Ext) {
6036 return Ext->getType()->getScalarSizeInBits() ==
6037 2 * Ext->getOperand(0)->getType()->getScalarSizeInBits();
6038 };
6039
6040 if (!match(Ext1, m_ZExtOrSExt(m_Value())) ||
6041 !match(Ext2, m_ZExtOrSExt(m_Value())) ||
6042 !areExtDoubled(cast<Instruction>(Ext1)) ||
6043 !areExtDoubled(cast<Instruction>(Ext2)))
6044 return false;
6045
6046 return true;
6047 }
6048
6049 /// Check if Op could be used with vmull_high_p64 intrinsic.
isOperandOfVmullHighP64(Value * Op)6050 static bool isOperandOfVmullHighP64(Value *Op) {
6051 Value *VectorOperand = nullptr;
6052 ConstantInt *ElementIndex = nullptr;
6053 return match(Op, m_ExtractElt(m_Value(VectorOperand),
6054 m_ConstantInt(ElementIndex))) &&
6055 ElementIndex->getValue() == 1 &&
6056 isa<FixedVectorType>(VectorOperand->getType()) &&
6057 cast<FixedVectorType>(VectorOperand->getType())->getNumElements() == 2;
6058 }
6059
6060 /// Check if Op1 and Op2 could be used with vmull_high_p64 intrinsic.
areOperandsOfVmullHighP64(Value * Op1,Value * Op2)6061 static bool areOperandsOfVmullHighP64(Value *Op1, Value *Op2) {
6062 return isOperandOfVmullHighP64(Op1) && isOperandOfVmullHighP64(Op2);
6063 }
6064
shouldSinkVectorOfPtrs(Value * Ptrs,SmallVectorImpl<Use * > & Ops)6065 static bool shouldSinkVectorOfPtrs(Value *Ptrs, SmallVectorImpl<Use *> &Ops) {
6066 // Restrict ourselves to the form CodeGenPrepare typically constructs.
6067 auto *GEP = dyn_cast<GetElementPtrInst>(Ptrs);
6068 if (!GEP || GEP->getNumOperands() != 2)
6069 return false;
6070
6071 Value *Base = GEP->getOperand(0);
6072 Value *Offsets = GEP->getOperand(1);
6073
6074 // We only care about scalar_base+vector_offsets.
6075 if (Base->getType()->isVectorTy() || !Offsets->getType()->isVectorTy())
6076 return false;
6077
6078 // Sink extends that would allow us to use 32-bit offset vectors.
6079 if (isa<SExtInst>(Offsets) || isa<ZExtInst>(Offsets)) {
6080 auto *OffsetsInst = cast<Instruction>(Offsets);
6081 if (OffsetsInst->getType()->getScalarSizeInBits() > 32 &&
6082 OffsetsInst->getOperand(0)->getType()->getScalarSizeInBits() <= 32)
6083 Ops.push_back(&GEP->getOperandUse(1));
6084 }
6085
6086 // Sink the GEP.
6087 return true;
6088 }
6089
6090 /// We want to sink following cases:
6091 /// (add|sub|gep) A, ((mul|shl) vscale, imm); (add|sub|gep) A, vscale;
6092 /// (add|sub|gep) A, ((mul|shl) zext(vscale), imm);
shouldSinkVScale(Value * Op,SmallVectorImpl<Use * > & Ops)6093 static bool shouldSinkVScale(Value *Op, SmallVectorImpl<Use *> &Ops) {
6094 if (match(Op, m_VScale()))
6095 return true;
6096 if (match(Op, m_Shl(m_VScale(), m_ConstantInt())) ||
6097 match(Op, m_Mul(m_VScale(), m_ConstantInt()))) {
6098 Ops.push_back(&cast<Instruction>(Op)->getOperandUse(0));
6099 return true;
6100 }
6101 if (match(Op, m_Shl(m_ZExt(m_VScale()), m_ConstantInt())) ||
6102 match(Op, m_Mul(m_ZExt(m_VScale()), m_ConstantInt()))) {
6103 Value *ZExtOp = cast<Instruction>(Op)->getOperand(0);
6104 Ops.push_back(&cast<Instruction>(ZExtOp)->getOperandUse(0));
6105 Ops.push_back(&cast<Instruction>(Op)->getOperandUse(0));
6106 return true;
6107 }
6108 return false;
6109 }
6110
6111 /// Check if sinking \p I's operands to I's basic block is profitable, because
6112 /// the operands can be folded into a target instruction, e.g.
6113 /// shufflevectors extracts and/or sext/zext can be folded into (u,s)subl(2).
isProfitableToSinkOperands(Instruction * I,SmallVectorImpl<Use * > & Ops) const6114 bool AArch64TTIImpl::isProfitableToSinkOperands(
6115 Instruction *I, SmallVectorImpl<Use *> &Ops) const {
6116 if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
6117 switch (II->getIntrinsicID()) {
6118 case Intrinsic::aarch64_neon_smull:
6119 case Intrinsic::aarch64_neon_umull:
6120 if (areExtractShuffleVectors(II->getOperand(0), II->getOperand(1),
6121 /*AllowSplat=*/true)) {
6122 Ops.push_back(&II->getOperandUse(0));
6123 Ops.push_back(&II->getOperandUse(1));
6124 return true;
6125 }
6126 [[fallthrough]];
6127
6128 case Intrinsic::fma:
6129 case Intrinsic::fmuladd:
6130 if (isa<VectorType>(I->getType()) &&
6131 cast<VectorType>(I->getType())->getElementType()->isHalfTy() &&
6132 !ST->hasFullFP16())
6133 return false;
6134 [[fallthrough]];
6135 case Intrinsic::aarch64_neon_sqdmull:
6136 case Intrinsic::aarch64_neon_sqdmulh:
6137 case Intrinsic::aarch64_neon_sqrdmulh:
6138 // Sink splats for index lane variants
6139 if (isSplatShuffle(II->getOperand(0)))
6140 Ops.push_back(&II->getOperandUse(0));
6141 if (isSplatShuffle(II->getOperand(1)))
6142 Ops.push_back(&II->getOperandUse(1));
6143 return !Ops.empty();
6144 case Intrinsic::aarch64_neon_fmlal:
6145 case Intrinsic::aarch64_neon_fmlal2:
6146 case Intrinsic::aarch64_neon_fmlsl:
6147 case Intrinsic::aarch64_neon_fmlsl2:
6148 // Sink splats for index lane variants
6149 if (isSplatShuffle(II->getOperand(1)))
6150 Ops.push_back(&II->getOperandUse(1));
6151 if (isSplatShuffle(II->getOperand(2)))
6152 Ops.push_back(&II->getOperandUse(2));
6153 return !Ops.empty();
6154 case Intrinsic::aarch64_sve_ptest_first:
6155 case Intrinsic::aarch64_sve_ptest_last:
6156 if (auto *IIOp = dyn_cast<IntrinsicInst>(II->getOperand(0)))
6157 if (IIOp->getIntrinsicID() == Intrinsic::aarch64_sve_ptrue)
6158 Ops.push_back(&II->getOperandUse(0));
6159 return !Ops.empty();
6160 case Intrinsic::aarch64_sme_write_horiz:
6161 case Intrinsic::aarch64_sme_write_vert:
6162 case Intrinsic::aarch64_sme_writeq_horiz:
6163 case Intrinsic::aarch64_sme_writeq_vert: {
6164 auto *Idx = dyn_cast<Instruction>(II->getOperand(1));
6165 if (!Idx || Idx->getOpcode() != Instruction::Add)
6166 return false;
6167 Ops.push_back(&II->getOperandUse(1));
6168 return true;
6169 }
6170 case Intrinsic::aarch64_sme_read_horiz:
6171 case Intrinsic::aarch64_sme_read_vert:
6172 case Intrinsic::aarch64_sme_readq_horiz:
6173 case Intrinsic::aarch64_sme_readq_vert:
6174 case Intrinsic::aarch64_sme_ld1b_vert:
6175 case Intrinsic::aarch64_sme_ld1h_vert:
6176 case Intrinsic::aarch64_sme_ld1w_vert:
6177 case Intrinsic::aarch64_sme_ld1d_vert:
6178 case Intrinsic::aarch64_sme_ld1q_vert:
6179 case Intrinsic::aarch64_sme_st1b_vert:
6180 case Intrinsic::aarch64_sme_st1h_vert:
6181 case Intrinsic::aarch64_sme_st1w_vert:
6182 case Intrinsic::aarch64_sme_st1d_vert:
6183 case Intrinsic::aarch64_sme_st1q_vert:
6184 case Intrinsic::aarch64_sme_ld1b_horiz:
6185 case Intrinsic::aarch64_sme_ld1h_horiz:
6186 case Intrinsic::aarch64_sme_ld1w_horiz:
6187 case Intrinsic::aarch64_sme_ld1d_horiz:
6188 case Intrinsic::aarch64_sme_ld1q_horiz:
6189 case Intrinsic::aarch64_sme_st1b_horiz:
6190 case Intrinsic::aarch64_sme_st1h_horiz:
6191 case Intrinsic::aarch64_sme_st1w_horiz:
6192 case Intrinsic::aarch64_sme_st1d_horiz:
6193 case Intrinsic::aarch64_sme_st1q_horiz: {
6194 auto *Idx = dyn_cast<Instruction>(II->getOperand(3));
6195 if (!Idx || Idx->getOpcode() != Instruction::Add)
6196 return false;
6197 Ops.push_back(&II->getOperandUse(3));
6198 return true;
6199 }
6200 case Intrinsic::aarch64_neon_pmull:
6201 if (!areExtractShuffleVectors(II->getOperand(0), II->getOperand(1)))
6202 return false;
6203 Ops.push_back(&II->getOperandUse(0));
6204 Ops.push_back(&II->getOperandUse(1));
6205 return true;
6206 case Intrinsic::aarch64_neon_pmull64:
6207 if (!areOperandsOfVmullHighP64(II->getArgOperand(0),
6208 II->getArgOperand(1)))
6209 return false;
6210 Ops.push_back(&II->getArgOperandUse(0));
6211 Ops.push_back(&II->getArgOperandUse(1));
6212 return true;
6213 case Intrinsic::masked_gather:
6214 if (!shouldSinkVectorOfPtrs(II->getArgOperand(0), Ops))
6215 return false;
6216 Ops.push_back(&II->getArgOperandUse(0));
6217 return true;
6218 case Intrinsic::masked_scatter:
6219 if (!shouldSinkVectorOfPtrs(II->getArgOperand(1), Ops))
6220 return false;
6221 Ops.push_back(&II->getArgOperandUse(1));
6222 return true;
6223 default:
6224 return false;
6225 }
6226 }
6227
6228 auto ShouldSinkCondition = [](Value *Cond) -> bool {
6229 auto *II = dyn_cast<IntrinsicInst>(Cond);
6230 return II && II->getIntrinsicID() == Intrinsic::vector_reduce_or &&
6231 isa<ScalableVectorType>(II->getOperand(0)->getType());
6232 };
6233
6234 switch (I->getOpcode()) {
6235 case Instruction::GetElementPtr:
6236 case Instruction::Add:
6237 case Instruction::Sub:
6238 // Sink vscales closer to uses for better isel
6239 for (unsigned Op = 0; Op < I->getNumOperands(); ++Op) {
6240 if (shouldSinkVScale(I->getOperand(Op), Ops)) {
6241 Ops.push_back(&I->getOperandUse(Op));
6242 return true;
6243 }
6244 }
6245 break;
6246 case Instruction::Select: {
6247 if (!ShouldSinkCondition(I->getOperand(0)))
6248 return false;
6249
6250 Ops.push_back(&I->getOperandUse(0));
6251 return true;
6252 }
6253 case Instruction::Br: {
6254 if (cast<BranchInst>(I)->isUnconditional())
6255 return false;
6256
6257 if (!ShouldSinkCondition(cast<BranchInst>(I)->getCondition()))
6258 return false;
6259
6260 Ops.push_back(&I->getOperandUse(0));
6261 return true;
6262 }
6263 default:
6264 break;
6265 }
6266
6267 if (!I->getType()->isVectorTy())
6268 return false;
6269
6270 switch (I->getOpcode()) {
6271 case Instruction::Sub:
6272 case Instruction::Add: {
6273 if (!areExtractExts(I->getOperand(0), I->getOperand(1)))
6274 return false;
6275
6276 // If the exts' operands extract either the lower or upper elements, we
6277 // can sink them too.
6278 auto Ext1 = cast<Instruction>(I->getOperand(0));
6279 auto Ext2 = cast<Instruction>(I->getOperand(1));
6280 if (areExtractShuffleVectors(Ext1->getOperand(0), Ext2->getOperand(0))) {
6281 Ops.push_back(&Ext1->getOperandUse(0));
6282 Ops.push_back(&Ext2->getOperandUse(0));
6283 }
6284
6285 Ops.push_back(&I->getOperandUse(0));
6286 Ops.push_back(&I->getOperandUse(1));
6287
6288 return true;
6289 }
6290 case Instruction::Or: {
6291 // Pattern: Or(And(MaskValue, A), And(Not(MaskValue), B)) ->
6292 // bitselect(MaskValue, A, B) where Not(MaskValue) = Xor(MaskValue, -1)
6293 if (ST->hasNEON()) {
6294 Instruction *OtherAnd, *IA, *IB;
6295 Value *MaskValue;
6296 // MainAnd refers to And instruction that has 'Not' as one of its operands
6297 if (match(I, m_c_Or(m_OneUse(m_Instruction(OtherAnd)),
6298 m_OneUse(m_c_And(m_OneUse(m_Not(m_Value(MaskValue))),
6299 m_Instruction(IA)))))) {
6300 if (match(OtherAnd,
6301 m_c_And(m_Specific(MaskValue), m_Instruction(IB)))) {
6302 Instruction *MainAnd = I->getOperand(0) == OtherAnd
6303 ? cast<Instruction>(I->getOperand(1))
6304 : cast<Instruction>(I->getOperand(0));
6305
6306 // Both Ands should be in same basic block as Or
6307 if (I->getParent() != MainAnd->getParent() ||
6308 I->getParent() != OtherAnd->getParent())
6309 return false;
6310
6311 // Non-mask operands of both Ands should also be in same basic block
6312 if (I->getParent() != IA->getParent() ||
6313 I->getParent() != IB->getParent())
6314 return false;
6315
6316 Ops.push_back(
6317 &MainAnd->getOperandUse(MainAnd->getOperand(0) == IA ? 1 : 0));
6318 Ops.push_back(&I->getOperandUse(0));
6319 Ops.push_back(&I->getOperandUse(1));
6320
6321 return true;
6322 }
6323 }
6324 }
6325
6326 return false;
6327 }
6328 case Instruction::Mul: {
6329 auto ShouldSinkSplatForIndexedVariant = [](Value *V) {
6330 auto *Ty = cast<VectorType>(V->getType());
6331 // For SVE the lane-indexing is within 128-bits, so we can't fold splats.
6332 if (Ty->isScalableTy())
6333 return false;
6334
6335 // Indexed variants of Mul exist for i16 and i32 element types only.
6336 return Ty->getScalarSizeInBits() == 16 || Ty->getScalarSizeInBits() == 32;
6337 };
6338
6339 int NumZExts = 0, NumSExts = 0;
6340 for (auto &Op : I->operands()) {
6341 // Make sure we are not already sinking this operand
6342 if (any_of(Ops, [&](Use *U) { return U->get() == Op; }))
6343 continue;
6344
6345 if (match(&Op, m_ZExtOrSExt(m_Value()))) {
6346 auto *Ext = cast<Instruction>(Op);
6347 auto *ExtOp = Ext->getOperand(0);
6348 if (isSplatShuffle(ExtOp) && ShouldSinkSplatForIndexedVariant(ExtOp))
6349 Ops.push_back(&Ext->getOperandUse(0));
6350 Ops.push_back(&Op);
6351
6352 if (isa<SExtInst>(Ext))
6353 NumSExts++;
6354 else
6355 NumZExts++;
6356
6357 continue;
6358 }
6359
6360 ShuffleVectorInst *Shuffle = dyn_cast<ShuffleVectorInst>(Op);
6361 if (!Shuffle)
6362 continue;
6363
6364 // If the Shuffle is a splat and the operand is a zext/sext, sinking the
6365 // operand and the s/zext can help create indexed s/umull. This is
6366 // especially useful to prevent i64 mul being scalarized.
6367 if (isSplatShuffle(Shuffle) &&
6368 match(Shuffle->getOperand(0), m_ZExtOrSExt(m_Value()))) {
6369 Ops.push_back(&Shuffle->getOperandUse(0));
6370 Ops.push_back(&Op);
6371 if (match(Shuffle->getOperand(0), m_SExt(m_Value())))
6372 NumSExts++;
6373 else
6374 NumZExts++;
6375 continue;
6376 }
6377
6378 Value *ShuffleOperand = Shuffle->getOperand(0);
6379 InsertElementInst *Insert = dyn_cast<InsertElementInst>(ShuffleOperand);
6380 if (!Insert)
6381 continue;
6382
6383 Instruction *OperandInstr = dyn_cast<Instruction>(Insert->getOperand(1));
6384 if (!OperandInstr)
6385 continue;
6386
6387 ConstantInt *ElementConstant =
6388 dyn_cast<ConstantInt>(Insert->getOperand(2));
6389 // Check that the insertelement is inserting into element 0
6390 if (!ElementConstant || !ElementConstant->isZero())
6391 continue;
6392
6393 unsigned Opcode = OperandInstr->getOpcode();
6394 if (Opcode == Instruction::SExt)
6395 NumSExts++;
6396 else if (Opcode == Instruction::ZExt)
6397 NumZExts++;
6398 else {
6399 // If we find that the top bits are known 0, then we can sink and allow
6400 // the backend to generate a umull.
6401 unsigned Bitwidth = I->getType()->getScalarSizeInBits();
6402 APInt UpperMask = APInt::getHighBitsSet(Bitwidth, Bitwidth / 2);
6403 if (!MaskedValueIsZero(OperandInstr, UpperMask, DL))
6404 continue;
6405 NumZExts++;
6406 }
6407
6408 // And(Load) is excluded to prevent CGP getting stuck in a loop of sinking
6409 // the And, just to hoist it again back to the load.
6410 if (!match(OperandInstr, m_And(m_Load(m_Value()), m_Value())))
6411 Ops.push_back(&Insert->getOperandUse(1));
6412 Ops.push_back(&Shuffle->getOperandUse(0));
6413 Ops.push_back(&Op);
6414 }
6415
6416 // It is profitable to sink if we found two of the same type of extends.
6417 if (!Ops.empty() && (NumSExts == 2 || NumZExts == 2))
6418 return true;
6419
6420 // Otherwise, see if we should sink splats for indexed variants.
6421 if (!ShouldSinkSplatForIndexedVariant(I))
6422 return false;
6423
6424 Ops.clear();
6425 if (isSplatShuffle(I->getOperand(0)))
6426 Ops.push_back(&I->getOperandUse(0));
6427 if (isSplatShuffle(I->getOperand(1)))
6428 Ops.push_back(&I->getOperandUse(1));
6429
6430 return !Ops.empty();
6431 }
6432 case Instruction::FMul: {
6433 // For SVE the lane-indexing is within 128-bits, so we can't fold splats.
6434 if (I->getType()->isScalableTy())
6435 return false;
6436
6437 if (cast<VectorType>(I->getType())->getElementType()->isHalfTy() &&
6438 !ST->hasFullFP16())
6439 return false;
6440
6441 // Sink splats for index lane variants
6442 if (isSplatShuffle(I->getOperand(0)))
6443 Ops.push_back(&I->getOperandUse(0));
6444 if (isSplatShuffle(I->getOperand(1)))
6445 Ops.push_back(&I->getOperandUse(1));
6446 return !Ops.empty();
6447 }
6448 default:
6449 return false;
6450 }
6451 return false;
6452 }
6453