1 //===-- AArch64TargetTransformInfo.cpp - AArch64 specific TTI -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "AArch64TargetTransformInfo.h" 10 #include "AArch64ExpandImm.h" 11 #include "AArch64PerfectShuffle.h" 12 #include "MCTargetDesc/AArch64AddressingModes.h" 13 #include "llvm/Analysis/IVDescriptors.h" 14 #include "llvm/Analysis/LoopInfo.h" 15 #include "llvm/Analysis/TargetTransformInfo.h" 16 #include "llvm/CodeGen/BasicTTIImpl.h" 17 #include "llvm/CodeGen/CostTable.h" 18 #include "llvm/CodeGen/TargetLowering.h" 19 #include "llvm/IR/IntrinsicInst.h" 20 #include "llvm/IR/Intrinsics.h" 21 #include "llvm/IR/IntrinsicsAArch64.h" 22 #include "llvm/IR/PatternMatch.h" 23 #include "llvm/Support/Debug.h" 24 #include "llvm/Transforms/InstCombine/InstCombiner.h" 25 #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h" 26 #include <algorithm> 27 #include <optional> 28 using namespace llvm; 29 using namespace llvm::PatternMatch; 30 31 #define DEBUG_TYPE "aarch64tti" 32 33 static cl::opt<bool> EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix", 34 cl::init(true), cl::Hidden); 35 36 static cl::opt<unsigned> SVEGatherOverhead("sve-gather-overhead", cl::init(10), 37 cl::Hidden); 38 39 static cl::opt<unsigned> SVEScatterOverhead("sve-scatter-overhead", 40 cl::init(10), cl::Hidden); 41 42 static cl::opt<unsigned> SVETailFoldInsnThreshold("sve-tail-folding-insn-threshold", 43 cl::init(15), cl::Hidden); 44 45 static cl::opt<unsigned> 46 NeonNonConstStrideOverhead("neon-nonconst-stride-overhead", cl::init(10), 47 cl::Hidden); 48 49 static cl::opt<unsigned> CallPenaltyChangeSM( 50 "call-penalty-sm-change", cl::init(5), cl::Hidden, 51 cl::desc( 52 "Penalty of calling a function that requires a change to PSTATE.SM")); 53 54 static cl::opt<unsigned> InlineCallPenaltyChangeSM( 55 "inline-call-penalty-sm-change", cl::init(10), cl::Hidden, 56 cl::desc("Penalty of inlining a call that requires a change to PSTATE.SM")); 57 58 namespace { 59 class TailFoldingOption { 60 // These bitfields will only ever be set to something non-zero in operator=, 61 // when setting the -sve-tail-folding option. This option should always be of 62 // the form (default|simple|all|disable)[+(Flag1|Flag2|etc)], where here 63 // InitialBits is one of (disabled|all|simple). EnableBits represents 64 // additional flags we're enabling, and DisableBits for those flags we're 65 // disabling. The default flag is tracked in the variable NeedsDefault, since 66 // at the time of setting the option we may not know what the default value 67 // for the CPU is. 68 TailFoldingOpts InitialBits = TailFoldingOpts::Disabled; 69 TailFoldingOpts EnableBits = TailFoldingOpts::Disabled; 70 TailFoldingOpts DisableBits = TailFoldingOpts::Disabled; 71 72 // This value needs to be initialised to true in case the user does not 73 // explicitly set the -sve-tail-folding option. 74 bool NeedsDefault = true; 75 76 void setInitialBits(TailFoldingOpts Bits) { InitialBits = Bits; } 77 78 void setNeedsDefault(bool V) { NeedsDefault = V; } 79 80 void setEnableBit(TailFoldingOpts Bit) { 81 EnableBits |= Bit; 82 DisableBits &= ~Bit; 83 } 84 85 void setDisableBit(TailFoldingOpts Bit) { 86 EnableBits &= ~Bit; 87 DisableBits |= Bit; 88 } 89 90 TailFoldingOpts getBits(TailFoldingOpts DefaultBits) const { 91 TailFoldingOpts Bits = TailFoldingOpts::Disabled; 92 93 assert((InitialBits == TailFoldingOpts::Disabled || !NeedsDefault) && 94 "Initial bits should only include one of " 95 "(disabled|all|simple|default)"); 96 Bits = NeedsDefault ? DefaultBits : InitialBits; 97 Bits |= EnableBits; 98 Bits &= ~DisableBits; 99 100 return Bits; 101 } 102 103 void reportError(std::string Opt) { 104 errs() << "invalid argument '" << Opt 105 << "' to -sve-tail-folding=; the option should be of the form\n" 106 " (disabled|all|default|simple)[+(reductions|recurrences" 107 "|reverse|noreductions|norecurrences|noreverse)]\n"; 108 report_fatal_error("Unrecognised tail-folding option"); 109 } 110 111 public: 112 113 void operator=(const std::string &Val) { 114 // If the user explicitly sets -sve-tail-folding= then treat as an error. 115 if (Val.empty()) { 116 reportError(""); 117 return; 118 } 119 120 // Since the user is explicitly setting the option we don't automatically 121 // need the default unless they require it. 122 setNeedsDefault(false); 123 124 SmallVector<StringRef, 4> TailFoldTypes; 125 StringRef(Val).split(TailFoldTypes, '+', -1, false); 126 127 unsigned StartIdx = 1; 128 if (TailFoldTypes[0] == "disabled") 129 setInitialBits(TailFoldingOpts::Disabled); 130 else if (TailFoldTypes[0] == "all") 131 setInitialBits(TailFoldingOpts::All); 132 else if (TailFoldTypes[0] == "default") 133 setNeedsDefault(true); 134 else if (TailFoldTypes[0] == "simple") 135 setInitialBits(TailFoldingOpts::Simple); 136 else { 137 StartIdx = 0; 138 setInitialBits(TailFoldingOpts::Disabled); 139 } 140 141 for (unsigned I = StartIdx; I < TailFoldTypes.size(); I++) { 142 if (TailFoldTypes[I] == "reductions") 143 setEnableBit(TailFoldingOpts::Reductions); 144 else if (TailFoldTypes[I] == "recurrences") 145 setEnableBit(TailFoldingOpts::Recurrences); 146 else if (TailFoldTypes[I] == "reverse") 147 setEnableBit(TailFoldingOpts::Reverse); 148 else if (TailFoldTypes[I] == "noreductions") 149 setDisableBit(TailFoldingOpts::Reductions); 150 else if (TailFoldTypes[I] == "norecurrences") 151 setDisableBit(TailFoldingOpts::Recurrences); 152 else if (TailFoldTypes[I] == "noreverse") 153 setDisableBit(TailFoldingOpts::Reverse); 154 else 155 reportError(Val); 156 } 157 } 158 159 bool satisfies(TailFoldingOpts DefaultBits, TailFoldingOpts Required) const { 160 return (getBits(DefaultBits) & Required) == Required; 161 } 162 }; 163 } // namespace 164 165 TailFoldingOption TailFoldingOptionLoc; 166 167 cl::opt<TailFoldingOption, true, cl::parser<std::string>> SVETailFolding( 168 "sve-tail-folding", 169 cl::desc( 170 "Control the use of vectorisation using tail-folding for SVE where the" 171 " option is specified in the form (Initial)[+(Flag1|Flag2|...)]:" 172 "\ndisabled (Initial) No loop types will vectorize using " 173 "tail-folding" 174 "\ndefault (Initial) Uses the default tail-folding settings for " 175 "the target CPU" 176 "\nall (Initial) All legal loop types will vectorize using " 177 "tail-folding" 178 "\nsimple (Initial) Use tail-folding for simple loops (not " 179 "reductions or recurrences)" 180 "\nreductions Use tail-folding for loops containing reductions" 181 "\nnoreductions Inverse of above" 182 "\nrecurrences Use tail-folding for loops containing fixed order " 183 "recurrences" 184 "\nnorecurrences Inverse of above" 185 "\nreverse Use tail-folding for loops requiring reversed " 186 "predicates" 187 "\nnoreverse Inverse of above"), 188 cl::location(TailFoldingOptionLoc)); 189 190 // Experimental option that will only be fully functional when the 191 // code-generator is changed to use SVE instead of NEON for all fixed-width 192 // operations. 193 static cl::opt<bool> EnableFixedwidthAutovecInStreamingMode( 194 "enable-fixedwidth-autovec-in-streaming-mode", cl::init(false), cl::Hidden); 195 196 // Experimental option that will only be fully functional when the cost-model 197 // and code-generator have been changed to avoid using scalable vector 198 // instructions that are not legal in streaming SVE mode. 199 static cl::opt<bool> EnableScalableAutovecInStreamingMode( 200 "enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden); 201 202 static bool isSMEABIRoutineCall(const CallInst &CI) { 203 const auto *F = CI.getCalledFunction(); 204 return F && StringSwitch<bool>(F->getName()) 205 .Case("__arm_sme_state", true) 206 .Case("__arm_tpidr2_save", true) 207 .Case("__arm_tpidr2_restore", true) 208 .Case("__arm_za_disable", true) 209 .Default(false); 210 } 211 212 /// Returns true if the function has explicit operations that can only be 213 /// lowered using incompatible instructions for the selected mode. This also 214 /// returns true if the function F may use or modify ZA state. 215 static bool hasPossibleIncompatibleOps(const Function *F) { 216 for (const BasicBlock &BB : *F) { 217 for (const Instruction &I : BB) { 218 // Be conservative for now and assume that any call to inline asm or to 219 // intrinsics could could result in non-streaming ops (e.g. calls to 220 // @llvm.aarch64.* or @llvm.gather/scatter intrinsics). We can assume that 221 // all native LLVM instructions can be lowered to compatible instructions. 222 if (isa<CallInst>(I) && !I.isDebugOrPseudoInst() && 223 (cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) || 224 isSMEABIRoutineCall(cast<CallInst>(I)))) 225 return true; 226 } 227 } 228 return false; 229 } 230 231 bool AArch64TTIImpl::areInlineCompatible(const Function *Caller, 232 const Function *Callee) const { 233 SMEAttrs CallerAttrs(*Caller); 234 SMEAttrs CalleeAttrs(*Callee); 235 if (CalleeAttrs.hasNewZABody()) 236 return false; 237 238 if (CallerAttrs.requiresLazySave(CalleeAttrs) || 239 CallerAttrs.requiresSMChange(CalleeAttrs, 240 /*BodyOverridesInterface=*/true)) { 241 if (hasPossibleIncompatibleOps(Callee)) 242 return false; 243 } 244 245 const TargetMachine &TM = getTLI()->getTargetMachine(); 246 247 const FeatureBitset &CallerBits = 248 TM.getSubtargetImpl(*Caller)->getFeatureBits(); 249 const FeatureBitset &CalleeBits = 250 TM.getSubtargetImpl(*Callee)->getFeatureBits(); 251 252 // Inline a callee if its target-features are a subset of the callers 253 // target-features. 254 return (CallerBits & CalleeBits) == CalleeBits; 255 } 256 257 bool AArch64TTIImpl::areTypesABICompatible( 258 const Function *Caller, const Function *Callee, 259 const ArrayRef<Type *> &Types) const { 260 if (!BaseT::areTypesABICompatible(Caller, Callee, Types)) 261 return false; 262 263 // We need to ensure that argument promotion does not attempt to promote 264 // pointers to fixed-length vector types larger than 128 bits like 265 // <8 x float> (and pointers to aggregate types which have such fixed-length 266 // vector type members) into the values of the pointees. Such vector types 267 // are used for SVE VLS but there is no ABI for SVE VLS arguments and the 268 // backend cannot lower such value arguments. The 128-bit fixed-length SVE 269 // types can be safely treated as 128-bit NEON types and they cannot be 270 // distinguished in IR. 271 if (ST->useSVEForFixedLengthVectors() && llvm::any_of(Types, [](Type *Ty) { 272 auto FVTy = dyn_cast<FixedVectorType>(Ty); 273 return FVTy && 274 FVTy->getScalarSizeInBits() * FVTy->getNumElements() > 128; 275 })) 276 return false; 277 278 return true; 279 } 280 281 unsigned 282 AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call, 283 unsigned DefaultCallPenalty) const { 284 // This function calculates a penalty for executing Call in F. 285 // 286 // There are two ways this function can be called: 287 // (1) F: 288 // call from F -> G (the call here is Call) 289 // 290 // For (1), Call.getCaller() == F, so it will always return a high cost if 291 // a streaming-mode change is required (thus promoting the need to inline the 292 // function) 293 // 294 // (2) F: 295 // call from F -> G (the call here is not Call) 296 // G: 297 // call from G -> H (the call here is Call) 298 // 299 // For (2), if after inlining the body of G into F the call to H requires a 300 // streaming-mode change, and the call to G from F would also require a 301 // streaming-mode change, then there is benefit to do the streaming-mode 302 // change only once and avoid inlining of G into F. 303 SMEAttrs FAttrs(*F); 304 SMEAttrs CalleeAttrs(Call); 305 if (FAttrs.requiresSMChange(CalleeAttrs)) { 306 if (F == Call.getCaller()) // (1) 307 return CallPenaltyChangeSM * DefaultCallPenalty; 308 if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2) 309 return InlineCallPenaltyChangeSM * DefaultCallPenalty; 310 } 311 312 return DefaultCallPenalty; 313 } 314 315 bool AArch64TTIImpl::shouldMaximizeVectorBandwidth( 316 TargetTransformInfo::RegisterKind K) const { 317 assert(K != TargetTransformInfo::RGK_Scalar); 318 return (K == TargetTransformInfo::RGK_FixedWidthVector && 319 ST->isNeonAvailable()); 320 } 321 322 /// Calculate the cost of materializing a 64-bit value. This helper 323 /// method might only calculate a fraction of a larger immediate. Therefore it 324 /// is valid to return a cost of ZERO. 325 InstructionCost AArch64TTIImpl::getIntImmCost(int64_t Val) { 326 // Check if the immediate can be encoded within an instruction. 327 if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64)) 328 return 0; 329 330 if (Val < 0) 331 Val = ~Val; 332 333 // Calculate how many moves we will need to materialize this constant. 334 SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn; 335 AArch64_IMM::expandMOVImm(Val, 64, Insn); 336 return Insn.size(); 337 } 338 339 /// Calculate the cost of materializing the given constant. 340 InstructionCost AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty, 341 TTI::TargetCostKind CostKind) { 342 assert(Ty->isIntegerTy()); 343 344 unsigned BitSize = Ty->getPrimitiveSizeInBits(); 345 if (BitSize == 0) 346 return ~0U; 347 348 // Sign-extend all constants to a multiple of 64-bit. 349 APInt ImmVal = Imm; 350 if (BitSize & 0x3f) 351 ImmVal = Imm.sext((BitSize + 63) & ~0x3fU); 352 353 // Split the constant into 64-bit chunks and calculate the cost for each 354 // chunk. 355 InstructionCost Cost = 0; 356 for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) { 357 APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64); 358 int64_t Val = Tmp.getSExtValue(); 359 Cost += getIntImmCost(Val); 360 } 361 // We need at least one instruction to materialze the constant. 362 return std::max<InstructionCost>(1, Cost); 363 } 364 365 InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx, 366 const APInt &Imm, Type *Ty, 367 TTI::TargetCostKind CostKind, 368 Instruction *Inst) { 369 assert(Ty->isIntegerTy()); 370 371 unsigned BitSize = Ty->getPrimitiveSizeInBits(); 372 // There is no cost model for constants with a bit size of 0. Return TCC_Free 373 // here, so that constant hoisting will ignore this constant. 374 if (BitSize == 0) 375 return TTI::TCC_Free; 376 377 unsigned ImmIdx = ~0U; 378 switch (Opcode) { 379 default: 380 return TTI::TCC_Free; 381 case Instruction::GetElementPtr: 382 // Always hoist the base address of a GetElementPtr. 383 if (Idx == 0) 384 return 2 * TTI::TCC_Basic; 385 return TTI::TCC_Free; 386 case Instruction::Store: 387 ImmIdx = 0; 388 break; 389 case Instruction::Add: 390 case Instruction::Sub: 391 case Instruction::Mul: 392 case Instruction::UDiv: 393 case Instruction::SDiv: 394 case Instruction::URem: 395 case Instruction::SRem: 396 case Instruction::And: 397 case Instruction::Or: 398 case Instruction::Xor: 399 case Instruction::ICmp: 400 ImmIdx = 1; 401 break; 402 // Always return TCC_Free for the shift value of a shift instruction. 403 case Instruction::Shl: 404 case Instruction::LShr: 405 case Instruction::AShr: 406 if (Idx == 1) 407 return TTI::TCC_Free; 408 break; 409 case Instruction::Trunc: 410 case Instruction::ZExt: 411 case Instruction::SExt: 412 case Instruction::IntToPtr: 413 case Instruction::PtrToInt: 414 case Instruction::BitCast: 415 case Instruction::PHI: 416 case Instruction::Call: 417 case Instruction::Select: 418 case Instruction::Ret: 419 case Instruction::Load: 420 break; 421 } 422 423 if (Idx == ImmIdx) { 424 int NumConstants = (BitSize + 63) / 64; 425 InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); 426 return (Cost <= NumConstants * TTI::TCC_Basic) 427 ? static_cast<int>(TTI::TCC_Free) 428 : Cost; 429 } 430 return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); 431 } 432 433 InstructionCost 434 AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx, 435 const APInt &Imm, Type *Ty, 436 TTI::TargetCostKind CostKind) { 437 assert(Ty->isIntegerTy()); 438 439 unsigned BitSize = Ty->getPrimitiveSizeInBits(); 440 // There is no cost model for constants with a bit size of 0. Return TCC_Free 441 // here, so that constant hoisting will ignore this constant. 442 if (BitSize == 0) 443 return TTI::TCC_Free; 444 445 // Most (all?) AArch64 intrinsics do not support folding immediates into the 446 // selected instruction, so we compute the materialization cost for the 447 // immediate directly. 448 if (IID >= Intrinsic::aarch64_addg && IID <= Intrinsic::aarch64_udiv) 449 return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); 450 451 switch (IID) { 452 default: 453 return TTI::TCC_Free; 454 case Intrinsic::sadd_with_overflow: 455 case Intrinsic::uadd_with_overflow: 456 case Intrinsic::ssub_with_overflow: 457 case Intrinsic::usub_with_overflow: 458 case Intrinsic::smul_with_overflow: 459 case Intrinsic::umul_with_overflow: 460 if (Idx == 1) { 461 int NumConstants = (BitSize + 63) / 64; 462 InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); 463 return (Cost <= NumConstants * TTI::TCC_Basic) 464 ? static_cast<int>(TTI::TCC_Free) 465 : Cost; 466 } 467 break; 468 case Intrinsic::experimental_stackmap: 469 if ((Idx < 2) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue()))) 470 return TTI::TCC_Free; 471 break; 472 case Intrinsic::experimental_patchpoint_void: 473 case Intrinsic::experimental_patchpoint_i64: 474 if ((Idx < 4) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue()))) 475 return TTI::TCC_Free; 476 break; 477 case Intrinsic::experimental_gc_statepoint: 478 if ((Idx < 5) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue()))) 479 return TTI::TCC_Free; 480 break; 481 } 482 return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); 483 } 484 485 TargetTransformInfo::PopcntSupportKind 486 AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) { 487 assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2"); 488 if (TyWidth == 32 || TyWidth == 64) 489 return TTI::PSK_FastHardware; 490 // TODO: AArch64TargetLowering::LowerCTPOP() supports 128bit popcount. 491 return TTI::PSK_Software; 492 } 493 494 InstructionCost 495 AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, 496 TTI::TargetCostKind CostKind) { 497 auto *RetTy = ICA.getReturnType(); 498 switch (ICA.getID()) { 499 case Intrinsic::umin: 500 case Intrinsic::umax: 501 case Intrinsic::smin: 502 case Intrinsic::smax: { 503 static const auto ValidMinMaxTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16, 504 MVT::v8i16, MVT::v2i32, MVT::v4i32, 505 MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, 506 MVT::nxv2i64}; 507 auto LT = getTypeLegalizationCost(RetTy); 508 // v2i64 types get converted to cmp+bif hence the cost of 2 509 if (LT.second == MVT::v2i64) 510 return LT.first * 2; 511 if (any_of(ValidMinMaxTys, [<](MVT M) { return M == LT.second; })) 512 return LT.first; 513 break; 514 } 515 case Intrinsic::sadd_sat: 516 case Intrinsic::ssub_sat: 517 case Intrinsic::uadd_sat: 518 case Intrinsic::usub_sat: { 519 static const auto ValidSatTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16, 520 MVT::v8i16, MVT::v2i32, MVT::v4i32, 521 MVT::v2i64}; 522 auto LT = getTypeLegalizationCost(RetTy); 523 // This is a base cost of 1 for the vadd, plus 3 extract shifts if we 524 // need to extend the type, as it uses shr(qadd(shl, shl)). 525 unsigned Instrs = 526 LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4; 527 if (any_of(ValidSatTys, [<](MVT M) { return M == LT.second; })) 528 return LT.first * Instrs; 529 break; 530 } 531 case Intrinsic::abs: { 532 static const auto ValidAbsTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16, 533 MVT::v8i16, MVT::v2i32, MVT::v4i32, 534 MVT::v2i64}; 535 auto LT = getTypeLegalizationCost(RetTy); 536 if (any_of(ValidAbsTys, [<](MVT M) { return M == LT.second; })) 537 return LT.first; 538 break; 539 } 540 case Intrinsic::bswap: { 541 static const auto ValidAbsTys = {MVT::v4i16, MVT::v8i16, MVT::v2i32, 542 MVT::v4i32, MVT::v2i64}; 543 auto LT = getTypeLegalizationCost(RetTy); 544 if (any_of(ValidAbsTys, [<](MVT M) { return M == LT.second; }) && 545 LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits()) 546 return LT.first; 547 break; 548 } 549 case Intrinsic::experimental_stepvector: { 550 InstructionCost Cost = 1; // Cost of the `index' instruction 551 auto LT = getTypeLegalizationCost(RetTy); 552 // Legalisation of illegal vectors involves an `index' instruction plus 553 // (LT.first - 1) vector adds. 554 if (LT.first > 1) { 555 Type *LegalVTy = EVT(LT.second).getTypeForEVT(RetTy->getContext()); 556 InstructionCost AddCost = 557 getArithmeticInstrCost(Instruction::Add, LegalVTy, CostKind); 558 Cost += AddCost * (LT.first - 1); 559 } 560 return Cost; 561 } 562 case Intrinsic::bitreverse: { 563 static const CostTblEntry BitreverseTbl[] = { 564 {Intrinsic::bitreverse, MVT::i32, 1}, 565 {Intrinsic::bitreverse, MVT::i64, 1}, 566 {Intrinsic::bitreverse, MVT::v8i8, 1}, 567 {Intrinsic::bitreverse, MVT::v16i8, 1}, 568 {Intrinsic::bitreverse, MVT::v4i16, 2}, 569 {Intrinsic::bitreverse, MVT::v8i16, 2}, 570 {Intrinsic::bitreverse, MVT::v2i32, 2}, 571 {Intrinsic::bitreverse, MVT::v4i32, 2}, 572 {Intrinsic::bitreverse, MVT::v1i64, 2}, 573 {Intrinsic::bitreverse, MVT::v2i64, 2}, 574 }; 575 const auto LegalisationCost = getTypeLegalizationCost(RetTy); 576 const auto *Entry = 577 CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second); 578 if (Entry) { 579 // Cost Model is using the legal type(i32) that i8 and i16 will be 580 // converted to +1 so that we match the actual lowering cost 581 if (TLI->getValueType(DL, RetTy, true) == MVT::i8 || 582 TLI->getValueType(DL, RetTy, true) == MVT::i16) 583 return LegalisationCost.first * Entry->Cost + 1; 584 585 return LegalisationCost.first * Entry->Cost; 586 } 587 break; 588 } 589 case Intrinsic::ctpop: { 590 if (!ST->hasNEON()) { 591 // 32-bit or 64-bit ctpop without NEON is 12 instructions. 592 return getTypeLegalizationCost(RetTy).first * 12; 593 } 594 static const CostTblEntry CtpopCostTbl[] = { 595 {ISD::CTPOP, MVT::v2i64, 4}, 596 {ISD::CTPOP, MVT::v4i32, 3}, 597 {ISD::CTPOP, MVT::v8i16, 2}, 598 {ISD::CTPOP, MVT::v16i8, 1}, 599 {ISD::CTPOP, MVT::i64, 4}, 600 {ISD::CTPOP, MVT::v2i32, 3}, 601 {ISD::CTPOP, MVT::v4i16, 2}, 602 {ISD::CTPOP, MVT::v8i8, 1}, 603 {ISD::CTPOP, MVT::i32, 5}, 604 }; 605 auto LT = getTypeLegalizationCost(RetTy); 606 MVT MTy = LT.second; 607 if (const auto *Entry = CostTableLookup(CtpopCostTbl, ISD::CTPOP, MTy)) { 608 // Extra cost of +1 when illegal vector types are legalized by promoting 609 // the integer type. 610 int ExtraCost = MTy.isVector() && MTy.getScalarSizeInBits() != 611 RetTy->getScalarSizeInBits() 612 ? 1 613 : 0; 614 return LT.first * Entry->Cost + ExtraCost; 615 } 616 break; 617 } 618 case Intrinsic::sadd_with_overflow: 619 case Intrinsic::uadd_with_overflow: 620 case Intrinsic::ssub_with_overflow: 621 case Intrinsic::usub_with_overflow: 622 case Intrinsic::smul_with_overflow: 623 case Intrinsic::umul_with_overflow: { 624 static const CostTblEntry WithOverflowCostTbl[] = { 625 {Intrinsic::sadd_with_overflow, MVT::i8, 3}, 626 {Intrinsic::uadd_with_overflow, MVT::i8, 3}, 627 {Intrinsic::sadd_with_overflow, MVT::i16, 3}, 628 {Intrinsic::uadd_with_overflow, MVT::i16, 3}, 629 {Intrinsic::sadd_with_overflow, MVT::i32, 1}, 630 {Intrinsic::uadd_with_overflow, MVT::i32, 1}, 631 {Intrinsic::sadd_with_overflow, MVT::i64, 1}, 632 {Intrinsic::uadd_with_overflow, MVT::i64, 1}, 633 {Intrinsic::ssub_with_overflow, MVT::i8, 3}, 634 {Intrinsic::usub_with_overflow, MVT::i8, 3}, 635 {Intrinsic::ssub_with_overflow, MVT::i16, 3}, 636 {Intrinsic::usub_with_overflow, MVT::i16, 3}, 637 {Intrinsic::ssub_with_overflow, MVT::i32, 1}, 638 {Intrinsic::usub_with_overflow, MVT::i32, 1}, 639 {Intrinsic::ssub_with_overflow, MVT::i64, 1}, 640 {Intrinsic::usub_with_overflow, MVT::i64, 1}, 641 {Intrinsic::smul_with_overflow, MVT::i8, 5}, 642 {Intrinsic::umul_with_overflow, MVT::i8, 4}, 643 {Intrinsic::smul_with_overflow, MVT::i16, 5}, 644 {Intrinsic::umul_with_overflow, MVT::i16, 4}, 645 {Intrinsic::smul_with_overflow, MVT::i32, 2}, // eg umull;tst 646 {Intrinsic::umul_with_overflow, MVT::i32, 2}, // eg umull;cmp sxtw 647 {Intrinsic::smul_with_overflow, MVT::i64, 3}, // eg mul;smulh;cmp 648 {Intrinsic::umul_with_overflow, MVT::i64, 3}, // eg mul;umulh;cmp asr 649 }; 650 EVT MTy = TLI->getValueType(DL, RetTy->getContainedType(0), true); 651 if (MTy.isSimple()) 652 if (const auto *Entry = CostTableLookup(WithOverflowCostTbl, ICA.getID(), 653 MTy.getSimpleVT())) 654 return Entry->Cost; 655 break; 656 } 657 case Intrinsic::fptosi_sat: 658 case Intrinsic::fptoui_sat: { 659 if (ICA.getArgTypes().empty()) 660 break; 661 bool IsSigned = ICA.getID() == Intrinsic::fptosi_sat; 662 auto LT = getTypeLegalizationCost(ICA.getArgTypes()[0]); 663 EVT MTy = TLI->getValueType(DL, RetTy); 664 // Check for the legal types, which are where the size of the input and the 665 // output are the same, or we are using cvt f64->i32 or f32->i64. 666 if ((LT.second == MVT::f32 || LT.second == MVT::f64 || 667 LT.second == MVT::v2f32 || LT.second == MVT::v4f32 || 668 LT.second == MVT::v2f64) && 669 (LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits() || 670 (LT.second == MVT::f64 && MTy == MVT::i32) || 671 (LT.second == MVT::f32 && MTy == MVT::i64))) 672 return LT.first; 673 // Similarly for fp16 sizes 674 if (ST->hasFullFP16() && 675 ((LT.second == MVT::f16 && MTy == MVT::i32) || 676 ((LT.second == MVT::v4f16 || LT.second == MVT::v8f16) && 677 (LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits())))) 678 return LT.first; 679 680 // Otherwise we use a legal convert followed by a min+max 681 if ((LT.second.getScalarType() == MVT::f32 || 682 LT.second.getScalarType() == MVT::f64 || 683 (ST->hasFullFP16() && LT.second.getScalarType() == MVT::f16)) && 684 LT.second.getScalarSizeInBits() >= MTy.getScalarSizeInBits()) { 685 Type *LegalTy = 686 Type::getIntNTy(RetTy->getContext(), LT.second.getScalarSizeInBits()); 687 if (LT.second.isVector()) 688 LegalTy = VectorType::get(LegalTy, LT.second.getVectorElementCount()); 689 InstructionCost Cost = 1; 690 IntrinsicCostAttributes Attrs1(IsSigned ? Intrinsic::smin : Intrinsic::umin, 691 LegalTy, {LegalTy, LegalTy}); 692 Cost += getIntrinsicInstrCost(Attrs1, CostKind); 693 IntrinsicCostAttributes Attrs2(IsSigned ? Intrinsic::smax : Intrinsic::umax, 694 LegalTy, {LegalTy, LegalTy}); 695 Cost += getIntrinsicInstrCost(Attrs2, CostKind); 696 return LT.first * Cost; 697 } 698 break; 699 } 700 case Intrinsic::fshl: 701 case Intrinsic::fshr: { 702 if (ICA.getArgs().empty()) 703 break; 704 705 // TODO: Add handling for fshl where third argument is not a constant. 706 const TTI::OperandValueInfo OpInfoZ = TTI::getOperandInfo(ICA.getArgs()[2]); 707 if (!OpInfoZ.isConstant()) 708 break; 709 710 const auto LegalisationCost = getTypeLegalizationCost(RetTy); 711 if (OpInfoZ.isUniform()) { 712 // FIXME: The costs could be lower if the codegen is better. 713 static const CostTblEntry FshlTbl[] = { 714 {Intrinsic::fshl, MVT::v4i32, 3}, // ushr + shl + orr 715 {Intrinsic::fshl, MVT::v2i64, 3}, {Intrinsic::fshl, MVT::v16i8, 4}, 716 {Intrinsic::fshl, MVT::v8i16, 4}, {Intrinsic::fshl, MVT::v2i32, 3}, 717 {Intrinsic::fshl, MVT::v8i8, 4}, {Intrinsic::fshl, MVT::v4i16, 4}}; 718 // Costs for both fshl & fshr are the same, so just pass Intrinsic::fshl 719 // to avoid having to duplicate the costs. 720 const auto *Entry = 721 CostTableLookup(FshlTbl, Intrinsic::fshl, LegalisationCost.second); 722 if (Entry) 723 return LegalisationCost.first * Entry->Cost; 724 } 725 726 auto TyL = getTypeLegalizationCost(RetTy); 727 if (!RetTy->isIntegerTy()) 728 break; 729 730 // Estimate cost manually, as types like i8 and i16 will get promoted to 731 // i32 and CostTableLookup will ignore the extra conversion cost. 732 bool HigherCost = (RetTy->getScalarSizeInBits() != 32 && 733 RetTy->getScalarSizeInBits() < 64) || 734 (RetTy->getScalarSizeInBits() % 64 != 0); 735 unsigned ExtraCost = HigherCost ? 1 : 0; 736 if (RetTy->getScalarSizeInBits() == 32 || 737 RetTy->getScalarSizeInBits() == 64) 738 ExtraCost = 0; // fhsl/fshr for i32 and i64 can be lowered to a single 739 // extr instruction. 740 else if (HigherCost) 741 ExtraCost = 1; 742 else 743 break; 744 return TyL.first + ExtraCost; 745 } 746 default: 747 break; 748 } 749 return BaseT::getIntrinsicInstrCost(ICA, CostKind); 750 } 751 752 /// The function will remove redundant reinterprets casting in the presence 753 /// of the control flow 754 static std::optional<Instruction *> processPhiNode(InstCombiner &IC, 755 IntrinsicInst &II) { 756 SmallVector<Instruction *, 32> Worklist; 757 auto RequiredType = II.getType(); 758 759 auto *PN = dyn_cast<PHINode>(II.getArgOperand(0)); 760 assert(PN && "Expected Phi Node!"); 761 762 // Don't create a new Phi unless we can remove the old one. 763 if (!PN->hasOneUse()) 764 return std::nullopt; 765 766 for (Value *IncValPhi : PN->incoming_values()) { 767 auto *Reinterpret = dyn_cast<IntrinsicInst>(IncValPhi); 768 if (!Reinterpret || 769 Reinterpret->getIntrinsicID() != 770 Intrinsic::aarch64_sve_convert_to_svbool || 771 RequiredType != Reinterpret->getArgOperand(0)->getType()) 772 return std::nullopt; 773 } 774 775 // Create the new Phi 776 IC.Builder.SetInsertPoint(PN); 777 PHINode *NPN = IC.Builder.CreatePHI(RequiredType, PN->getNumIncomingValues()); 778 Worklist.push_back(PN); 779 780 for (unsigned I = 0; I < PN->getNumIncomingValues(); I++) { 781 auto *Reinterpret = cast<Instruction>(PN->getIncomingValue(I)); 782 NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(I)); 783 Worklist.push_back(Reinterpret); 784 } 785 786 // Cleanup Phi Node and reinterprets 787 return IC.replaceInstUsesWith(II, NPN); 788 } 789 790 // (from_svbool (binop (to_svbool pred) (svbool_t _) (svbool_t _)))) 791 // => (binop (pred) (from_svbool _) (from_svbool _)) 792 // 793 // The above transformation eliminates a `to_svbool` in the predicate 794 // operand of bitwise operation `binop` by narrowing the vector width of 795 // the operation. For example, it would convert a `<vscale x 16 x i1> 796 // and` into a `<vscale x 4 x i1> and`. This is profitable because 797 // to_svbool must zero the new lanes during widening, whereas 798 // from_svbool is free. 799 static std::optional<Instruction *> 800 tryCombineFromSVBoolBinOp(InstCombiner &IC, IntrinsicInst &II) { 801 auto BinOp = dyn_cast<IntrinsicInst>(II.getOperand(0)); 802 if (!BinOp) 803 return std::nullopt; 804 805 auto IntrinsicID = BinOp->getIntrinsicID(); 806 switch (IntrinsicID) { 807 case Intrinsic::aarch64_sve_and_z: 808 case Intrinsic::aarch64_sve_bic_z: 809 case Intrinsic::aarch64_sve_eor_z: 810 case Intrinsic::aarch64_sve_nand_z: 811 case Intrinsic::aarch64_sve_nor_z: 812 case Intrinsic::aarch64_sve_orn_z: 813 case Intrinsic::aarch64_sve_orr_z: 814 break; 815 default: 816 return std::nullopt; 817 } 818 819 auto BinOpPred = BinOp->getOperand(0); 820 auto BinOpOp1 = BinOp->getOperand(1); 821 auto BinOpOp2 = BinOp->getOperand(2); 822 823 auto PredIntr = dyn_cast<IntrinsicInst>(BinOpPred); 824 if (!PredIntr || 825 PredIntr->getIntrinsicID() != Intrinsic::aarch64_sve_convert_to_svbool) 826 return std::nullopt; 827 828 auto PredOp = PredIntr->getOperand(0); 829 auto PredOpTy = cast<VectorType>(PredOp->getType()); 830 if (PredOpTy != II.getType()) 831 return std::nullopt; 832 833 SmallVector<Value *> NarrowedBinOpArgs = {PredOp}; 834 auto NarrowBinOpOp1 = IC.Builder.CreateIntrinsic( 835 Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp1}); 836 NarrowedBinOpArgs.push_back(NarrowBinOpOp1); 837 if (BinOpOp1 == BinOpOp2) 838 NarrowedBinOpArgs.push_back(NarrowBinOpOp1); 839 else 840 NarrowedBinOpArgs.push_back(IC.Builder.CreateIntrinsic( 841 Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp2})); 842 843 auto NarrowedBinOp = 844 IC.Builder.CreateIntrinsic(IntrinsicID, {PredOpTy}, NarrowedBinOpArgs); 845 return IC.replaceInstUsesWith(II, NarrowedBinOp); 846 } 847 848 static std::optional<Instruction *> 849 instCombineConvertFromSVBool(InstCombiner &IC, IntrinsicInst &II) { 850 // If the reinterpret instruction operand is a PHI Node 851 if (isa<PHINode>(II.getArgOperand(0))) 852 return processPhiNode(IC, II); 853 854 if (auto BinOpCombine = tryCombineFromSVBoolBinOp(IC, II)) 855 return BinOpCombine; 856 857 // Ignore converts to/from svcount_t. 858 if (isa<TargetExtType>(II.getArgOperand(0)->getType()) || 859 isa<TargetExtType>(II.getType())) 860 return std::nullopt; 861 862 SmallVector<Instruction *, 32> CandidatesForRemoval; 863 Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr; 864 865 const auto *IVTy = cast<VectorType>(II.getType()); 866 867 // Walk the chain of conversions. 868 while (Cursor) { 869 // If the type of the cursor has fewer lanes than the final result, zeroing 870 // must take place, which breaks the equivalence chain. 871 const auto *CursorVTy = cast<VectorType>(Cursor->getType()); 872 if (CursorVTy->getElementCount().getKnownMinValue() < 873 IVTy->getElementCount().getKnownMinValue()) 874 break; 875 876 // If the cursor has the same type as I, it is a viable replacement. 877 if (Cursor->getType() == IVTy) 878 EarliestReplacement = Cursor; 879 880 auto *IntrinsicCursor = dyn_cast<IntrinsicInst>(Cursor); 881 882 // If this is not an SVE conversion intrinsic, this is the end of the chain. 883 if (!IntrinsicCursor || !(IntrinsicCursor->getIntrinsicID() == 884 Intrinsic::aarch64_sve_convert_to_svbool || 885 IntrinsicCursor->getIntrinsicID() == 886 Intrinsic::aarch64_sve_convert_from_svbool)) 887 break; 888 889 CandidatesForRemoval.insert(CandidatesForRemoval.begin(), IntrinsicCursor); 890 Cursor = IntrinsicCursor->getOperand(0); 891 } 892 893 // If no viable replacement in the conversion chain was found, there is 894 // nothing to do. 895 if (!EarliestReplacement) 896 return std::nullopt; 897 898 return IC.replaceInstUsesWith(II, EarliestReplacement); 899 } 900 901 static bool isAllActivePredicate(Value *Pred) { 902 // Look through convert.from.svbool(convert.to.svbool(...) chain. 903 Value *UncastedPred; 904 if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_from_svbool>( 905 m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>( 906 m_Value(UncastedPred))))) 907 // If the predicate has the same or less lanes than the uncasted 908 // predicate then we know the casting has no effect. 909 if (cast<ScalableVectorType>(Pred->getType())->getMinNumElements() <= 910 cast<ScalableVectorType>(UncastedPred->getType())->getMinNumElements()) 911 Pred = UncastedPred; 912 913 return match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>( 914 m_ConstantInt<AArch64SVEPredPattern::all>())); 915 } 916 917 static std::optional<Instruction *> instCombineSVESel(InstCombiner &IC, 918 IntrinsicInst &II) { 919 // svsel(ptrue, x, y) => x 920 auto *OpPredicate = II.getOperand(0); 921 if (isAllActivePredicate(OpPredicate)) 922 return IC.replaceInstUsesWith(II, II.getOperand(1)); 923 924 auto Select = 925 IC.Builder.CreateSelect(OpPredicate, II.getOperand(1), II.getOperand(2)); 926 return IC.replaceInstUsesWith(II, Select); 927 } 928 929 static std::optional<Instruction *> instCombineSVEDup(InstCombiner &IC, 930 IntrinsicInst &II) { 931 IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1)); 932 if (!Pg) 933 return std::nullopt; 934 935 if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue) 936 return std::nullopt; 937 938 const auto PTruePattern = 939 cast<ConstantInt>(Pg->getOperand(0))->getZExtValue(); 940 if (PTruePattern != AArch64SVEPredPattern::vl1) 941 return std::nullopt; 942 943 // The intrinsic is inserting into lane zero so use an insert instead. 944 auto *IdxTy = Type::getInt64Ty(II.getContext()); 945 auto *Insert = InsertElementInst::Create( 946 II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0)); 947 Insert->insertBefore(&II); 948 Insert->takeName(&II); 949 950 return IC.replaceInstUsesWith(II, Insert); 951 } 952 953 static std::optional<Instruction *> instCombineSVEDupX(InstCombiner &IC, 954 IntrinsicInst &II) { 955 // Replace DupX with a regular IR splat. 956 auto *RetTy = cast<ScalableVectorType>(II.getType()); 957 Value *Splat = IC.Builder.CreateVectorSplat(RetTy->getElementCount(), 958 II.getArgOperand(0)); 959 Splat->takeName(&II); 960 return IC.replaceInstUsesWith(II, Splat); 961 } 962 963 static std::optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC, 964 IntrinsicInst &II) { 965 LLVMContext &Ctx = II.getContext(); 966 967 // Check that the predicate is all active 968 auto *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(0)); 969 if (!Pg || Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue) 970 return std::nullopt; 971 972 const auto PTruePattern = 973 cast<ConstantInt>(Pg->getOperand(0))->getZExtValue(); 974 if (PTruePattern != AArch64SVEPredPattern::all) 975 return std::nullopt; 976 977 // Check that we have a compare of zero.. 978 auto *SplatValue = 979 dyn_cast_or_null<ConstantInt>(getSplatValue(II.getArgOperand(2))); 980 if (!SplatValue || !SplatValue->isZero()) 981 return std::nullopt; 982 983 // ..against a dupq 984 auto *DupQLane = dyn_cast<IntrinsicInst>(II.getArgOperand(1)); 985 if (!DupQLane || 986 DupQLane->getIntrinsicID() != Intrinsic::aarch64_sve_dupq_lane) 987 return std::nullopt; 988 989 // Where the dupq is a lane 0 replicate of a vector insert 990 if (!cast<ConstantInt>(DupQLane->getArgOperand(1))->isZero()) 991 return std::nullopt; 992 993 auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0)); 994 if (!VecIns || VecIns->getIntrinsicID() != Intrinsic::vector_insert) 995 return std::nullopt; 996 997 // Where the vector insert is a fixed constant vector insert into undef at 998 // index zero 999 if (!isa<UndefValue>(VecIns->getArgOperand(0))) 1000 return std::nullopt; 1001 1002 if (!cast<ConstantInt>(VecIns->getArgOperand(2))->isZero()) 1003 return std::nullopt; 1004 1005 auto *ConstVec = dyn_cast<Constant>(VecIns->getArgOperand(1)); 1006 if (!ConstVec) 1007 return std::nullopt; 1008 1009 auto *VecTy = dyn_cast<FixedVectorType>(ConstVec->getType()); 1010 auto *OutTy = dyn_cast<ScalableVectorType>(II.getType()); 1011 if (!VecTy || !OutTy || VecTy->getNumElements() != OutTy->getMinNumElements()) 1012 return std::nullopt; 1013 1014 unsigned NumElts = VecTy->getNumElements(); 1015 unsigned PredicateBits = 0; 1016 1017 // Expand intrinsic operands to a 16-bit byte level predicate 1018 for (unsigned I = 0; I < NumElts; ++I) { 1019 auto *Arg = dyn_cast<ConstantInt>(ConstVec->getAggregateElement(I)); 1020 if (!Arg) 1021 return std::nullopt; 1022 if (!Arg->isZero()) 1023 PredicateBits |= 1 << (I * (16 / NumElts)); 1024 } 1025 1026 // If all bits are zero bail early with an empty predicate 1027 if (PredicateBits == 0) { 1028 auto *PFalse = Constant::getNullValue(II.getType()); 1029 PFalse->takeName(&II); 1030 return IC.replaceInstUsesWith(II, PFalse); 1031 } 1032 1033 // Calculate largest predicate type used (where byte predicate is largest) 1034 unsigned Mask = 8; 1035 for (unsigned I = 0; I < 16; ++I) 1036 if ((PredicateBits & (1 << I)) != 0) 1037 Mask |= (I % 8); 1038 1039 unsigned PredSize = Mask & -Mask; 1040 auto *PredType = ScalableVectorType::get( 1041 Type::getInt1Ty(Ctx), AArch64::SVEBitsPerBlock / (PredSize * 8)); 1042 1043 // Ensure all relevant bits are set 1044 for (unsigned I = 0; I < 16; I += PredSize) 1045 if ((PredicateBits & (1 << I)) == 0) 1046 return std::nullopt; 1047 1048 auto *PTruePat = 1049 ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all); 1050 auto *PTrue = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, 1051 {PredType}, {PTruePat}); 1052 auto *ConvertToSVBool = IC.Builder.CreateIntrinsic( 1053 Intrinsic::aarch64_sve_convert_to_svbool, {PredType}, {PTrue}); 1054 auto *ConvertFromSVBool = 1055 IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool, 1056 {II.getType()}, {ConvertToSVBool}); 1057 1058 ConvertFromSVBool->takeName(&II); 1059 return IC.replaceInstUsesWith(II, ConvertFromSVBool); 1060 } 1061 1062 static std::optional<Instruction *> instCombineSVELast(InstCombiner &IC, 1063 IntrinsicInst &II) { 1064 Value *Pg = II.getArgOperand(0); 1065 Value *Vec = II.getArgOperand(1); 1066 auto IntrinsicID = II.getIntrinsicID(); 1067 bool IsAfter = IntrinsicID == Intrinsic::aarch64_sve_lasta; 1068 1069 // lastX(splat(X)) --> X 1070 if (auto *SplatVal = getSplatValue(Vec)) 1071 return IC.replaceInstUsesWith(II, SplatVal); 1072 1073 // If x and/or y is a splat value then: 1074 // lastX (binop (x, y)) --> binop(lastX(x), lastX(y)) 1075 Value *LHS, *RHS; 1076 if (match(Vec, m_OneUse(m_BinOp(m_Value(LHS), m_Value(RHS))))) { 1077 if (isSplatValue(LHS) || isSplatValue(RHS)) { 1078 auto *OldBinOp = cast<BinaryOperator>(Vec); 1079 auto OpC = OldBinOp->getOpcode(); 1080 auto *NewLHS = 1081 IC.Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, LHS}); 1082 auto *NewRHS = 1083 IC.Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, RHS}); 1084 auto *NewBinOp = BinaryOperator::CreateWithCopiedFlags( 1085 OpC, NewLHS, NewRHS, OldBinOp, OldBinOp->getName(), &II); 1086 return IC.replaceInstUsesWith(II, NewBinOp); 1087 } 1088 } 1089 1090 auto *C = dyn_cast<Constant>(Pg); 1091 if (IsAfter && C && C->isNullValue()) { 1092 // The intrinsic is extracting lane 0 so use an extract instead. 1093 auto *IdxTy = Type::getInt64Ty(II.getContext()); 1094 auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, 0)); 1095 Extract->insertBefore(&II); 1096 Extract->takeName(&II); 1097 return IC.replaceInstUsesWith(II, Extract); 1098 } 1099 1100 auto *IntrPG = dyn_cast<IntrinsicInst>(Pg); 1101 if (!IntrPG) 1102 return std::nullopt; 1103 1104 if (IntrPG->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue) 1105 return std::nullopt; 1106 1107 const auto PTruePattern = 1108 cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue(); 1109 1110 // Can the intrinsic's predicate be converted to a known constant index? 1111 unsigned MinNumElts = getNumElementsFromSVEPredPattern(PTruePattern); 1112 if (!MinNumElts) 1113 return std::nullopt; 1114 1115 unsigned Idx = MinNumElts - 1; 1116 // Increment the index if extracting the element after the last active 1117 // predicate element. 1118 if (IsAfter) 1119 ++Idx; 1120 1121 // Ignore extracts whose index is larger than the known minimum vector 1122 // length. NOTE: This is an artificial constraint where we prefer to 1123 // maintain what the user asked for until an alternative is proven faster. 1124 auto *PgVTy = cast<ScalableVectorType>(Pg->getType()); 1125 if (Idx >= PgVTy->getMinNumElements()) 1126 return std::nullopt; 1127 1128 // The intrinsic is extracting a fixed lane so use an extract instead. 1129 auto *IdxTy = Type::getInt64Ty(II.getContext()); 1130 auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, Idx)); 1131 Extract->insertBefore(&II); 1132 Extract->takeName(&II); 1133 return IC.replaceInstUsesWith(II, Extract); 1134 } 1135 1136 static std::optional<Instruction *> instCombineSVECondLast(InstCombiner &IC, 1137 IntrinsicInst &II) { 1138 // The SIMD&FP variant of CLAST[AB] is significantly faster than the scalar 1139 // integer variant across a variety of micro-architectures. Replace scalar 1140 // integer CLAST[AB] intrinsic with optimal SIMD&FP variant. A simple 1141 // bitcast-to-fp + clast[ab] + bitcast-to-int will cost a cycle or two more 1142 // depending on the micro-architecture, but has been observed as generally 1143 // being faster, particularly when the CLAST[AB] op is a loop-carried 1144 // dependency. 1145 Value *Pg = II.getArgOperand(0); 1146 Value *Fallback = II.getArgOperand(1); 1147 Value *Vec = II.getArgOperand(2); 1148 Type *Ty = II.getType(); 1149 1150 if (!Ty->isIntegerTy()) 1151 return std::nullopt; 1152 1153 Type *FPTy; 1154 switch (cast<IntegerType>(Ty)->getBitWidth()) { 1155 default: 1156 return std::nullopt; 1157 case 16: 1158 FPTy = IC.Builder.getHalfTy(); 1159 break; 1160 case 32: 1161 FPTy = IC.Builder.getFloatTy(); 1162 break; 1163 case 64: 1164 FPTy = IC.Builder.getDoubleTy(); 1165 break; 1166 } 1167 1168 Value *FPFallBack = IC.Builder.CreateBitCast(Fallback, FPTy); 1169 auto *FPVTy = VectorType::get( 1170 FPTy, cast<VectorType>(Vec->getType())->getElementCount()); 1171 Value *FPVec = IC.Builder.CreateBitCast(Vec, FPVTy); 1172 auto *FPII = IC.Builder.CreateIntrinsic( 1173 II.getIntrinsicID(), {FPVec->getType()}, {Pg, FPFallBack, FPVec}); 1174 Value *FPIItoInt = IC.Builder.CreateBitCast(FPII, II.getType()); 1175 return IC.replaceInstUsesWith(II, FPIItoInt); 1176 } 1177 1178 static std::optional<Instruction *> instCombineRDFFR(InstCombiner &IC, 1179 IntrinsicInst &II) { 1180 LLVMContext &Ctx = II.getContext(); 1181 // Replace rdffr with predicated rdffr.z intrinsic, so that optimizePTestInstr 1182 // can work with RDFFR_PP for ptest elimination. 1183 auto *AllPat = 1184 ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all); 1185 auto *PTrue = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, 1186 {II.getType()}, {AllPat}); 1187 auto *RDFFR = 1188 IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_rdffr_z, {}, {PTrue}); 1189 RDFFR->takeName(&II); 1190 return IC.replaceInstUsesWith(II, RDFFR); 1191 } 1192 1193 static std::optional<Instruction *> 1194 instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) { 1195 const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue(); 1196 1197 if (Pattern == AArch64SVEPredPattern::all) { 1198 Constant *StepVal = ConstantInt::get(II.getType(), NumElts); 1199 auto *VScale = IC.Builder.CreateVScale(StepVal); 1200 VScale->takeName(&II); 1201 return IC.replaceInstUsesWith(II, VScale); 1202 } 1203 1204 unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern); 1205 1206 return MinNumElts && NumElts >= MinNumElts 1207 ? std::optional<Instruction *>(IC.replaceInstUsesWith( 1208 II, ConstantInt::get(II.getType(), MinNumElts))) 1209 : std::nullopt; 1210 } 1211 1212 static std::optional<Instruction *> instCombineSVEPTest(InstCombiner &IC, 1213 IntrinsicInst &II) { 1214 Value *PgVal = II.getArgOperand(0); 1215 Value *OpVal = II.getArgOperand(1); 1216 1217 // PTEST_<FIRST|LAST>(X, X) is equivalent to PTEST_ANY(X, X). 1218 // Later optimizations prefer this form. 1219 if (PgVal == OpVal && 1220 (II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_first || 1221 II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_last)) { 1222 Value *Ops[] = {PgVal, OpVal}; 1223 Type *Tys[] = {PgVal->getType()}; 1224 1225 auto *PTest = 1226 IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptest_any, Tys, Ops); 1227 PTest->takeName(&II); 1228 1229 return IC.replaceInstUsesWith(II, PTest); 1230 } 1231 1232 IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(PgVal); 1233 IntrinsicInst *Op = dyn_cast<IntrinsicInst>(OpVal); 1234 1235 if (!Pg || !Op) 1236 return std::nullopt; 1237 1238 Intrinsic::ID OpIID = Op->getIntrinsicID(); 1239 1240 if (Pg->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool && 1241 OpIID == Intrinsic::aarch64_sve_convert_to_svbool && 1242 Pg->getArgOperand(0)->getType() == Op->getArgOperand(0)->getType()) { 1243 Value *Ops[] = {Pg->getArgOperand(0), Op->getArgOperand(0)}; 1244 Type *Tys[] = {Pg->getArgOperand(0)->getType()}; 1245 1246 auto *PTest = IC.Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops); 1247 1248 PTest->takeName(&II); 1249 return IC.replaceInstUsesWith(II, PTest); 1250 } 1251 1252 // Transform PTEST_ANY(X=OP(PG,...), X) -> PTEST_ANY(PG, X)). 1253 // Later optimizations may rewrite sequence to use the flag-setting variant 1254 // of instruction X to remove PTEST. 1255 if ((Pg == Op) && (II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_any) && 1256 ((OpIID == Intrinsic::aarch64_sve_brka_z) || 1257 (OpIID == Intrinsic::aarch64_sve_brkb_z) || 1258 (OpIID == Intrinsic::aarch64_sve_brkpa_z) || 1259 (OpIID == Intrinsic::aarch64_sve_brkpb_z) || 1260 (OpIID == Intrinsic::aarch64_sve_rdffr_z) || 1261 (OpIID == Intrinsic::aarch64_sve_and_z) || 1262 (OpIID == Intrinsic::aarch64_sve_bic_z) || 1263 (OpIID == Intrinsic::aarch64_sve_eor_z) || 1264 (OpIID == Intrinsic::aarch64_sve_nand_z) || 1265 (OpIID == Intrinsic::aarch64_sve_nor_z) || 1266 (OpIID == Intrinsic::aarch64_sve_orn_z) || 1267 (OpIID == Intrinsic::aarch64_sve_orr_z))) { 1268 Value *Ops[] = {Pg->getArgOperand(0), Pg}; 1269 Type *Tys[] = {Pg->getType()}; 1270 1271 auto *PTest = IC.Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops); 1272 PTest->takeName(&II); 1273 1274 return IC.replaceInstUsesWith(II, PTest); 1275 } 1276 1277 return std::nullopt; 1278 } 1279 1280 template <Intrinsic::ID MulOpc, typename Intrinsic::ID FuseOpc> 1281 static std::optional<Instruction *> 1282 instCombineSVEVectorFuseMulAddSub(InstCombiner &IC, IntrinsicInst &II, 1283 bool MergeIntoAddendOp) { 1284 Value *P = II.getOperand(0); 1285 Value *MulOp0, *MulOp1, *AddendOp, *Mul; 1286 if (MergeIntoAddendOp) { 1287 AddendOp = II.getOperand(1); 1288 Mul = II.getOperand(2); 1289 } else { 1290 AddendOp = II.getOperand(2); 1291 Mul = II.getOperand(1); 1292 } 1293 1294 if (!match(Mul, m_Intrinsic<MulOpc>(m_Specific(P), m_Value(MulOp0), 1295 m_Value(MulOp1)))) 1296 return std::nullopt; 1297 1298 if (!Mul->hasOneUse()) 1299 return std::nullopt; 1300 1301 Instruction *FMFSource = nullptr; 1302 if (II.getType()->isFPOrFPVectorTy()) { 1303 llvm::FastMathFlags FAddFlags = II.getFastMathFlags(); 1304 // Stop the combine when the flags on the inputs differ in case dropping 1305 // flags would lead to us missing out on more beneficial optimizations. 1306 if (FAddFlags != cast<CallInst>(Mul)->getFastMathFlags()) 1307 return std::nullopt; 1308 if (!FAddFlags.allowContract()) 1309 return std::nullopt; 1310 FMFSource = &II; 1311 } 1312 1313 CallInst *Res; 1314 if (MergeIntoAddendOp) 1315 Res = IC.Builder.CreateIntrinsic(FuseOpc, {II.getType()}, 1316 {P, AddendOp, MulOp0, MulOp1}, FMFSource); 1317 else 1318 Res = IC.Builder.CreateIntrinsic(FuseOpc, {II.getType()}, 1319 {P, MulOp0, MulOp1, AddendOp}, FMFSource); 1320 1321 return IC.replaceInstUsesWith(II, Res); 1322 } 1323 1324 static std::optional<Instruction *> 1325 instCombineSVELD1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) { 1326 Value *Pred = II.getOperand(0); 1327 Value *PtrOp = II.getOperand(1); 1328 Type *VecTy = II.getType(); 1329 1330 if (isAllActivePredicate(Pred)) { 1331 LoadInst *Load = IC.Builder.CreateLoad(VecTy, PtrOp); 1332 Load->copyMetadata(II); 1333 return IC.replaceInstUsesWith(II, Load); 1334 } 1335 1336 CallInst *MaskedLoad = 1337 IC.Builder.CreateMaskedLoad(VecTy, PtrOp, PtrOp->getPointerAlignment(DL), 1338 Pred, ConstantAggregateZero::get(VecTy)); 1339 MaskedLoad->copyMetadata(II); 1340 return IC.replaceInstUsesWith(II, MaskedLoad); 1341 } 1342 1343 static std::optional<Instruction *> 1344 instCombineSVEST1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) { 1345 Value *VecOp = II.getOperand(0); 1346 Value *Pred = II.getOperand(1); 1347 Value *PtrOp = II.getOperand(2); 1348 1349 if (isAllActivePredicate(Pred)) { 1350 StoreInst *Store = IC.Builder.CreateStore(VecOp, PtrOp); 1351 Store->copyMetadata(II); 1352 return IC.eraseInstFromFunction(II); 1353 } 1354 1355 CallInst *MaskedStore = IC.Builder.CreateMaskedStore( 1356 VecOp, PtrOp, PtrOp->getPointerAlignment(DL), Pred); 1357 MaskedStore->copyMetadata(II); 1358 return IC.eraseInstFromFunction(II); 1359 } 1360 1361 static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) { 1362 switch (Intrinsic) { 1363 case Intrinsic::aarch64_sve_fmul_u: 1364 return Instruction::BinaryOps::FMul; 1365 case Intrinsic::aarch64_sve_fadd_u: 1366 return Instruction::BinaryOps::FAdd; 1367 case Intrinsic::aarch64_sve_fsub_u: 1368 return Instruction::BinaryOps::FSub; 1369 default: 1370 return Instruction::BinaryOpsEnd; 1371 } 1372 } 1373 1374 static std::optional<Instruction *> 1375 instCombineSVEVectorBinOp(InstCombiner &IC, IntrinsicInst &II) { 1376 // Bail due to missing support for ISD::STRICT_ scalable vector operations. 1377 if (II.isStrictFP()) 1378 return std::nullopt; 1379 1380 auto *OpPredicate = II.getOperand(0); 1381 auto BinOpCode = intrinsicIDToBinOpCode(II.getIntrinsicID()); 1382 if (BinOpCode == Instruction::BinaryOpsEnd || 1383 !match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>( 1384 m_ConstantInt<AArch64SVEPredPattern::all>()))) 1385 return std::nullopt; 1386 IRBuilderBase::FastMathFlagGuard FMFGuard(IC.Builder); 1387 IC.Builder.setFastMathFlags(II.getFastMathFlags()); 1388 auto BinOp = 1389 IC.Builder.CreateBinOp(BinOpCode, II.getOperand(1), II.getOperand(2)); 1390 return IC.replaceInstUsesWith(II, BinOp); 1391 } 1392 1393 // Canonicalise operations that take an all active predicate (e.g. sve.add -> 1394 // sve.add_u). 1395 static std::optional<Instruction *> instCombineSVEAllActive(IntrinsicInst &II, 1396 Intrinsic::ID IID) { 1397 auto *OpPredicate = II.getOperand(0); 1398 if (!match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>( 1399 m_ConstantInt<AArch64SVEPredPattern::all>()))) 1400 return std::nullopt; 1401 1402 auto *Mod = II.getModule(); 1403 auto *NewDecl = Intrinsic::getDeclaration(Mod, IID, {II.getType()}); 1404 II.setCalledFunction(NewDecl); 1405 1406 return &II; 1407 } 1408 1409 static std::optional<Instruction *> instCombineSVEVectorAdd(InstCombiner &IC, 1410 IntrinsicInst &II) { 1411 if (auto II_U = instCombineSVEAllActive(II, Intrinsic::aarch64_sve_add_u)) 1412 return II_U; 1413 if (auto MLA = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul, 1414 Intrinsic::aarch64_sve_mla>( 1415 IC, II, true)) 1416 return MLA; 1417 if (auto MAD = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul, 1418 Intrinsic::aarch64_sve_mad>( 1419 IC, II, false)) 1420 return MAD; 1421 return std::nullopt; 1422 } 1423 1424 static std::optional<Instruction *> 1425 instCombineSVEVectorFAdd(InstCombiner &IC, IntrinsicInst &II) { 1426 if (auto II_U = instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fadd_u)) 1427 return II_U; 1428 if (auto FMLA = 1429 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul, 1430 Intrinsic::aarch64_sve_fmla>(IC, II, 1431 true)) 1432 return FMLA; 1433 if (auto FMAD = 1434 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul, 1435 Intrinsic::aarch64_sve_fmad>(IC, II, 1436 false)) 1437 return FMAD; 1438 if (auto FMLA = 1439 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u, 1440 Intrinsic::aarch64_sve_fmla>(IC, II, 1441 true)) 1442 return FMLA; 1443 return std::nullopt; 1444 } 1445 1446 static std::optional<Instruction *> 1447 instCombineSVEVectorFAddU(InstCombiner &IC, IntrinsicInst &II) { 1448 if (auto FMLA = 1449 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul, 1450 Intrinsic::aarch64_sve_fmla>(IC, II, 1451 true)) 1452 return FMLA; 1453 if (auto FMAD = 1454 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul, 1455 Intrinsic::aarch64_sve_fmad>(IC, II, 1456 false)) 1457 return FMAD; 1458 if (auto FMLA_U = 1459 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u, 1460 Intrinsic::aarch64_sve_fmla_u>( 1461 IC, II, true)) 1462 return FMLA_U; 1463 return instCombineSVEVectorBinOp(IC, II); 1464 } 1465 1466 static std::optional<Instruction *> 1467 instCombineSVEVectorFSub(InstCombiner &IC, IntrinsicInst &II) { 1468 if (auto II_U = instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fsub_u)) 1469 return II_U; 1470 if (auto FMLS = 1471 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul, 1472 Intrinsic::aarch64_sve_fmls>(IC, II, 1473 true)) 1474 return FMLS; 1475 if (auto FMSB = 1476 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul, 1477 Intrinsic::aarch64_sve_fnmsb>( 1478 IC, II, false)) 1479 return FMSB; 1480 if (auto FMLS = 1481 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u, 1482 Intrinsic::aarch64_sve_fmls>(IC, II, 1483 true)) 1484 return FMLS; 1485 return std::nullopt; 1486 } 1487 1488 static std::optional<Instruction *> 1489 instCombineSVEVectorFSubU(InstCombiner &IC, IntrinsicInst &II) { 1490 if (auto FMLS = 1491 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul, 1492 Intrinsic::aarch64_sve_fmls>(IC, II, 1493 true)) 1494 return FMLS; 1495 if (auto FMSB = 1496 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul, 1497 Intrinsic::aarch64_sve_fnmsb>( 1498 IC, II, false)) 1499 return FMSB; 1500 if (auto FMLS_U = 1501 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u, 1502 Intrinsic::aarch64_sve_fmls_u>( 1503 IC, II, true)) 1504 return FMLS_U; 1505 return instCombineSVEVectorBinOp(IC, II); 1506 } 1507 1508 static std::optional<Instruction *> instCombineSVEVectorSub(InstCombiner &IC, 1509 IntrinsicInst &II) { 1510 if (auto II_U = instCombineSVEAllActive(II, Intrinsic::aarch64_sve_sub_u)) 1511 return II_U; 1512 if (auto MLS = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul, 1513 Intrinsic::aarch64_sve_mls>( 1514 IC, II, true)) 1515 return MLS; 1516 return std::nullopt; 1517 } 1518 1519 static std::optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC, 1520 IntrinsicInst &II, 1521 Intrinsic::ID IID) { 1522 auto *OpPredicate = II.getOperand(0); 1523 auto *OpMultiplicand = II.getOperand(1); 1524 auto *OpMultiplier = II.getOperand(2); 1525 1526 // Canonicalise a non _u intrinsic only. 1527 if (II.getIntrinsicID() != IID) 1528 if (auto II_U = instCombineSVEAllActive(II, IID)) 1529 return II_U; 1530 1531 // Return true if a given instruction is a unit splat value, false otherwise. 1532 auto IsUnitSplat = [](auto *I) { 1533 auto *SplatValue = getSplatValue(I); 1534 if (!SplatValue) 1535 return false; 1536 return match(SplatValue, m_FPOne()) || match(SplatValue, m_One()); 1537 }; 1538 1539 // Return true if a given instruction is an aarch64_sve_dup intrinsic call 1540 // with a unit splat value, false otherwise. 1541 auto IsUnitDup = [](auto *I) { 1542 auto *IntrI = dyn_cast<IntrinsicInst>(I); 1543 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup) 1544 return false; 1545 1546 auto *SplatValue = IntrI->getOperand(2); 1547 return match(SplatValue, m_FPOne()) || match(SplatValue, m_One()); 1548 }; 1549 1550 if (IsUnitSplat(OpMultiplier)) { 1551 // [f]mul pg %n, (dupx 1) => %n 1552 OpMultiplicand->takeName(&II); 1553 return IC.replaceInstUsesWith(II, OpMultiplicand); 1554 } else if (IsUnitDup(OpMultiplier)) { 1555 // [f]mul pg %n, (dup pg 1) => %n 1556 auto *DupInst = cast<IntrinsicInst>(OpMultiplier); 1557 auto *DupPg = DupInst->getOperand(1); 1558 // TODO: this is naive. The optimization is still valid if DupPg 1559 // 'encompasses' OpPredicate, not only if they're the same predicate. 1560 if (OpPredicate == DupPg) { 1561 OpMultiplicand->takeName(&II); 1562 return IC.replaceInstUsesWith(II, OpMultiplicand); 1563 } 1564 } 1565 1566 return instCombineSVEVectorBinOp(IC, II); 1567 } 1568 1569 static std::optional<Instruction *> instCombineSVEUnpack(InstCombiner &IC, 1570 IntrinsicInst &II) { 1571 Value *UnpackArg = II.getArgOperand(0); 1572 auto *RetTy = cast<ScalableVectorType>(II.getType()); 1573 bool IsSigned = II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpkhi || 1574 II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpklo; 1575 1576 // Hi = uunpkhi(splat(X)) --> Hi = splat(extend(X)) 1577 // Lo = uunpklo(splat(X)) --> Lo = splat(extend(X)) 1578 if (auto *ScalarArg = getSplatValue(UnpackArg)) { 1579 ScalarArg = 1580 IC.Builder.CreateIntCast(ScalarArg, RetTy->getScalarType(), IsSigned); 1581 Value *NewVal = 1582 IC.Builder.CreateVectorSplat(RetTy->getElementCount(), ScalarArg); 1583 NewVal->takeName(&II); 1584 return IC.replaceInstUsesWith(II, NewVal); 1585 } 1586 1587 return std::nullopt; 1588 } 1589 static std::optional<Instruction *> instCombineSVETBL(InstCombiner &IC, 1590 IntrinsicInst &II) { 1591 auto *OpVal = II.getOperand(0); 1592 auto *OpIndices = II.getOperand(1); 1593 VectorType *VTy = cast<VectorType>(II.getType()); 1594 1595 // Check whether OpIndices is a constant splat value < minimal element count 1596 // of result. 1597 auto *SplatValue = dyn_cast_or_null<ConstantInt>(getSplatValue(OpIndices)); 1598 if (!SplatValue || 1599 SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue())) 1600 return std::nullopt; 1601 1602 // Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to 1603 // splat_vector(extractelement(OpVal, SplatValue)) for further optimization. 1604 auto *Extract = IC.Builder.CreateExtractElement(OpVal, SplatValue); 1605 auto *VectorSplat = 1606 IC.Builder.CreateVectorSplat(VTy->getElementCount(), Extract); 1607 1608 VectorSplat->takeName(&II); 1609 return IC.replaceInstUsesWith(II, VectorSplat); 1610 } 1611 1612 static std::optional<Instruction *> instCombineSVEZip(InstCombiner &IC, 1613 IntrinsicInst &II) { 1614 // zip1(uzp1(A, B), uzp2(A, B)) --> A 1615 // zip2(uzp1(A, B), uzp2(A, B)) --> B 1616 Value *A, *B; 1617 if (match(II.getArgOperand(0), 1618 m_Intrinsic<Intrinsic::aarch64_sve_uzp1>(m_Value(A), m_Value(B))) && 1619 match(II.getArgOperand(1), m_Intrinsic<Intrinsic::aarch64_sve_uzp2>( 1620 m_Specific(A), m_Specific(B)))) 1621 return IC.replaceInstUsesWith( 1622 II, (II.getIntrinsicID() == Intrinsic::aarch64_sve_zip1 ? A : B)); 1623 1624 return std::nullopt; 1625 } 1626 1627 static std::optional<Instruction *> 1628 instCombineLD1GatherIndex(InstCombiner &IC, IntrinsicInst &II) { 1629 Value *Mask = II.getOperand(0); 1630 Value *BasePtr = II.getOperand(1); 1631 Value *Index = II.getOperand(2); 1632 Type *Ty = II.getType(); 1633 Value *PassThru = ConstantAggregateZero::get(Ty); 1634 1635 // Contiguous gather => masked load. 1636 // (sve.ld1.gather.index Mask BasePtr (sve.index IndexBase 1)) 1637 // => (masked.load (gep BasePtr IndexBase) Align Mask zeroinitializer) 1638 Value *IndexBase; 1639 if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>( 1640 m_Value(IndexBase), m_SpecificInt(1)))) { 1641 Align Alignment = 1642 BasePtr->getPointerAlignment(II.getModule()->getDataLayout()); 1643 1644 Type *VecPtrTy = PointerType::getUnqual(Ty); 1645 Value *Ptr = IC.Builder.CreateGEP(cast<VectorType>(Ty)->getElementType(), 1646 BasePtr, IndexBase); 1647 Ptr = IC.Builder.CreateBitCast(Ptr, VecPtrTy); 1648 CallInst *MaskedLoad = 1649 IC.Builder.CreateMaskedLoad(Ty, Ptr, Alignment, Mask, PassThru); 1650 MaskedLoad->takeName(&II); 1651 return IC.replaceInstUsesWith(II, MaskedLoad); 1652 } 1653 1654 return std::nullopt; 1655 } 1656 1657 static std::optional<Instruction *> 1658 instCombineST1ScatterIndex(InstCombiner &IC, IntrinsicInst &II) { 1659 Value *Val = II.getOperand(0); 1660 Value *Mask = II.getOperand(1); 1661 Value *BasePtr = II.getOperand(2); 1662 Value *Index = II.getOperand(3); 1663 Type *Ty = Val->getType(); 1664 1665 // Contiguous scatter => masked store. 1666 // (sve.st1.scatter.index Value Mask BasePtr (sve.index IndexBase 1)) 1667 // => (masked.store Value (gep BasePtr IndexBase) Align Mask) 1668 Value *IndexBase; 1669 if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>( 1670 m_Value(IndexBase), m_SpecificInt(1)))) { 1671 Align Alignment = 1672 BasePtr->getPointerAlignment(II.getModule()->getDataLayout()); 1673 1674 Value *Ptr = IC.Builder.CreateGEP(cast<VectorType>(Ty)->getElementType(), 1675 BasePtr, IndexBase); 1676 Type *VecPtrTy = PointerType::getUnqual(Ty); 1677 Ptr = IC.Builder.CreateBitCast(Ptr, VecPtrTy); 1678 1679 (void)IC.Builder.CreateMaskedStore(Val, Ptr, Alignment, Mask); 1680 1681 return IC.eraseInstFromFunction(II); 1682 } 1683 1684 return std::nullopt; 1685 } 1686 1687 static std::optional<Instruction *> instCombineSVESDIV(InstCombiner &IC, 1688 IntrinsicInst &II) { 1689 Type *Int32Ty = IC.Builder.getInt32Ty(); 1690 Value *Pred = II.getOperand(0); 1691 Value *Vec = II.getOperand(1); 1692 Value *DivVec = II.getOperand(2); 1693 1694 Value *SplatValue = getSplatValue(DivVec); 1695 ConstantInt *SplatConstantInt = dyn_cast_or_null<ConstantInt>(SplatValue); 1696 if (!SplatConstantInt) 1697 return std::nullopt; 1698 APInt Divisor = SplatConstantInt->getValue(); 1699 1700 if (Divisor.isPowerOf2()) { 1701 Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2()); 1702 auto ASRD = IC.Builder.CreateIntrinsic( 1703 Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2}); 1704 return IC.replaceInstUsesWith(II, ASRD); 1705 } 1706 if (Divisor.isNegatedPowerOf2()) { 1707 Divisor.negate(); 1708 Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2()); 1709 auto ASRD = IC.Builder.CreateIntrinsic( 1710 Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2}); 1711 auto NEG = IC.Builder.CreateIntrinsic( 1712 Intrinsic::aarch64_sve_neg, {ASRD->getType()}, {ASRD, Pred, ASRD}); 1713 return IC.replaceInstUsesWith(II, NEG); 1714 } 1715 1716 return std::nullopt; 1717 } 1718 1719 bool SimplifyValuePattern(SmallVector<Value *> &Vec, bool AllowPoison) { 1720 size_t VecSize = Vec.size(); 1721 if (VecSize == 1) 1722 return true; 1723 if (!isPowerOf2_64(VecSize)) 1724 return false; 1725 size_t HalfVecSize = VecSize / 2; 1726 1727 for (auto LHS = Vec.begin(), RHS = Vec.begin() + HalfVecSize; 1728 RHS != Vec.end(); LHS++, RHS++) { 1729 if (*LHS != nullptr && *RHS != nullptr) { 1730 if (*LHS == *RHS) 1731 continue; 1732 else 1733 return false; 1734 } 1735 if (!AllowPoison) 1736 return false; 1737 if (*LHS == nullptr && *RHS != nullptr) 1738 *LHS = *RHS; 1739 } 1740 1741 Vec.resize(HalfVecSize); 1742 SimplifyValuePattern(Vec, AllowPoison); 1743 return true; 1744 } 1745 1746 // Try to simplify dupqlane patterns like dupqlane(f32 A, f32 B, f32 A, f32 B) 1747 // to dupqlane(f64(C)) where C is A concatenated with B 1748 static std::optional<Instruction *> instCombineSVEDupqLane(InstCombiner &IC, 1749 IntrinsicInst &II) { 1750 Value *CurrentInsertElt = nullptr, *Default = nullptr; 1751 if (!match(II.getOperand(0), 1752 m_Intrinsic<Intrinsic::vector_insert>( 1753 m_Value(Default), m_Value(CurrentInsertElt), m_Value())) || 1754 !isa<FixedVectorType>(CurrentInsertElt->getType())) 1755 return std::nullopt; 1756 auto IIScalableTy = cast<ScalableVectorType>(II.getType()); 1757 1758 // Insert the scalars into a container ordered by InsertElement index 1759 SmallVector<Value *> Elts(IIScalableTy->getMinNumElements(), nullptr); 1760 while (auto InsertElt = dyn_cast<InsertElementInst>(CurrentInsertElt)) { 1761 auto Idx = cast<ConstantInt>(InsertElt->getOperand(2)); 1762 Elts[Idx->getValue().getZExtValue()] = InsertElt->getOperand(1); 1763 CurrentInsertElt = InsertElt->getOperand(0); 1764 } 1765 1766 bool AllowPoison = 1767 isa<PoisonValue>(CurrentInsertElt) && isa<PoisonValue>(Default); 1768 if (!SimplifyValuePattern(Elts, AllowPoison)) 1769 return std::nullopt; 1770 1771 // Rebuild the simplified chain of InsertElements. e.g. (a, b, a, b) as (a, b) 1772 Value *InsertEltChain = PoisonValue::get(CurrentInsertElt->getType()); 1773 for (size_t I = 0; I < Elts.size(); I++) { 1774 if (Elts[I] == nullptr) 1775 continue; 1776 InsertEltChain = IC.Builder.CreateInsertElement(InsertEltChain, Elts[I], 1777 IC.Builder.getInt64(I)); 1778 } 1779 if (InsertEltChain == nullptr) 1780 return std::nullopt; 1781 1782 // Splat the simplified sequence, e.g. (f16 a, f16 b, f16 c, f16 d) as one i64 1783 // value or (f16 a, f16 b) as one i32 value. This requires an InsertSubvector 1784 // be bitcast to a type wide enough to fit the sequence, be splatted, and then 1785 // be narrowed back to the original type. 1786 unsigned PatternWidth = IIScalableTy->getScalarSizeInBits() * Elts.size(); 1787 unsigned PatternElementCount = IIScalableTy->getScalarSizeInBits() * 1788 IIScalableTy->getMinNumElements() / 1789 PatternWidth; 1790 1791 IntegerType *WideTy = IC.Builder.getIntNTy(PatternWidth); 1792 auto *WideScalableTy = ScalableVectorType::get(WideTy, PatternElementCount); 1793 auto *WideShuffleMaskTy = 1794 ScalableVectorType::get(IC.Builder.getInt32Ty(), PatternElementCount); 1795 1796 auto ZeroIdx = ConstantInt::get(IC.Builder.getInt64Ty(), APInt(64, 0)); 1797 auto InsertSubvector = IC.Builder.CreateInsertVector( 1798 II.getType(), PoisonValue::get(II.getType()), InsertEltChain, ZeroIdx); 1799 auto WideBitcast = 1800 IC.Builder.CreateBitOrPointerCast(InsertSubvector, WideScalableTy); 1801 auto WideShuffleMask = ConstantAggregateZero::get(WideShuffleMaskTy); 1802 auto WideShuffle = IC.Builder.CreateShuffleVector( 1803 WideBitcast, PoisonValue::get(WideScalableTy), WideShuffleMask); 1804 auto NarrowBitcast = 1805 IC.Builder.CreateBitOrPointerCast(WideShuffle, II.getType()); 1806 1807 return IC.replaceInstUsesWith(II, NarrowBitcast); 1808 } 1809 1810 static std::optional<Instruction *> instCombineMaxMinNM(InstCombiner &IC, 1811 IntrinsicInst &II) { 1812 Value *A = II.getArgOperand(0); 1813 Value *B = II.getArgOperand(1); 1814 if (A == B) 1815 return IC.replaceInstUsesWith(II, A); 1816 1817 return std::nullopt; 1818 } 1819 1820 static std::optional<Instruction *> instCombineSVESrshl(InstCombiner &IC, 1821 IntrinsicInst &II) { 1822 Value *Pred = II.getOperand(0); 1823 Value *Vec = II.getOperand(1); 1824 Value *Shift = II.getOperand(2); 1825 1826 // Convert SRSHL into the simpler LSL intrinsic when fed by an ABS intrinsic. 1827 Value *AbsPred, *MergedValue; 1828 if (!match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_sqabs>( 1829 m_Value(MergedValue), m_Value(AbsPred), m_Value())) && 1830 !match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_abs>( 1831 m_Value(MergedValue), m_Value(AbsPred), m_Value()))) 1832 1833 return std::nullopt; 1834 1835 // Transform is valid if any of the following are true: 1836 // * The ABS merge value is an undef or non-negative 1837 // * The ABS predicate is all active 1838 // * The ABS predicate and the SRSHL predicates are the same 1839 if (!isa<UndefValue>(MergedValue) && !match(MergedValue, m_NonNegative()) && 1840 AbsPred != Pred && !isAllActivePredicate(AbsPred)) 1841 return std::nullopt; 1842 1843 // Only valid when the shift amount is non-negative, otherwise the rounding 1844 // behaviour of SRSHL cannot be ignored. 1845 if (!match(Shift, m_NonNegative())) 1846 return std::nullopt; 1847 1848 auto LSL = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_lsl, 1849 {II.getType()}, {Pred, Vec, Shift}); 1850 1851 return IC.replaceInstUsesWith(II, LSL); 1852 } 1853 1854 std::optional<Instruction *> 1855 AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC, 1856 IntrinsicInst &II) const { 1857 Intrinsic::ID IID = II.getIntrinsicID(); 1858 switch (IID) { 1859 default: 1860 break; 1861 case Intrinsic::aarch64_neon_fmaxnm: 1862 case Intrinsic::aarch64_neon_fminnm: 1863 return instCombineMaxMinNM(IC, II); 1864 case Intrinsic::aarch64_sve_convert_from_svbool: 1865 return instCombineConvertFromSVBool(IC, II); 1866 case Intrinsic::aarch64_sve_dup: 1867 return instCombineSVEDup(IC, II); 1868 case Intrinsic::aarch64_sve_dup_x: 1869 return instCombineSVEDupX(IC, II); 1870 case Intrinsic::aarch64_sve_cmpne: 1871 case Intrinsic::aarch64_sve_cmpne_wide: 1872 return instCombineSVECmpNE(IC, II); 1873 case Intrinsic::aarch64_sve_rdffr: 1874 return instCombineRDFFR(IC, II); 1875 case Intrinsic::aarch64_sve_lasta: 1876 case Intrinsic::aarch64_sve_lastb: 1877 return instCombineSVELast(IC, II); 1878 case Intrinsic::aarch64_sve_clasta_n: 1879 case Intrinsic::aarch64_sve_clastb_n: 1880 return instCombineSVECondLast(IC, II); 1881 case Intrinsic::aarch64_sve_cntd: 1882 return instCombineSVECntElts(IC, II, 2); 1883 case Intrinsic::aarch64_sve_cntw: 1884 return instCombineSVECntElts(IC, II, 4); 1885 case Intrinsic::aarch64_sve_cnth: 1886 return instCombineSVECntElts(IC, II, 8); 1887 case Intrinsic::aarch64_sve_cntb: 1888 return instCombineSVECntElts(IC, II, 16); 1889 case Intrinsic::aarch64_sve_ptest_any: 1890 case Intrinsic::aarch64_sve_ptest_first: 1891 case Intrinsic::aarch64_sve_ptest_last: 1892 return instCombineSVEPTest(IC, II); 1893 case Intrinsic::aarch64_sve_fabd: 1894 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fabd_u); 1895 case Intrinsic::aarch64_sve_fadd: 1896 return instCombineSVEVectorFAdd(IC, II); 1897 case Intrinsic::aarch64_sve_fadd_u: 1898 return instCombineSVEVectorFAddU(IC, II); 1899 case Intrinsic::aarch64_sve_fdiv: 1900 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fdiv_u); 1901 case Intrinsic::aarch64_sve_fmax: 1902 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmax_u); 1903 case Intrinsic::aarch64_sve_fmaxnm: 1904 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmaxnm_u); 1905 case Intrinsic::aarch64_sve_fmin: 1906 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmin_u); 1907 case Intrinsic::aarch64_sve_fminnm: 1908 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fminnm_u); 1909 case Intrinsic::aarch64_sve_fmla: 1910 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmla_u); 1911 case Intrinsic::aarch64_sve_fmls: 1912 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmls_u); 1913 case Intrinsic::aarch64_sve_fmul: 1914 case Intrinsic::aarch64_sve_fmul_u: 1915 return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_fmul_u); 1916 case Intrinsic::aarch64_sve_fmulx: 1917 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fmulx_u); 1918 case Intrinsic::aarch64_sve_fnmla: 1919 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fnmla_u); 1920 case Intrinsic::aarch64_sve_fnmls: 1921 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_fnmls_u); 1922 case Intrinsic::aarch64_sve_fsub: 1923 return instCombineSVEVectorFSub(IC, II); 1924 case Intrinsic::aarch64_sve_fsub_u: 1925 return instCombineSVEVectorFSubU(IC, II); 1926 case Intrinsic::aarch64_sve_add: 1927 return instCombineSVEVectorAdd(IC, II); 1928 case Intrinsic::aarch64_sve_add_u: 1929 return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul_u, 1930 Intrinsic::aarch64_sve_mla_u>( 1931 IC, II, true); 1932 case Intrinsic::aarch64_sve_mla: 1933 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_mla_u); 1934 case Intrinsic::aarch64_sve_mls: 1935 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_mls_u); 1936 case Intrinsic::aarch64_sve_mul: 1937 case Intrinsic::aarch64_sve_mul_u: 1938 return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_mul_u); 1939 case Intrinsic::aarch64_sve_sabd: 1940 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_sabd_u); 1941 case Intrinsic::aarch64_sve_smax: 1942 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_smax_u); 1943 case Intrinsic::aarch64_sve_smin: 1944 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_smin_u); 1945 case Intrinsic::aarch64_sve_smulh: 1946 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_smulh_u); 1947 case Intrinsic::aarch64_sve_sub: 1948 return instCombineSVEVectorSub(IC, II); 1949 case Intrinsic::aarch64_sve_sub_u: 1950 return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul_u, 1951 Intrinsic::aarch64_sve_mls_u>( 1952 IC, II, true); 1953 case Intrinsic::aarch64_sve_uabd: 1954 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_uabd_u); 1955 case Intrinsic::aarch64_sve_umax: 1956 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_umax_u); 1957 case Intrinsic::aarch64_sve_umin: 1958 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_umin_u); 1959 case Intrinsic::aarch64_sve_umulh: 1960 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_umulh_u); 1961 case Intrinsic::aarch64_sve_asr: 1962 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_asr_u); 1963 case Intrinsic::aarch64_sve_lsl: 1964 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_lsl_u); 1965 case Intrinsic::aarch64_sve_lsr: 1966 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_lsr_u); 1967 case Intrinsic::aarch64_sve_and: 1968 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_and_u); 1969 case Intrinsic::aarch64_sve_bic: 1970 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_bic_u); 1971 case Intrinsic::aarch64_sve_eor: 1972 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_eor_u); 1973 case Intrinsic::aarch64_sve_orr: 1974 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_orr_u); 1975 case Intrinsic::aarch64_sve_sqsub: 1976 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_sqsub_u); 1977 case Intrinsic::aarch64_sve_uqsub: 1978 return instCombineSVEAllActive(II, Intrinsic::aarch64_sve_uqsub_u); 1979 case Intrinsic::aarch64_sve_tbl: 1980 return instCombineSVETBL(IC, II); 1981 case Intrinsic::aarch64_sve_uunpkhi: 1982 case Intrinsic::aarch64_sve_uunpklo: 1983 case Intrinsic::aarch64_sve_sunpkhi: 1984 case Intrinsic::aarch64_sve_sunpklo: 1985 return instCombineSVEUnpack(IC, II); 1986 case Intrinsic::aarch64_sve_zip1: 1987 case Intrinsic::aarch64_sve_zip2: 1988 return instCombineSVEZip(IC, II); 1989 case Intrinsic::aarch64_sve_ld1_gather_index: 1990 return instCombineLD1GatherIndex(IC, II); 1991 case Intrinsic::aarch64_sve_st1_scatter_index: 1992 return instCombineST1ScatterIndex(IC, II); 1993 case Intrinsic::aarch64_sve_ld1: 1994 return instCombineSVELD1(IC, II, DL); 1995 case Intrinsic::aarch64_sve_st1: 1996 return instCombineSVEST1(IC, II, DL); 1997 case Intrinsic::aarch64_sve_sdiv: 1998 return instCombineSVESDIV(IC, II); 1999 case Intrinsic::aarch64_sve_sel: 2000 return instCombineSVESel(IC, II); 2001 case Intrinsic::aarch64_sve_srshl: 2002 return instCombineSVESrshl(IC, II); 2003 case Intrinsic::aarch64_sve_dupq_lane: 2004 return instCombineSVEDupqLane(IC, II); 2005 } 2006 2007 return std::nullopt; 2008 } 2009 2010 std::optional<Value *> AArch64TTIImpl::simplifyDemandedVectorEltsIntrinsic( 2011 InstCombiner &IC, IntrinsicInst &II, APInt OrigDemandedElts, 2012 APInt &UndefElts, APInt &UndefElts2, APInt &UndefElts3, 2013 std::function<void(Instruction *, unsigned, APInt, APInt &)> 2014 SimplifyAndSetOp) const { 2015 switch (II.getIntrinsicID()) { 2016 default: 2017 break; 2018 case Intrinsic::aarch64_neon_fcvtxn: 2019 case Intrinsic::aarch64_neon_rshrn: 2020 case Intrinsic::aarch64_neon_sqrshrn: 2021 case Intrinsic::aarch64_neon_sqrshrun: 2022 case Intrinsic::aarch64_neon_sqshrn: 2023 case Intrinsic::aarch64_neon_sqshrun: 2024 case Intrinsic::aarch64_neon_sqxtn: 2025 case Intrinsic::aarch64_neon_sqxtun: 2026 case Intrinsic::aarch64_neon_uqrshrn: 2027 case Intrinsic::aarch64_neon_uqshrn: 2028 case Intrinsic::aarch64_neon_uqxtn: 2029 SimplifyAndSetOp(&II, 0, OrigDemandedElts, UndefElts); 2030 break; 2031 } 2032 2033 return std::nullopt; 2034 } 2035 2036 TypeSize 2037 AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const { 2038 switch (K) { 2039 case TargetTransformInfo::RGK_Scalar: 2040 return TypeSize::getFixed(64); 2041 case TargetTransformInfo::RGK_FixedWidthVector: 2042 if (!ST->isNeonAvailable() && !EnableFixedwidthAutovecInStreamingMode) 2043 return TypeSize::getFixed(0); 2044 2045 if (ST->hasSVE()) 2046 return TypeSize::getFixed( 2047 std::max(ST->getMinSVEVectorSizeInBits(), 128u)); 2048 2049 return TypeSize::getFixed(ST->hasNEON() ? 128 : 0); 2050 case TargetTransformInfo::RGK_ScalableVector: 2051 if (!ST->isSVEAvailable() && !EnableScalableAutovecInStreamingMode) 2052 return TypeSize::getScalable(0); 2053 2054 return TypeSize::getScalable(ST->hasSVE() ? 128 : 0); 2055 } 2056 llvm_unreachable("Unsupported register kind"); 2057 } 2058 2059 bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode, 2060 ArrayRef<const Value *> Args, 2061 Type *SrcOverrideTy) { 2062 // A helper that returns a vector type from the given type. The number of 2063 // elements in type Ty determines the vector width. 2064 auto toVectorTy = [&](Type *ArgTy) { 2065 return VectorType::get(ArgTy->getScalarType(), 2066 cast<VectorType>(DstTy)->getElementCount()); 2067 }; 2068 2069 // Exit early if DstTy is not a vector type whose elements are one of [i16, 2070 // i32, i64]. SVE doesn't generally have the same set of instructions to 2071 // perform an extend with the add/sub/mul. There are SMULLB style 2072 // instructions, but they operate on top/bottom, requiring some sort of lane 2073 // interleaving to be used with zext/sext. 2074 unsigned DstEltSize = DstTy->getScalarSizeInBits(); 2075 if (!useNeonVector(DstTy) || Args.size() != 2 || 2076 (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64)) 2077 return false; 2078 2079 // Determine if the operation has a widening variant. We consider both the 2080 // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the 2081 // instructions. 2082 // 2083 // TODO: Add additional widening operations (e.g., shl, etc.) once we 2084 // verify that their extending operands are eliminated during code 2085 // generation. 2086 Type *SrcTy = SrcOverrideTy; 2087 switch (Opcode) { 2088 case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2). 2089 case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2). 2090 // The second operand needs to be an extend 2091 if (isa<SExtInst>(Args[1]) || isa<ZExtInst>(Args[1])) { 2092 if (!SrcTy) 2093 SrcTy = 2094 toVectorTy(cast<Instruction>(Args[1])->getOperand(0)->getType()); 2095 } else 2096 return false; 2097 break; 2098 case Instruction::Mul: { // SMULL(2), UMULL(2) 2099 // Both operands need to be extends of the same type. 2100 if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) || 2101 (isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) { 2102 if (!SrcTy) 2103 SrcTy = 2104 toVectorTy(cast<Instruction>(Args[0])->getOperand(0)->getType()); 2105 } else if (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1])) { 2106 // If one of the operands is a Zext and the other has enough zero bits to 2107 // be treated as unsigned, we can still general a umull, meaning the zext 2108 // is free. 2109 KnownBits Known = 2110 computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL); 2111 if (Args[0]->getType()->getScalarSizeInBits() - 2112 Known.Zero.countLeadingOnes() > 2113 DstTy->getScalarSizeInBits() / 2) 2114 return false; 2115 if (!SrcTy) 2116 SrcTy = toVectorTy(Type::getIntNTy(DstTy->getContext(), 2117 DstTy->getScalarSizeInBits() / 2)); 2118 } else 2119 return false; 2120 break; 2121 } 2122 default: 2123 return false; 2124 } 2125 2126 // Legalize the destination type and ensure it can be used in a widening 2127 // operation. 2128 auto DstTyL = getTypeLegalizationCost(DstTy); 2129 if (!DstTyL.second.isVector() || DstEltSize != DstTy->getScalarSizeInBits()) 2130 return false; 2131 2132 // Legalize the source type and ensure it can be used in a widening 2133 // operation. 2134 assert(SrcTy && "Expected some SrcTy"); 2135 auto SrcTyL = getTypeLegalizationCost(SrcTy); 2136 unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits(); 2137 if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits()) 2138 return false; 2139 2140 // Get the total number of vector elements in the legalized types. 2141 InstructionCost NumDstEls = 2142 DstTyL.first * DstTyL.second.getVectorMinNumElements(); 2143 InstructionCost NumSrcEls = 2144 SrcTyL.first * SrcTyL.second.getVectorMinNumElements(); 2145 2146 // Return true if the legalized types have the same number of vector elements 2147 // and the destination element type size is twice that of the source type. 2148 return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize; 2149 } 2150 2151 // s/urhadd instructions implement the following pattern, making the 2152 // extends free: 2153 // %x = add ((zext i8 -> i16), 1) 2154 // %y = (zext i8 -> i16) 2155 // trunc i16 (lshr (add %x, %y), 1) -> i8 2156 // 2157 bool AArch64TTIImpl::isExtPartOfAvgExpr(const Instruction *ExtUser, Type *Dst, 2158 Type *Src) { 2159 // The source should be a legal vector type. 2160 if (!Src->isVectorTy() || !TLI->isTypeLegal(TLI->getValueType(DL, Src)) || 2161 (Src->isScalableTy() && !ST->hasSVE2())) 2162 return false; 2163 2164 if (ExtUser->getOpcode() != Instruction::Add || !ExtUser->hasOneUse()) 2165 return false; 2166 2167 // Look for trunc/shl/add before trying to match the pattern. 2168 const Instruction *Add = ExtUser; 2169 auto *AddUser = 2170 dyn_cast_or_null<Instruction>(Add->getUniqueUndroppableUser()); 2171 if (AddUser && AddUser->getOpcode() == Instruction::Add) 2172 Add = AddUser; 2173 2174 auto *Shr = dyn_cast_or_null<Instruction>(Add->getUniqueUndroppableUser()); 2175 if (!Shr || Shr->getOpcode() != Instruction::LShr) 2176 return false; 2177 2178 auto *Trunc = dyn_cast_or_null<Instruction>(Shr->getUniqueUndroppableUser()); 2179 if (!Trunc || Trunc->getOpcode() != Instruction::Trunc || 2180 Src->getScalarSizeInBits() != 2181 cast<CastInst>(Trunc)->getDestTy()->getScalarSizeInBits()) 2182 return false; 2183 2184 // Try to match the whole pattern. Ext could be either the first or second 2185 // m_ZExtOrSExt matched. 2186 Instruction *Ex1, *Ex2; 2187 if (!(match(Add, m_c_Add(m_Instruction(Ex1), 2188 m_c_Add(m_Instruction(Ex2), m_SpecificInt(1)))))) 2189 return false; 2190 2191 // Ensure both extends are of the same type 2192 if (match(Ex1, m_ZExtOrSExt(m_Value())) && 2193 Ex1->getOpcode() == Ex2->getOpcode()) 2194 return true; 2195 2196 return false; 2197 } 2198 2199 InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, 2200 Type *Src, 2201 TTI::CastContextHint CCH, 2202 TTI::TargetCostKind CostKind, 2203 const Instruction *I) { 2204 int ISD = TLI->InstructionOpcodeToISD(Opcode); 2205 assert(ISD && "Invalid opcode"); 2206 // If the cast is observable, and it is used by a widening instruction (e.g., 2207 // uaddl, saddw, etc.), it may be free. 2208 if (I && I->hasOneUser()) { 2209 auto *SingleUser = cast<Instruction>(*I->user_begin()); 2210 SmallVector<const Value *, 4> Operands(SingleUser->operand_values()); 2211 if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands, Src)) { 2212 // For adds only count the second operand as free if both operands are 2213 // extends but not the same operation. (i.e both operands are not free in 2214 // add(sext, zext)). 2215 if (SingleUser->getOpcode() == Instruction::Add) { 2216 if (I == SingleUser->getOperand(1) || 2217 (isa<CastInst>(SingleUser->getOperand(1)) && 2218 cast<CastInst>(SingleUser->getOperand(1))->getOpcode() == Opcode)) 2219 return 0; 2220 } else // Others are free so long as isWideningInstruction returned true. 2221 return 0; 2222 } 2223 2224 // The cast will be free for the s/urhadd instructions 2225 if ((isa<ZExtInst>(I) || isa<SExtInst>(I)) && 2226 isExtPartOfAvgExpr(SingleUser, Dst, Src)) 2227 return 0; 2228 } 2229 2230 // TODO: Allow non-throughput costs that aren't binary. 2231 auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost { 2232 if (CostKind != TTI::TCK_RecipThroughput) 2233 return Cost == 0 ? 0 : 1; 2234 return Cost; 2235 }; 2236 2237 EVT SrcTy = TLI->getValueType(DL, Src); 2238 EVT DstTy = TLI->getValueType(DL, Dst); 2239 2240 if (!SrcTy.isSimple() || !DstTy.isSimple()) 2241 return AdjustCost( 2242 BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I)); 2243 2244 static const TypeConversionCostTblEntry 2245 ConversionTbl[] = { 2246 { ISD::TRUNCATE, MVT::v2i8, MVT::v2i64, 1}, // xtn 2247 { ISD::TRUNCATE, MVT::v2i16, MVT::v2i64, 1}, // xtn 2248 { ISD::TRUNCATE, MVT::v2i32, MVT::v2i64, 1}, // xtn 2249 { ISD::TRUNCATE, MVT::v4i8, MVT::v4i32, 1}, // xtn 2250 { ISD::TRUNCATE, MVT::v4i8, MVT::v4i64, 3}, // 2 xtn + 1 uzp1 2251 { ISD::TRUNCATE, MVT::v4i16, MVT::v4i32, 1}, // xtn 2252 { ISD::TRUNCATE, MVT::v4i16, MVT::v4i64, 2}, // 1 uzp1 + 1 xtn 2253 { ISD::TRUNCATE, MVT::v4i32, MVT::v4i64, 1}, // 1 uzp1 2254 { ISD::TRUNCATE, MVT::v8i8, MVT::v8i16, 1}, // 1 xtn 2255 { ISD::TRUNCATE, MVT::v8i8, MVT::v8i32, 2}, // 1 uzp1 + 1 xtn 2256 { ISD::TRUNCATE, MVT::v8i8, MVT::v8i64, 4}, // 3 x uzp1 + xtn 2257 { ISD::TRUNCATE, MVT::v8i16, MVT::v8i32, 1}, // 1 uzp1 2258 { ISD::TRUNCATE, MVT::v8i16, MVT::v8i64, 3}, // 3 x uzp1 2259 { ISD::TRUNCATE, MVT::v8i32, MVT::v8i64, 2}, // 2 x uzp1 2260 { ISD::TRUNCATE, MVT::v16i8, MVT::v16i16, 1}, // uzp1 2261 { ISD::TRUNCATE, MVT::v16i8, MVT::v16i32, 3}, // (2 + 1) x uzp1 2262 { ISD::TRUNCATE, MVT::v16i8, MVT::v16i64, 7}, // (4 + 2 + 1) x uzp1 2263 { ISD::TRUNCATE, MVT::v16i16, MVT::v16i32, 2}, // 2 x uzp1 2264 { ISD::TRUNCATE, MVT::v16i16, MVT::v16i64, 6}, // (4 + 2) x uzp1 2265 { ISD::TRUNCATE, MVT::v16i32, MVT::v16i64, 4}, // 4 x uzp1 2266 2267 // Truncations on nxvmiN 2268 { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i16, 1 }, 2269 { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i32, 1 }, 2270 { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i64, 1 }, 2271 { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i16, 1 }, 2272 { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i32, 1 }, 2273 { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i64, 2 }, 2274 { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i16, 1 }, 2275 { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i32, 3 }, 2276 { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i64, 5 }, 2277 { ISD::TRUNCATE, MVT::nxv16i1, MVT::nxv16i8, 1 }, 2278 { ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i32, 1 }, 2279 { ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 1 }, 2280 { ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 1 }, 2281 { ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 2 }, 2282 { ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 3 }, 2283 { ISD::TRUNCATE, MVT::nxv8i32, MVT::nxv8i64, 6 }, 2284 2285 // The number of shll instructions for the extension. 2286 { ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i16, 3 }, 2287 { ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i16, 3 }, 2288 { ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i32, 2 }, 2289 { ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i32, 2 }, 2290 { ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i8, 3 }, 2291 { ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i8, 3 }, 2292 { ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i16, 2 }, 2293 { ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i16, 2 }, 2294 { ISD::SIGN_EXTEND, MVT::v8i64, MVT::v8i8, 7 }, 2295 { ISD::ZERO_EXTEND, MVT::v8i64, MVT::v8i8, 7 }, 2296 { ISD::SIGN_EXTEND, MVT::v8i64, MVT::v8i16, 6 }, 2297 { ISD::ZERO_EXTEND, MVT::v8i64, MVT::v8i16, 6 }, 2298 { ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i8, 2 }, 2299 { ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i8, 2 }, 2300 { ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i8, 6 }, 2301 { ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i8, 6 }, 2302 2303 // LowerVectorINT_TO_FP: 2304 { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 }, 2305 { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 }, 2306 { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 }, 2307 { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 }, 2308 { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 }, 2309 { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 }, 2310 2311 // Complex: to v2f32 2312 { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i8, 3 }, 2313 { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 }, 2314 { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 }, 2315 { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i8, 3 }, 2316 { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 }, 2317 { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 }, 2318 2319 // Complex: to v4f32 2320 { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i8, 4 }, 2321 { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 }, 2322 { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8, 3 }, 2323 { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 }, 2324 2325 // Complex: to v8f32 2326 { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8, 10 }, 2327 { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 }, 2328 { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8, 10 }, 2329 { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 }, 2330 2331 // Complex: to v16f32 2332 { ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 }, 2333 { ISD::UINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 }, 2334 2335 // Complex: to v2f64 2336 { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i8, 4 }, 2337 { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 }, 2338 { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 }, 2339 { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i8, 4 }, 2340 { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 }, 2341 { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 }, 2342 2343 // Complex: to v4f64 2344 { ISD::SINT_TO_FP, MVT::v4f64, MVT::v4i32, 4 }, 2345 { ISD::UINT_TO_FP, MVT::v4f64, MVT::v4i32, 4 }, 2346 2347 // LowerVectorFP_TO_INT 2348 { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f32, 1 }, 2349 { ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f32, 1 }, 2350 { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f64, 1 }, 2351 { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f32, 1 }, 2352 { ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f32, 1 }, 2353 { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f64, 1 }, 2354 2355 // Complex, from v2f32: legal type is v2i32 (no cost) or v2i64 (1 ext). 2356 { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f32, 2 }, 2357 { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f32, 1 }, 2358 { ISD::FP_TO_SINT, MVT::v2i8, MVT::v2f32, 1 }, 2359 { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f32, 2 }, 2360 { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f32, 1 }, 2361 { ISD::FP_TO_UINT, MVT::v2i8, MVT::v2f32, 1 }, 2362 2363 // Complex, from v4f32: legal type is v4i16, 1 narrowing => ~2 2364 { ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f32, 2 }, 2365 { ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f32, 2 }, 2366 { ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 2 }, 2367 { ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f32, 2 }, 2368 2369 // Complex, from nxv2f32. 2370 { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f32, 1 }, 2371 { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f32, 1 }, 2372 { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f32, 1 }, 2373 { ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f32, 1 }, 2374 { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f32, 1 }, 2375 { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f32, 1 }, 2376 { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f32, 1 }, 2377 { ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f32, 1 }, 2378 2379 // Complex, from v2f64: legal type is v2i32, 1 narrowing => ~2. 2380 { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f64, 2 }, 2381 { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f64, 2 }, 2382 { ISD::FP_TO_SINT, MVT::v2i8, MVT::v2f64, 2 }, 2383 { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f64, 2 }, 2384 { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f64, 2 }, 2385 { ISD::FP_TO_UINT, MVT::v2i8, MVT::v2f64, 2 }, 2386 2387 // Complex, from nxv2f64. 2388 { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f64, 1 }, 2389 { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f64, 1 }, 2390 { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f64, 1 }, 2391 { ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f64, 1 }, 2392 { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f64, 1 }, 2393 { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f64, 1 }, 2394 { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f64, 1 }, 2395 { ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f64, 1 }, 2396 2397 // Complex, from nxv4f32. 2398 { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f32, 4 }, 2399 { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f32, 1 }, 2400 { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f32, 1 }, 2401 { ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f32, 1 }, 2402 { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f32, 4 }, 2403 { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f32, 1 }, 2404 { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f32, 1 }, 2405 { ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f32, 1 }, 2406 2407 // Complex, from nxv8f64. Illegal -> illegal conversions not required. 2408 { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f64, 7 }, 2409 { ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f64, 7 }, 2410 { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f64, 7 }, 2411 { ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f64, 7 }, 2412 2413 // Complex, from nxv4f64. Illegal -> illegal conversions not required. 2414 { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f64, 3 }, 2415 { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f64, 3 }, 2416 { ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f64, 3 }, 2417 { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f64, 3 }, 2418 { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f64, 3 }, 2419 { ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f64, 3 }, 2420 2421 // Complex, from nxv8f32. Illegal -> illegal conversions not required. 2422 { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f32, 3 }, 2423 { ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f32, 3 }, 2424 { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f32, 3 }, 2425 { ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f32, 3 }, 2426 2427 // Complex, from nxv8f16. 2428 { ISD::FP_TO_SINT, MVT::nxv8i64, MVT::nxv8f16, 10 }, 2429 { ISD::FP_TO_SINT, MVT::nxv8i32, MVT::nxv8f16, 4 }, 2430 { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f16, 1 }, 2431 { ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f16, 1 }, 2432 { ISD::FP_TO_UINT, MVT::nxv8i64, MVT::nxv8f16, 10 }, 2433 { ISD::FP_TO_UINT, MVT::nxv8i32, MVT::nxv8f16, 4 }, 2434 { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f16, 1 }, 2435 { ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f16, 1 }, 2436 2437 // Complex, from nxv4f16. 2438 { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f16, 4 }, 2439 { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f16, 1 }, 2440 { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f16, 1 }, 2441 { ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f16, 1 }, 2442 { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f16, 4 }, 2443 { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f16, 1 }, 2444 { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f16, 1 }, 2445 { ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f16, 1 }, 2446 2447 // Complex, from nxv2f16. 2448 { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f16, 1 }, 2449 { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f16, 1 }, 2450 { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f16, 1 }, 2451 { ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f16, 1 }, 2452 { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f16, 1 }, 2453 { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f16, 1 }, 2454 { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f16, 1 }, 2455 { ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f16, 1 }, 2456 2457 // Truncate from nxvmf32 to nxvmf16. 2458 { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f32, 1 }, 2459 { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1 }, 2460 { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3 }, 2461 2462 // Truncate from nxvmf64 to nxvmf16. 2463 { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1 }, 2464 { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3 }, 2465 { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7 }, 2466 2467 // Truncate from nxvmf64 to nxvmf32. 2468 { ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1 }, 2469 { ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3 }, 2470 { ISD::FP_ROUND, MVT::nxv8f32, MVT::nxv8f64, 6 }, 2471 2472 // Extend from nxvmf16 to nxvmf32. 2473 { ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2f16, 1}, 2474 { ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1}, 2475 { ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2}, 2476 2477 // Extend from nxvmf16 to nxvmf64. 2478 { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1}, 2479 { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2}, 2480 { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4}, 2481 2482 // Extend from nxvmf32 to nxvmf64. 2483 { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1}, 2484 { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2}, 2485 { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f32, 6}, 2486 2487 // Bitcasts from float to integer 2488 { ISD::BITCAST, MVT::nxv2f16, MVT::nxv2i16, 0 }, 2489 { ISD::BITCAST, MVT::nxv4f16, MVT::nxv4i16, 0 }, 2490 { ISD::BITCAST, MVT::nxv2f32, MVT::nxv2i32, 0 }, 2491 2492 // Bitcasts from integer to float 2493 { ISD::BITCAST, MVT::nxv2i16, MVT::nxv2f16, 0 }, 2494 { ISD::BITCAST, MVT::nxv4i16, MVT::nxv4f16, 0 }, 2495 { ISD::BITCAST, MVT::nxv2i32, MVT::nxv2f32, 0 }, 2496 2497 // Add cost for extending to illegal -too wide- scalable vectors. 2498 // zero/sign extend are implemented by multiple unpack operations, 2499 // where each operation has a cost of 1. 2500 { ISD::ZERO_EXTEND, MVT::nxv16i16, MVT::nxv16i8, 2}, 2501 { ISD::ZERO_EXTEND, MVT::nxv16i32, MVT::nxv16i8, 6}, 2502 { ISD::ZERO_EXTEND, MVT::nxv16i64, MVT::nxv16i8, 14}, 2503 { ISD::ZERO_EXTEND, MVT::nxv8i32, MVT::nxv8i16, 2}, 2504 { ISD::ZERO_EXTEND, MVT::nxv8i64, MVT::nxv8i16, 6}, 2505 { ISD::ZERO_EXTEND, MVT::nxv4i64, MVT::nxv4i32, 2}, 2506 2507 { ISD::SIGN_EXTEND, MVT::nxv16i16, MVT::nxv16i8, 2}, 2508 { ISD::SIGN_EXTEND, MVT::nxv16i32, MVT::nxv16i8, 6}, 2509 { ISD::SIGN_EXTEND, MVT::nxv16i64, MVT::nxv16i8, 14}, 2510 { ISD::SIGN_EXTEND, MVT::nxv8i32, MVT::nxv8i16, 2}, 2511 { ISD::SIGN_EXTEND, MVT::nxv8i64, MVT::nxv8i16, 6}, 2512 { ISD::SIGN_EXTEND, MVT::nxv4i64, MVT::nxv4i32, 2}, 2513 }; 2514 2515 // We have to estimate a cost of fixed length operation upon 2516 // SVE registers(operations) with the number of registers required 2517 // for a fixed type to be represented upon SVE registers. 2518 EVT WiderTy = SrcTy.bitsGT(DstTy) ? SrcTy : DstTy; 2519 if (SrcTy.isFixedLengthVector() && DstTy.isFixedLengthVector() && 2520 SrcTy.getVectorNumElements() == DstTy.getVectorNumElements() && 2521 ST->useSVEForFixedLengthVectors(WiderTy)) { 2522 std::pair<InstructionCost, MVT> LT = 2523 getTypeLegalizationCost(WiderTy.getTypeForEVT(Dst->getContext())); 2524 unsigned NumElements = AArch64::SVEBitsPerBlock / 2525 LT.second.getVectorElementType().getSizeInBits(); 2526 return AdjustCost( 2527 LT.first * 2528 getCastInstrCost( 2529 Opcode, ScalableVectorType::get(Dst->getScalarType(), NumElements), 2530 ScalableVectorType::get(Src->getScalarType(), NumElements), CCH, 2531 CostKind, I)); 2532 } 2533 2534 if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD, 2535 DstTy.getSimpleVT(), 2536 SrcTy.getSimpleVT())) 2537 return AdjustCost(Entry->Cost); 2538 2539 static const TypeConversionCostTblEntry FP16Tbl[] = { 2540 {ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f16, 1}, // fcvtzs 2541 {ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f16, 1}, 2542 {ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f16, 1}, // fcvtzs 2543 {ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f16, 1}, 2544 {ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f16, 2}, // fcvtl+fcvtzs 2545 {ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f16, 2}, 2546 {ISD::FP_TO_SINT, MVT::v8i8, MVT::v8f16, 2}, // fcvtzs+xtn 2547 {ISD::FP_TO_UINT, MVT::v8i8, MVT::v8f16, 2}, 2548 {ISD::FP_TO_SINT, MVT::v8i16, MVT::v8f16, 1}, // fcvtzs 2549 {ISD::FP_TO_UINT, MVT::v8i16, MVT::v8f16, 1}, 2550 {ISD::FP_TO_SINT, MVT::v8i32, MVT::v8f16, 4}, // 2*fcvtl+2*fcvtzs 2551 {ISD::FP_TO_UINT, MVT::v8i32, MVT::v8f16, 4}, 2552 {ISD::FP_TO_SINT, MVT::v16i8, MVT::v16f16, 3}, // 2*fcvtzs+xtn 2553 {ISD::FP_TO_UINT, MVT::v16i8, MVT::v16f16, 3}, 2554 {ISD::FP_TO_SINT, MVT::v16i16, MVT::v16f16, 2}, // 2*fcvtzs 2555 {ISD::FP_TO_UINT, MVT::v16i16, MVT::v16f16, 2}, 2556 {ISD::FP_TO_SINT, MVT::v16i32, MVT::v16f16, 8}, // 4*fcvtl+4*fcvtzs 2557 {ISD::FP_TO_UINT, MVT::v16i32, MVT::v16f16, 8}, 2558 {ISD::UINT_TO_FP, MVT::v8f16, MVT::v8i8, 2}, // ushll + ucvtf 2559 {ISD::SINT_TO_FP, MVT::v8f16, MVT::v8i8, 2}, // sshll + scvtf 2560 {ISD::UINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * ushl(2) + 2 * ucvtf 2561 {ISD::SINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * sshl(2) + 2 * scvtf 2562 }; 2563 2564 if (ST->hasFullFP16()) 2565 if (const auto *Entry = ConvertCostTableLookup( 2566 FP16Tbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT())) 2567 return AdjustCost(Entry->Cost); 2568 2569 if ((ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND) && 2570 CCH == TTI::CastContextHint::Masked && ST->hasSVEorSME() && 2571 TLI->getTypeAction(Src->getContext(), SrcTy) == 2572 TargetLowering::TypePromoteInteger && 2573 TLI->getTypeAction(Dst->getContext(), DstTy) == 2574 TargetLowering::TypeSplitVector) { 2575 // The standard behaviour in the backend for these cases is to split the 2576 // extend up into two parts: 2577 // 1. Perform an extending load or masked load up to the legal type. 2578 // 2. Extend the loaded data to the final type. 2579 std::pair<InstructionCost, MVT> SrcLT = getTypeLegalizationCost(Src); 2580 Type *LegalTy = EVT(SrcLT.second).getTypeForEVT(Src->getContext()); 2581 InstructionCost Part1 = AArch64TTIImpl::getCastInstrCost( 2582 Opcode, LegalTy, Src, CCH, CostKind, I); 2583 InstructionCost Part2 = AArch64TTIImpl::getCastInstrCost( 2584 Opcode, Dst, LegalTy, TTI::CastContextHint::None, CostKind, I); 2585 return Part1 + Part2; 2586 } 2587 2588 // The BasicTTIImpl version only deals with CCH==TTI::CastContextHint::Normal, 2589 // but we also want to include the TTI::CastContextHint::Masked case too. 2590 if ((ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND) && 2591 CCH == TTI::CastContextHint::Masked && ST->hasSVEorSME() && 2592 TLI->isTypeLegal(DstTy)) 2593 CCH = TTI::CastContextHint::Normal; 2594 2595 return AdjustCost( 2596 BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I)); 2597 } 2598 2599 InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, 2600 Type *Dst, 2601 VectorType *VecTy, 2602 unsigned Index) { 2603 2604 // Make sure we were given a valid extend opcode. 2605 assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) && 2606 "Invalid opcode"); 2607 2608 // We are extending an element we extract from a vector, so the source type 2609 // of the extend is the element type of the vector. 2610 auto *Src = VecTy->getElementType(); 2611 2612 // Sign- and zero-extends are for integer types only. 2613 assert(isa<IntegerType>(Dst) && isa<IntegerType>(Src) && "Invalid type"); 2614 2615 // Get the cost for the extract. We compute the cost (if any) for the extend 2616 // below. 2617 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; 2618 InstructionCost Cost = getVectorInstrCost(Instruction::ExtractElement, VecTy, 2619 CostKind, Index, nullptr, nullptr); 2620 2621 // Legalize the types. 2622 auto VecLT = getTypeLegalizationCost(VecTy); 2623 auto DstVT = TLI->getValueType(DL, Dst); 2624 auto SrcVT = TLI->getValueType(DL, Src); 2625 2626 // If the resulting type is still a vector and the destination type is legal, 2627 // we may get the extension for free. If not, get the default cost for the 2628 // extend. 2629 if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT)) 2630 return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None, 2631 CostKind); 2632 2633 // The destination type should be larger than the element type. If not, get 2634 // the default cost for the extend. 2635 if (DstVT.getFixedSizeInBits() < SrcVT.getFixedSizeInBits()) 2636 return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None, 2637 CostKind); 2638 2639 switch (Opcode) { 2640 default: 2641 llvm_unreachable("Opcode should be either SExt or ZExt"); 2642 2643 // For sign-extends, we only need a smov, which performs the extension 2644 // automatically. 2645 case Instruction::SExt: 2646 return Cost; 2647 2648 // For zero-extends, the extend is performed automatically by a umov unless 2649 // the destination type is i64 and the element type is i8 or i16. 2650 case Instruction::ZExt: 2651 if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u) 2652 return Cost; 2653 } 2654 2655 // If we are unable to perform the extend for free, get the default cost. 2656 return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None, 2657 CostKind); 2658 } 2659 2660 InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode, 2661 TTI::TargetCostKind CostKind, 2662 const Instruction *I) { 2663 if (CostKind != TTI::TCK_RecipThroughput) 2664 return Opcode == Instruction::PHI ? 0 : 1; 2665 assert(CostKind == TTI::TCK_RecipThroughput && "unexpected CostKind"); 2666 // Branches are assumed to be predicted. 2667 return 0; 2668 } 2669 2670 InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I, 2671 Type *Val, 2672 unsigned Index, 2673 bool HasRealUse) { 2674 assert(Val->isVectorTy() && "This must be a vector type"); 2675 2676 if (Index != -1U) { 2677 // Legalize the type. 2678 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Val); 2679 2680 // This type is legalized to a scalar type. 2681 if (!LT.second.isVector()) 2682 return 0; 2683 2684 // The type may be split. For fixed-width vectors we can normalize the 2685 // index to the new type. 2686 if (LT.second.isFixedLengthVector()) { 2687 unsigned Width = LT.second.getVectorNumElements(); 2688 Index = Index % Width; 2689 } 2690 2691 // The element at index zero is already inside the vector. 2692 // - For a physical (HasRealUse==true) insert-element or extract-element 2693 // instruction that extracts integers, an explicit FPR -> GPR move is 2694 // needed. So it has non-zero cost. 2695 // - For the rest of cases (virtual instruction or element type is float), 2696 // consider the instruction free. 2697 if (Index == 0 && (!HasRealUse || !Val->getScalarType()->isIntegerTy())) 2698 return 0; 2699 2700 // This is recognising a LD1 single-element structure to one lane of one 2701 // register instruction. I.e., if this is an `insertelement` instruction, 2702 // and its second operand is a load, then we will generate a LD1, which 2703 // are expensive instructions. 2704 if (I && dyn_cast<LoadInst>(I->getOperand(1))) 2705 return ST->getVectorInsertExtractBaseCost() + 1; 2706 2707 // i1 inserts and extract will include an extra cset or cmp of the vector 2708 // value. Increase the cost by 1 to account. 2709 if (Val->getScalarSizeInBits() == 1) 2710 return ST->getVectorInsertExtractBaseCost() + 1; 2711 2712 // FIXME: 2713 // If the extract-element and insert-element instructions could be 2714 // simplified away (e.g., could be combined into users by looking at use-def 2715 // context), they have no cost. This is not done in the first place for 2716 // compile-time considerations. 2717 } 2718 2719 // All other insert/extracts cost this much. 2720 return ST->getVectorInsertExtractBaseCost(); 2721 } 2722 2723 InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, 2724 TTI::TargetCostKind CostKind, 2725 unsigned Index, Value *Op0, 2726 Value *Op1) { 2727 bool HasRealUse = 2728 Opcode == Instruction::InsertElement && Op0 && !isa<UndefValue>(Op0); 2729 return getVectorInstrCostHelper(nullptr, Val, Index, HasRealUse); 2730 } 2731 2732 InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I, 2733 Type *Val, 2734 TTI::TargetCostKind CostKind, 2735 unsigned Index) { 2736 return getVectorInstrCostHelper(&I, Val, Index, true /* HasRealUse */); 2737 } 2738 2739 InstructionCost AArch64TTIImpl::getScalarizationOverhead( 2740 VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract, 2741 TTI::TargetCostKind CostKind) { 2742 if (isa<ScalableVectorType>(Ty)) 2743 return InstructionCost::getInvalid(); 2744 if (Ty->getElementType()->isFloatingPointTy()) 2745 return BaseT::getScalarizationOverhead(Ty, DemandedElts, Insert, Extract, 2746 CostKind); 2747 return DemandedElts.popcount() * (Insert + Extract) * 2748 ST->getVectorInsertExtractBaseCost(); 2749 } 2750 2751 InstructionCost AArch64TTIImpl::getArithmeticInstrCost( 2752 unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind, 2753 TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info, 2754 ArrayRef<const Value *> Args, 2755 const Instruction *CxtI) { 2756 2757 // TODO: Handle more cost kinds. 2758 if (CostKind != TTI::TCK_RecipThroughput) 2759 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, 2760 Op2Info, Args, CxtI); 2761 2762 // Legalize the type. 2763 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty); 2764 int ISD = TLI->InstructionOpcodeToISD(Opcode); 2765 2766 switch (ISD) { 2767 default: 2768 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, 2769 Op2Info); 2770 case ISD::SDIV: 2771 if (Op2Info.isConstant() && Op2Info.isUniform() && Op2Info.isPowerOf2()) { 2772 // On AArch64, scalar signed division by constants power-of-two are 2773 // normally expanded to the sequence ADD + CMP + SELECT + SRA. 2774 // The OperandValue properties many not be same as that of previous 2775 // operation; conservatively assume OP_None. 2776 InstructionCost Cost = getArithmeticInstrCost( 2777 Instruction::Add, Ty, CostKind, 2778 Op1Info.getNoProps(), Op2Info.getNoProps()); 2779 Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind, 2780 Op1Info.getNoProps(), Op2Info.getNoProps()); 2781 Cost += getArithmeticInstrCost( 2782 Instruction::Select, Ty, CostKind, 2783 Op1Info.getNoProps(), Op2Info.getNoProps()); 2784 Cost += getArithmeticInstrCost(Instruction::AShr, Ty, CostKind, 2785 Op1Info.getNoProps(), Op2Info.getNoProps()); 2786 return Cost; 2787 } 2788 [[fallthrough]]; 2789 case ISD::UDIV: { 2790 if (Op2Info.isConstant() && Op2Info.isUniform()) { 2791 auto VT = TLI->getValueType(DL, Ty); 2792 if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT)) { 2793 // Vector signed division by constant are expanded to the 2794 // sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division 2795 // to MULHS + SUB + SRL + ADD + SRL. 2796 InstructionCost MulCost = getArithmeticInstrCost( 2797 Instruction::Mul, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps()); 2798 InstructionCost AddCost = getArithmeticInstrCost( 2799 Instruction::Add, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps()); 2800 InstructionCost ShrCost = getArithmeticInstrCost( 2801 Instruction::AShr, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps()); 2802 return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1; 2803 } 2804 } 2805 2806 InstructionCost Cost = BaseT::getArithmeticInstrCost( 2807 Opcode, Ty, CostKind, Op1Info, Op2Info); 2808 if (Ty->isVectorTy()) { 2809 if (TLI->isOperationLegalOrCustom(ISD, LT.second) && ST->hasSVE()) { 2810 // SDIV/UDIV operations are lowered using SVE, then we can have less 2811 // costs. 2812 if (isa<FixedVectorType>(Ty) && cast<FixedVectorType>(Ty) 2813 ->getPrimitiveSizeInBits() 2814 .getFixedValue() < 128) { 2815 EVT VT = TLI->getValueType(DL, Ty); 2816 static const CostTblEntry DivTbl[]{ 2817 {ISD::SDIV, MVT::v2i8, 5}, {ISD::SDIV, MVT::v4i8, 8}, 2818 {ISD::SDIV, MVT::v8i8, 8}, {ISD::SDIV, MVT::v2i16, 5}, 2819 {ISD::SDIV, MVT::v4i16, 5}, {ISD::SDIV, MVT::v2i32, 1}, 2820 {ISD::UDIV, MVT::v2i8, 5}, {ISD::UDIV, MVT::v4i8, 8}, 2821 {ISD::UDIV, MVT::v8i8, 8}, {ISD::UDIV, MVT::v2i16, 5}, 2822 {ISD::UDIV, MVT::v4i16, 5}, {ISD::UDIV, MVT::v2i32, 1}}; 2823 2824 const auto *Entry = CostTableLookup(DivTbl, ISD, VT.getSimpleVT()); 2825 if (nullptr != Entry) 2826 return Entry->Cost; 2827 } 2828 // For 8/16-bit elements, the cost is higher because the type 2829 // requires promotion and possibly splitting: 2830 if (LT.second.getScalarType() == MVT::i8) 2831 Cost *= 8; 2832 else if (LT.second.getScalarType() == MVT::i16) 2833 Cost *= 4; 2834 return Cost; 2835 } else { 2836 // If one of the operands is a uniform constant then the cost for each 2837 // element is Cost for insertion, extraction and division. 2838 // Insertion cost = 2, Extraction Cost = 2, Division = cost for the 2839 // operation with scalar type 2840 if ((Op1Info.isConstant() && Op1Info.isUniform()) || 2841 (Op2Info.isConstant() && Op2Info.isUniform())) { 2842 if (auto *VTy = dyn_cast<FixedVectorType>(Ty)) { 2843 InstructionCost DivCost = BaseT::getArithmeticInstrCost( 2844 Opcode, Ty->getScalarType(), CostKind, Op1Info, Op2Info); 2845 return (4 + DivCost) * VTy->getNumElements(); 2846 } 2847 } 2848 // On AArch64, without SVE, vector divisions are expanded 2849 // into scalar divisions of each pair of elements. 2850 Cost += getArithmeticInstrCost(Instruction::ExtractElement, Ty, 2851 CostKind, Op1Info, Op2Info); 2852 Cost += getArithmeticInstrCost(Instruction::InsertElement, Ty, CostKind, 2853 Op1Info, Op2Info); 2854 } 2855 2856 // TODO: if one of the arguments is scalar, then it's not necessary to 2857 // double the cost of handling the vector elements. 2858 Cost += Cost; 2859 } 2860 return Cost; 2861 } 2862 case ISD::MUL: 2863 // When SVE is available, then we can lower the v2i64 operation using 2864 // the SVE mul instruction, which has a lower cost. 2865 if (LT.second == MVT::v2i64 && ST->hasSVE()) 2866 return LT.first; 2867 2868 // When SVE is not available, there is no MUL.2d instruction, 2869 // which means mul <2 x i64> is expensive as elements are extracted 2870 // from the vectors and the muls scalarized. 2871 // As getScalarizationOverhead is a bit too pessimistic, we 2872 // estimate the cost for a i64 vector directly here, which is: 2873 // - four 2-cost i64 extracts, 2874 // - two 2-cost i64 inserts, and 2875 // - two 1-cost muls. 2876 // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with 2877 // LT.first = 2 the cost is 28. If both operands are extensions it will not 2878 // need to scalarize so the cost can be cheaper (smull or umull). 2879 // so the cost can be cheaper (smull or umull). 2880 if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args)) 2881 return LT.first; 2882 return LT.first * 14; 2883 case ISD::ADD: 2884 case ISD::XOR: 2885 case ISD::OR: 2886 case ISD::AND: 2887 case ISD::SRL: 2888 case ISD::SRA: 2889 case ISD::SHL: 2890 // These nodes are marked as 'custom' for combining purposes only. 2891 // We know that they are legal. See LowerAdd in ISelLowering. 2892 return LT.first; 2893 2894 case ISD::FNEG: 2895 case ISD::FADD: 2896 case ISD::FSUB: 2897 // Increase the cost for half and bfloat types if not architecturally 2898 // supported. 2899 if ((Ty->getScalarType()->isHalfTy() && !ST->hasFullFP16()) || 2900 (Ty->getScalarType()->isBFloatTy() && !ST->hasBF16())) 2901 return 2 * LT.first; 2902 if (!Ty->getScalarType()->isFP128Ty()) 2903 return LT.first; 2904 [[fallthrough]]; 2905 case ISD::FMUL: 2906 case ISD::FDIV: 2907 // These nodes are marked as 'custom' just to lower them to SVE. 2908 // We know said lowering will incur no additional cost. 2909 if (!Ty->getScalarType()->isFP128Ty()) 2910 return 2 * LT.first; 2911 2912 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, 2913 Op2Info); 2914 } 2915 } 2916 2917 InstructionCost AArch64TTIImpl::getAddressComputationCost(Type *Ty, 2918 ScalarEvolution *SE, 2919 const SCEV *Ptr) { 2920 // Address computations in vectorized code with non-consecutive addresses will 2921 // likely result in more instructions compared to scalar code where the 2922 // computation can more often be merged into the index mode. The resulting 2923 // extra micro-ops can significantly decrease throughput. 2924 unsigned NumVectorInstToHideOverhead = NeonNonConstStrideOverhead; 2925 int MaxMergeDistance = 64; 2926 2927 if (Ty->isVectorTy() && SE && 2928 !BaseT::isConstantStridedAccessLessThan(SE, Ptr, MaxMergeDistance + 1)) 2929 return NumVectorInstToHideOverhead; 2930 2931 // In many cases the address computation is not merged into the instruction 2932 // addressing mode. 2933 return 1; 2934 } 2935 2936 InstructionCost AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, 2937 Type *CondTy, 2938 CmpInst::Predicate VecPred, 2939 TTI::TargetCostKind CostKind, 2940 const Instruction *I) { 2941 // TODO: Handle other cost kinds. 2942 if (CostKind != TTI::TCK_RecipThroughput) 2943 return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, 2944 I); 2945 2946 int ISD = TLI->InstructionOpcodeToISD(Opcode); 2947 // We don't lower some vector selects well that are wider than the register 2948 // width. 2949 if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT) { 2950 // We would need this many instructions to hide the scalarization happening. 2951 const int AmortizationCost = 20; 2952 2953 // If VecPred is not set, check if we can get a predicate from the context 2954 // instruction, if its type matches the requested ValTy. 2955 if (VecPred == CmpInst::BAD_ICMP_PREDICATE && I && I->getType() == ValTy) { 2956 CmpInst::Predicate CurrentPred; 2957 if (match(I, m_Select(m_Cmp(CurrentPred, m_Value(), m_Value()), m_Value(), 2958 m_Value()))) 2959 VecPred = CurrentPred; 2960 } 2961 // Check if we have a compare/select chain that can be lowered using 2962 // a (F)CMxx & BFI pair. 2963 if (CmpInst::isIntPredicate(VecPred) || VecPred == CmpInst::FCMP_OLE || 2964 VecPred == CmpInst::FCMP_OLT || VecPred == CmpInst::FCMP_OGT || 2965 VecPred == CmpInst::FCMP_OGE || VecPred == CmpInst::FCMP_OEQ || 2966 VecPred == CmpInst::FCMP_UNE) { 2967 static const auto ValidMinMaxTys = { 2968 MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16, MVT::v2i32, 2969 MVT::v4i32, MVT::v2i64, MVT::v2f32, MVT::v4f32, MVT::v2f64}; 2970 static const auto ValidFP16MinMaxTys = {MVT::v4f16, MVT::v8f16}; 2971 2972 auto LT = getTypeLegalizationCost(ValTy); 2973 if (any_of(ValidMinMaxTys, [<](MVT M) { return M == LT.second; }) || 2974 (ST->hasFullFP16() && 2975 any_of(ValidFP16MinMaxTys, [<](MVT M) { return M == LT.second; }))) 2976 return LT.first; 2977 } 2978 2979 static const TypeConversionCostTblEntry 2980 VectorSelectTbl[] = { 2981 { ISD::SELECT, MVT::v2i1, MVT::v2f32, 2 }, 2982 { ISD::SELECT, MVT::v2i1, MVT::v2f64, 2 }, 2983 { ISD::SELECT, MVT::v4i1, MVT::v4f32, 2 }, 2984 { ISD::SELECT, MVT::v4i1, MVT::v4f16, 2 }, 2985 { ISD::SELECT, MVT::v8i1, MVT::v8f16, 2 }, 2986 { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 }, 2987 { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 }, 2988 { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 }, 2989 { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost }, 2990 { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost }, 2991 { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost } 2992 }; 2993 2994 EVT SelCondTy = TLI->getValueType(DL, CondTy); 2995 EVT SelValTy = TLI->getValueType(DL, ValTy); 2996 if (SelCondTy.isSimple() && SelValTy.isSimple()) { 2997 if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD, 2998 SelCondTy.getSimpleVT(), 2999 SelValTy.getSimpleVT())) 3000 return Entry->Cost; 3001 } 3002 } 3003 3004 if (isa<FixedVectorType>(ValTy) && ISD == ISD::SETCC) { 3005 auto LT = getTypeLegalizationCost(ValTy); 3006 // Cost v4f16 FCmp without FP16 support via converting to v4f32 and back. 3007 if (LT.second == MVT::v4f16 && !ST->hasFullFP16()) 3008 return LT.first * 4; // fcvtl + fcvtl + fcmp + xtn 3009 } 3010 3011 // Treat the icmp in icmp(and, 0) as free, as we can make use of ands. 3012 // FIXME: This can apply to more conditions and add/sub if it can be shown to 3013 // be profitable. 3014 if (ValTy->isIntegerTy() && ISD == ISD::SETCC && I && 3015 ICmpInst::isEquality(VecPred) && 3016 TLI->isTypeLegal(TLI->getValueType(DL, ValTy)) && 3017 match(I->getOperand(1), m_Zero()) && 3018 match(I->getOperand(0), m_And(m_Value(), m_Value()))) 3019 return 0; 3020 3021 // The base case handles scalable vectors fine for now, since it treats the 3022 // cost as 1 * legalization cost. 3023 return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I); 3024 } 3025 3026 AArch64TTIImpl::TTI::MemCmpExpansionOptions 3027 AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const { 3028 TTI::MemCmpExpansionOptions Options; 3029 if (ST->requiresStrictAlign()) { 3030 // TODO: Add cost modeling for strict align. Misaligned loads expand to 3031 // a bunch of instructions when strict align is enabled. 3032 return Options; 3033 } 3034 Options.AllowOverlappingLoads = true; 3035 Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize); 3036 Options.NumLoadsPerBlock = Options.MaxNumLoads; 3037 // TODO: Though vector loads usually perform well on AArch64, in some targets 3038 // they may wake up the FP unit, which raises the power consumption. Perhaps 3039 // they could be used with no holds barred (-O3). 3040 Options.LoadSizes = {8, 4, 2, 1}; 3041 Options.AllowedTailExpansions = {3, 5, 6}; 3042 return Options; 3043 } 3044 3045 bool AArch64TTIImpl::prefersVectorizedAddressing() const { 3046 return ST->hasSVE(); 3047 } 3048 3049 InstructionCost 3050 AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src, 3051 Align Alignment, unsigned AddressSpace, 3052 TTI::TargetCostKind CostKind) { 3053 if (useNeonVector(Src)) 3054 return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace, 3055 CostKind); 3056 auto LT = getTypeLegalizationCost(Src); 3057 if (!LT.first.isValid()) 3058 return InstructionCost::getInvalid(); 3059 3060 // The code-generator is currently not able to handle scalable vectors 3061 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting 3062 // it. This change will be removed when code-generation for these types is 3063 // sufficiently reliable. 3064 if (cast<VectorType>(Src)->getElementCount() == ElementCount::getScalable(1)) 3065 return InstructionCost::getInvalid(); 3066 3067 return LT.first; 3068 } 3069 3070 static unsigned getSVEGatherScatterOverhead(unsigned Opcode) { 3071 return Opcode == Instruction::Load ? SVEGatherOverhead : SVEScatterOverhead; 3072 } 3073 3074 InstructionCost AArch64TTIImpl::getGatherScatterOpCost( 3075 unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask, 3076 Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) { 3077 if (useNeonVector(DataTy) || !isLegalMaskedGatherScatter(DataTy)) 3078 return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, 3079 Alignment, CostKind, I); 3080 auto *VT = cast<VectorType>(DataTy); 3081 auto LT = getTypeLegalizationCost(DataTy); 3082 if (!LT.first.isValid()) 3083 return InstructionCost::getInvalid(); 3084 3085 if (!LT.second.isVector() || 3086 !isElementTypeLegalForScalableVector(VT->getElementType())) 3087 return InstructionCost::getInvalid(); 3088 3089 // The code-generator is currently not able to handle scalable vectors 3090 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting 3091 // it. This change will be removed when code-generation for these types is 3092 // sufficiently reliable. 3093 if (cast<VectorType>(DataTy)->getElementCount() == 3094 ElementCount::getScalable(1)) 3095 return InstructionCost::getInvalid(); 3096 3097 ElementCount LegalVF = LT.second.getVectorElementCount(); 3098 InstructionCost MemOpCost = 3099 getMemoryOpCost(Opcode, VT->getElementType(), Alignment, 0, CostKind, 3100 {TTI::OK_AnyValue, TTI::OP_None}, I); 3101 // Add on an overhead cost for using gathers/scatters. 3102 // TODO: At the moment this is applied unilaterally for all CPUs, but at some 3103 // point we may want a per-CPU overhead. 3104 MemOpCost *= getSVEGatherScatterOverhead(Opcode); 3105 return LT.first * MemOpCost * getMaxNumElements(LegalVF); 3106 } 3107 3108 bool AArch64TTIImpl::useNeonVector(const Type *Ty) const { 3109 return isa<FixedVectorType>(Ty) && !ST->useSVEForFixedLengthVectors(); 3110 } 3111 3112 InstructionCost AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty, 3113 MaybeAlign Alignment, 3114 unsigned AddressSpace, 3115 TTI::TargetCostKind CostKind, 3116 TTI::OperandValueInfo OpInfo, 3117 const Instruction *I) { 3118 EVT VT = TLI->getValueType(DL, Ty, true); 3119 // Type legalization can't handle structs 3120 if (VT == MVT::Other) 3121 return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace, 3122 CostKind); 3123 3124 auto LT = getTypeLegalizationCost(Ty); 3125 if (!LT.first.isValid()) 3126 return InstructionCost::getInvalid(); 3127 3128 // The code-generator is currently not able to handle scalable vectors 3129 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting 3130 // it. This change will be removed when code-generation for these types is 3131 // sufficiently reliable. 3132 if (auto *VTy = dyn_cast<ScalableVectorType>(Ty)) 3133 if (VTy->getElementCount() == ElementCount::getScalable(1)) 3134 return InstructionCost::getInvalid(); 3135 3136 // TODO: consider latency as well for TCK_SizeAndLatency. 3137 if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency) 3138 return LT.first; 3139 3140 if (CostKind != TTI::TCK_RecipThroughput) 3141 return 1; 3142 3143 if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store && 3144 LT.second.is128BitVector() && (!Alignment || *Alignment < Align(16))) { 3145 // Unaligned stores are extremely inefficient. We don't split all 3146 // unaligned 128-bit stores because the negative impact that has shown in 3147 // practice on inlined block copy code. 3148 // We make such stores expensive so that we will only vectorize if there 3149 // are 6 other instructions getting vectorized. 3150 const int AmortizationCost = 6; 3151 3152 return LT.first * 2 * AmortizationCost; 3153 } 3154 3155 // Opaque ptr or ptr vector types are i64s and can be lowered to STP/LDPs. 3156 if (Ty->isPtrOrPtrVectorTy()) 3157 return LT.first; 3158 3159 // Check truncating stores and extending loads. 3160 if (useNeonVector(Ty) && 3161 Ty->getScalarSizeInBits() != LT.second.getScalarSizeInBits()) { 3162 // v4i8 types are lowered to scalar a load/store and sshll/xtn. 3163 if (VT == MVT::v4i8) 3164 return 2; 3165 // Otherwise we need to scalarize. 3166 return cast<FixedVectorType>(Ty)->getNumElements() * 2; 3167 } 3168 3169 return LT.first; 3170 } 3171 3172 InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost( 3173 unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices, 3174 Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind, 3175 bool UseMaskForCond, bool UseMaskForGaps) { 3176 assert(Factor >= 2 && "Invalid interleave factor"); 3177 auto *VecVTy = cast<VectorType>(VecTy); 3178 3179 if (VecTy->isScalableTy() && (!ST->hasSVE() || Factor != 2)) 3180 return InstructionCost::getInvalid(); 3181 3182 // Vectorization for masked interleaved accesses is only enabled for scalable 3183 // VF. 3184 if (!VecTy->isScalableTy() && (UseMaskForCond || UseMaskForGaps)) 3185 return InstructionCost::getInvalid(); 3186 3187 if (!UseMaskForGaps && Factor <= TLI->getMaxSupportedInterleaveFactor()) { 3188 unsigned MinElts = VecVTy->getElementCount().getKnownMinValue(); 3189 auto *SubVecTy = 3190 VectorType::get(VecVTy->getElementType(), 3191 VecVTy->getElementCount().divideCoefficientBy(Factor)); 3192 3193 // ldN/stN only support legal vector types of size 64 or 128 in bits. 3194 // Accesses having vector types that are a multiple of 128 bits can be 3195 // matched to more than one ldN/stN instruction. 3196 bool UseScalable; 3197 if (MinElts % Factor == 0 && 3198 TLI->isLegalInterleavedAccessType(SubVecTy, DL, UseScalable)) 3199 return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL, UseScalable); 3200 } 3201 3202 return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices, 3203 Alignment, AddressSpace, CostKind, 3204 UseMaskForCond, UseMaskForGaps); 3205 } 3206 3207 InstructionCost 3208 AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) { 3209 InstructionCost Cost = 0; 3210 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; 3211 for (auto *I : Tys) { 3212 if (!I->isVectorTy()) 3213 continue; 3214 if (I->getScalarSizeInBits() * cast<FixedVectorType>(I)->getNumElements() == 3215 128) 3216 Cost += getMemoryOpCost(Instruction::Store, I, Align(128), 0, CostKind) + 3217 getMemoryOpCost(Instruction::Load, I, Align(128), 0, CostKind); 3218 } 3219 return Cost; 3220 } 3221 3222 unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) { 3223 return ST->getMaxInterleaveFactor(); 3224 } 3225 3226 // For Falkor, we want to avoid having too many strided loads in a loop since 3227 // that can exhaust the HW prefetcher resources. We adjust the unroller 3228 // MaxCount preference below to attempt to ensure unrolling doesn't create too 3229 // many strided loads. 3230 static void 3231 getFalkorUnrollingPreferences(Loop *L, ScalarEvolution &SE, 3232 TargetTransformInfo::UnrollingPreferences &UP) { 3233 enum { MaxStridedLoads = 7 }; 3234 auto countStridedLoads = [](Loop *L, ScalarEvolution &SE) { 3235 int StridedLoads = 0; 3236 // FIXME? We could make this more precise by looking at the CFG and 3237 // e.g. not counting loads in each side of an if-then-else diamond. 3238 for (const auto BB : L->blocks()) { 3239 for (auto &I : *BB) { 3240 LoadInst *LMemI = dyn_cast<LoadInst>(&I); 3241 if (!LMemI) 3242 continue; 3243 3244 Value *PtrValue = LMemI->getPointerOperand(); 3245 if (L->isLoopInvariant(PtrValue)) 3246 continue; 3247 3248 const SCEV *LSCEV = SE.getSCEV(PtrValue); 3249 const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV); 3250 if (!LSCEVAddRec || !LSCEVAddRec->isAffine()) 3251 continue; 3252 3253 // FIXME? We could take pairing of unrolled load copies into account 3254 // by looking at the AddRec, but we would probably have to limit this 3255 // to loops with no stores or other memory optimization barriers. 3256 ++StridedLoads; 3257 // We've seen enough strided loads that seeing more won't make a 3258 // difference. 3259 if (StridedLoads > MaxStridedLoads / 2) 3260 return StridedLoads; 3261 } 3262 } 3263 return StridedLoads; 3264 }; 3265 3266 int StridedLoads = countStridedLoads(L, SE); 3267 LLVM_DEBUG(dbgs() << "falkor-hwpf: detected " << StridedLoads 3268 << " strided loads\n"); 3269 // Pick the largest power of 2 unroll count that won't result in too many 3270 // strided loads. 3271 if (StridedLoads) { 3272 UP.MaxCount = 1 << Log2_32(MaxStridedLoads / StridedLoads); 3273 LLVM_DEBUG(dbgs() << "falkor-hwpf: setting unroll MaxCount to " 3274 << UP.MaxCount << '\n'); 3275 } 3276 } 3277 3278 void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE, 3279 TTI::UnrollingPreferences &UP, 3280 OptimizationRemarkEmitter *ORE) { 3281 // Enable partial unrolling and runtime unrolling. 3282 BaseT::getUnrollingPreferences(L, SE, UP, ORE); 3283 3284 UP.UpperBound = true; 3285 3286 // For inner loop, it is more likely to be a hot one, and the runtime check 3287 // can be promoted out from LICM pass, so the overhead is less, let's try 3288 // a larger threshold to unroll more loops. 3289 if (L->getLoopDepth() > 1) 3290 UP.PartialThreshold *= 2; 3291 3292 // Disable partial & runtime unrolling on -Os. 3293 UP.PartialOptSizeThreshold = 0; 3294 3295 if (ST->getProcFamily() == AArch64Subtarget::Falkor && 3296 EnableFalkorHWPFUnrollFix) 3297 getFalkorUnrollingPreferences(L, SE, UP); 3298 3299 // Scan the loop: don't unroll loops with calls as this could prevent 3300 // inlining. Don't unroll vector loops either, as they don't benefit much from 3301 // unrolling. 3302 for (auto *BB : L->getBlocks()) { 3303 for (auto &I : *BB) { 3304 // Don't unroll vectorised loop. 3305 if (I.getType()->isVectorTy()) 3306 return; 3307 3308 if (isa<CallInst>(I) || isa<InvokeInst>(I)) { 3309 if (const Function *F = cast<CallBase>(I).getCalledFunction()) { 3310 if (!isLoweredToCall(F)) 3311 continue; 3312 } 3313 return; 3314 } 3315 } 3316 } 3317 3318 // Enable runtime unrolling for in-order models 3319 // If mcpu is omitted, getProcFamily() returns AArch64Subtarget::Others, so by 3320 // checking for that case, we can ensure that the default behaviour is 3321 // unchanged 3322 if (ST->getProcFamily() != AArch64Subtarget::Others && 3323 !ST->getSchedModel().isOutOfOrder()) { 3324 UP.Runtime = true; 3325 UP.Partial = true; 3326 UP.UnrollRemainder = true; 3327 UP.DefaultUnrollRuntimeCount = 4; 3328 3329 UP.UnrollAndJam = true; 3330 UP.UnrollAndJamInnerLoopThreshold = 60; 3331 } 3332 } 3333 3334 void AArch64TTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE, 3335 TTI::PeelingPreferences &PP) { 3336 BaseT::getPeelingPreferences(L, SE, PP); 3337 } 3338 3339 Value *AArch64TTIImpl::getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst, 3340 Type *ExpectedType) { 3341 switch (Inst->getIntrinsicID()) { 3342 default: 3343 return nullptr; 3344 case Intrinsic::aarch64_neon_st2: 3345 case Intrinsic::aarch64_neon_st3: 3346 case Intrinsic::aarch64_neon_st4: { 3347 // Create a struct type 3348 StructType *ST = dyn_cast<StructType>(ExpectedType); 3349 if (!ST) 3350 return nullptr; 3351 unsigned NumElts = Inst->arg_size() - 1; 3352 if (ST->getNumElements() != NumElts) 3353 return nullptr; 3354 for (unsigned i = 0, e = NumElts; i != e; ++i) { 3355 if (Inst->getArgOperand(i)->getType() != ST->getElementType(i)) 3356 return nullptr; 3357 } 3358 Value *Res = PoisonValue::get(ExpectedType); 3359 IRBuilder<> Builder(Inst); 3360 for (unsigned i = 0, e = NumElts; i != e; ++i) { 3361 Value *L = Inst->getArgOperand(i); 3362 Res = Builder.CreateInsertValue(Res, L, i); 3363 } 3364 return Res; 3365 } 3366 case Intrinsic::aarch64_neon_ld2: 3367 case Intrinsic::aarch64_neon_ld3: 3368 case Intrinsic::aarch64_neon_ld4: 3369 if (Inst->getType() == ExpectedType) 3370 return Inst; 3371 return nullptr; 3372 } 3373 } 3374 3375 bool AArch64TTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst, 3376 MemIntrinsicInfo &Info) { 3377 switch (Inst->getIntrinsicID()) { 3378 default: 3379 break; 3380 case Intrinsic::aarch64_neon_ld2: 3381 case Intrinsic::aarch64_neon_ld3: 3382 case Intrinsic::aarch64_neon_ld4: 3383 Info.ReadMem = true; 3384 Info.WriteMem = false; 3385 Info.PtrVal = Inst->getArgOperand(0); 3386 break; 3387 case Intrinsic::aarch64_neon_st2: 3388 case Intrinsic::aarch64_neon_st3: 3389 case Intrinsic::aarch64_neon_st4: 3390 Info.ReadMem = false; 3391 Info.WriteMem = true; 3392 Info.PtrVal = Inst->getArgOperand(Inst->arg_size() - 1); 3393 break; 3394 } 3395 3396 switch (Inst->getIntrinsicID()) { 3397 default: 3398 return false; 3399 case Intrinsic::aarch64_neon_ld2: 3400 case Intrinsic::aarch64_neon_st2: 3401 Info.MatchingId = VECTOR_LDST_TWO_ELEMENTS; 3402 break; 3403 case Intrinsic::aarch64_neon_ld3: 3404 case Intrinsic::aarch64_neon_st3: 3405 Info.MatchingId = VECTOR_LDST_THREE_ELEMENTS; 3406 break; 3407 case Intrinsic::aarch64_neon_ld4: 3408 case Intrinsic::aarch64_neon_st4: 3409 Info.MatchingId = VECTOR_LDST_FOUR_ELEMENTS; 3410 break; 3411 } 3412 return true; 3413 } 3414 3415 /// See if \p I should be considered for address type promotion. We check if \p 3416 /// I is a sext with right type and used in memory accesses. If it used in a 3417 /// "complex" getelementptr, we allow it to be promoted without finding other 3418 /// sext instructions that sign extended the same initial value. A getelementptr 3419 /// is considered as "complex" if it has more than 2 operands. 3420 bool AArch64TTIImpl::shouldConsiderAddressTypePromotion( 3421 const Instruction &I, bool &AllowPromotionWithoutCommonHeader) { 3422 bool Considerable = false; 3423 AllowPromotionWithoutCommonHeader = false; 3424 if (!isa<SExtInst>(&I)) 3425 return false; 3426 Type *ConsideredSExtType = 3427 Type::getInt64Ty(I.getParent()->getParent()->getContext()); 3428 if (I.getType() != ConsideredSExtType) 3429 return false; 3430 // See if the sext is the one with the right type and used in at least one 3431 // GetElementPtrInst. 3432 for (const User *U : I.users()) { 3433 if (const GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(U)) { 3434 Considerable = true; 3435 // A getelementptr is considered as "complex" if it has more than 2 3436 // operands. We will promote a SExt used in such complex GEP as we 3437 // expect some computation to be merged if they are done on 64 bits. 3438 if (GEPInst->getNumOperands() > 2) { 3439 AllowPromotionWithoutCommonHeader = true; 3440 break; 3441 } 3442 } 3443 } 3444 return Considerable; 3445 } 3446 3447 bool AArch64TTIImpl::isLegalToVectorizeReduction( 3448 const RecurrenceDescriptor &RdxDesc, ElementCount VF) const { 3449 if (!VF.isScalable()) 3450 return true; 3451 3452 Type *Ty = RdxDesc.getRecurrenceType(); 3453 if (Ty->isBFloatTy() || !isElementTypeLegalForScalableVector(Ty)) 3454 return false; 3455 3456 switch (RdxDesc.getRecurrenceKind()) { 3457 case RecurKind::Add: 3458 case RecurKind::FAdd: 3459 case RecurKind::And: 3460 case RecurKind::Or: 3461 case RecurKind::Xor: 3462 case RecurKind::SMin: 3463 case RecurKind::SMax: 3464 case RecurKind::UMin: 3465 case RecurKind::UMax: 3466 case RecurKind::FMin: 3467 case RecurKind::FMax: 3468 case RecurKind::FMulAdd: 3469 case RecurKind::IAnyOf: 3470 case RecurKind::FAnyOf: 3471 return true; 3472 default: 3473 return false; 3474 } 3475 } 3476 3477 InstructionCost 3478 AArch64TTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty, 3479 FastMathFlags FMF, 3480 TTI::TargetCostKind CostKind) { 3481 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty); 3482 3483 if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16()) 3484 return BaseT::getMinMaxReductionCost(IID, Ty, FMF, CostKind); 3485 3486 InstructionCost LegalizationCost = 0; 3487 if (LT.first > 1) { 3488 Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext()); 3489 IntrinsicCostAttributes Attrs(IID, LegalVTy, {LegalVTy, LegalVTy}, FMF); 3490 LegalizationCost = getIntrinsicInstrCost(Attrs, CostKind) * (LT.first - 1); 3491 } 3492 3493 return LegalizationCost + /*Cost of horizontal reduction*/ 2; 3494 } 3495 3496 InstructionCost AArch64TTIImpl::getArithmeticReductionCostSVE( 3497 unsigned Opcode, VectorType *ValTy, TTI::TargetCostKind CostKind) { 3498 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy); 3499 InstructionCost LegalizationCost = 0; 3500 if (LT.first > 1) { 3501 Type *LegalVTy = EVT(LT.second).getTypeForEVT(ValTy->getContext()); 3502 LegalizationCost = getArithmeticInstrCost(Opcode, LegalVTy, CostKind); 3503 LegalizationCost *= LT.first - 1; 3504 } 3505 3506 int ISD = TLI->InstructionOpcodeToISD(Opcode); 3507 assert(ISD && "Invalid opcode"); 3508 // Add the final reduction cost for the legal horizontal reduction 3509 switch (ISD) { 3510 case ISD::ADD: 3511 case ISD::AND: 3512 case ISD::OR: 3513 case ISD::XOR: 3514 case ISD::FADD: 3515 return LegalizationCost + 2; 3516 default: 3517 return InstructionCost::getInvalid(); 3518 } 3519 } 3520 3521 InstructionCost 3522 AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy, 3523 std::optional<FastMathFlags> FMF, 3524 TTI::TargetCostKind CostKind) { 3525 if (TTI::requiresOrderedReduction(FMF)) { 3526 if (auto *FixedVTy = dyn_cast<FixedVectorType>(ValTy)) { 3527 InstructionCost BaseCost = 3528 BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind); 3529 // Add on extra cost to reflect the extra overhead on some CPUs. We still 3530 // end up vectorizing for more computationally intensive loops. 3531 return BaseCost + FixedVTy->getNumElements(); 3532 } 3533 3534 if (Opcode != Instruction::FAdd) 3535 return InstructionCost::getInvalid(); 3536 3537 auto *VTy = cast<ScalableVectorType>(ValTy); 3538 InstructionCost Cost = 3539 getArithmeticInstrCost(Opcode, VTy->getScalarType(), CostKind); 3540 Cost *= getMaxNumElements(VTy->getElementCount()); 3541 return Cost; 3542 } 3543 3544 if (isa<ScalableVectorType>(ValTy)) 3545 return getArithmeticReductionCostSVE(Opcode, ValTy, CostKind); 3546 3547 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy); 3548 MVT MTy = LT.second; 3549 int ISD = TLI->InstructionOpcodeToISD(Opcode); 3550 assert(ISD && "Invalid opcode"); 3551 3552 // Horizontal adds can use the 'addv' instruction. We model the cost of these 3553 // instructions as twice a normal vector add, plus 1 for each legalization 3554 // step (LT.first). This is the only arithmetic vector reduction operation for 3555 // which we have an instruction. 3556 // OR, XOR and AND costs should match the codegen from: 3557 // OR: llvm/test/CodeGen/AArch64/reduce-or.ll 3558 // XOR: llvm/test/CodeGen/AArch64/reduce-xor.ll 3559 // AND: llvm/test/CodeGen/AArch64/reduce-and.ll 3560 static const CostTblEntry CostTblNoPairwise[]{ 3561 {ISD::ADD, MVT::v8i8, 2}, 3562 {ISD::ADD, MVT::v16i8, 2}, 3563 {ISD::ADD, MVT::v4i16, 2}, 3564 {ISD::ADD, MVT::v8i16, 2}, 3565 {ISD::ADD, MVT::v4i32, 2}, 3566 {ISD::ADD, MVT::v2i64, 2}, 3567 {ISD::OR, MVT::v8i8, 15}, 3568 {ISD::OR, MVT::v16i8, 17}, 3569 {ISD::OR, MVT::v4i16, 7}, 3570 {ISD::OR, MVT::v8i16, 9}, 3571 {ISD::OR, MVT::v2i32, 3}, 3572 {ISD::OR, MVT::v4i32, 5}, 3573 {ISD::OR, MVT::v2i64, 3}, 3574 {ISD::XOR, MVT::v8i8, 15}, 3575 {ISD::XOR, MVT::v16i8, 17}, 3576 {ISD::XOR, MVT::v4i16, 7}, 3577 {ISD::XOR, MVT::v8i16, 9}, 3578 {ISD::XOR, MVT::v2i32, 3}, 3579 {ISD::XOR, MVT::v4i32, 5}, 3580 {ISD::XOR, MVT::v2i64, 3}, 3581 {ISD::AND, MVT::v8i8, 15}, 3582 {ISD::AND, MVT::v16i8, 17}, 3583 {ISD::AND, MVT::v4i16, 7}, 3584 {ISD::AND, MVT::v8i16, 9}, 3585 {ISD::AND, MVT::v2i32, 3}, 3586 {ISD::AND, MVT::v4i32, 5}, 3587 {ISD::AND, MVT::v2i64, 3}, 3588 }; 3589 switch (ISD) { 3590 default: 3591 break; 3592 case ISD::ADD: 3593 if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy)) 3594 return (LT.first - 1) + Entry->Cost; 3595 break; 3596 case ISD::XOR: 3597 case ISD::AND: 3598 case ISD::OR: 3599 const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy); 3600 if (!Entry) 3601 break; 3602 auto *ValVTy = cast<FixedVectorType>(ValTy); 3603 if (MTy.getVectorNumElements() <= ValVTy->getNumElements() && 3604 isPowerOf2_32(ValVTy->getNumElements())) { 3605 InstructionCost ExtraCost = 0; 3606 if (LT.first != 1) { 3607 // Type needs to be split, so there is an extra cost of LT.first - 1 3608 // arithmetic ops. 3609 auto *Ty = FixedVectorType::get(ValTy->getElementType(), 3610 MTy.getVectorNumElements()); 3611 ExtraCost = getArithmeticInstrCost(Opcode, Ty, CostKind); 3612 ExtraCost *= LT.first - 1; 3613 } 3614 // All and/or/xor of i1 will be lowered with maxv/minv/addv + fmov 3615 auto Cost = ValVTy->getElementType()->isIntegerTy(1) ? 2 : Entry->Cost; 3616 return Cost + ExtraCost; 3617 } 3618 break; 3619 } 3620 return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind); 3621 } 3622 3623 InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) { 3624 static const CostTblEntry ShuffleTbl[] = { 3625 { TTI::SK_Splice, MVT::nxv16i8, 1 }, 3626 { TTI::SK_Splice, MVT::nxv8i16, 1 }, 3627 { TTI::SK_Splice, MVT::nxv4i32, 1 }, 3628 { TTI::SK_Splice, MVT::nxv2i64, 1 }, 3629 { TTI::SK_Splice, MVT::nxv2f16, 1 }, 3630 { TTI::SK_Splice, MVT::nxv4f16, 1 }, 3631 { TTI::SK_Splice, MVT::nxv8f16, 1 }, 3632 { TTI::SK_Splice, MVT::nxv2bf16, 1 }, 3633 { TTI::SK_Splice, MVT::nxv4bf16, 1 }, 3634 { TTI::SK_Splice, MVT::nxv8bf16, 1 }, 3635 { TTI::SK_Splice, MVT::nxv2f32, 1 }, 3636 { TTI::SK_Splice, MVT::nxv4f32, 1 }, 3637 { TTI::SK_Splice, MVT::nxv2f64, 1 }, 3638 }; 3639 3640 // The code-generator is currently not able to handle scalable vectors 3641 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting 3642 // it. This change will be removed when code-generation for these types is 3643 // sufficiently reliable. 3644 if (Tp->getElementCount() == ElementCount::getScalable(1)) 3645 return InstructionCost::getInvalid(); 3646 3647 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp); 3648 Type *LegalVTy = EVT(LT.second).getTypeForEVT(Tp->getContext()); 3649 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; 3650 EVT PromotedVT = LT.second.getScalarType() == MVT::i1 3651 ? TLI->getPromotedVTForPredicate(EVT(LT.second)) 3652 : LT.second; 3653 Type *PromotedVTy = EVT(PromotedVT).getTypeForEVT(Tp->getContext()); 3654 InstructionCost LegalizationCost = 0; 3655 if (Index < 0) { 3656 LegalizationCost = 3657 getCmpSelInstrCost(Instruction::ICmp, PromotedVTy, PromotedVTy, 3658 CmpInst::BAD_ICMP_PREDICATE, CostKind) + 3659 getCmpSelInstrCost(Instruction::Select, PromotedVTy, LegalVTy, 3660 CmpInst::BAD_ICMP_PREDICATE, CostKind); 3661 } 3662 3663 // Predicated splice are promoted when lowering. See AArch64ISelLowering.cpp 3664 // Cost performed on a promoted type. 3665 if (LT.second.getScalarType() == MVT::i1) { 3666 LegalizationCost += 3667 getCastInstrCost(Instruction::ZExt, PromotedVTy, LegalVTy, 3668 TTI::CastContextHint::None, CostKind) + 3669 getCastInstrCost(Instruction::Trunc, LegalVTy, PromotedVTy, 3670 TTI::CastContextHint::None, CostKind); 3671 } 3672 const auto *Entry = 3673 CostTableLookup(ShuffleTbl, TTI::SK_Splice, PromotedVT.getSimpleVT()); 3674 assert(Entry && "Illegal Type for Splice"); 3675 LegalizationCost += Entry->Cost; 3676 return LegalizationCost * LT.first; 3677 } 3678 3679 InstructionCost AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, 3680 VectorType *Tp, 3681 ArrayRef<int> Mask, 3682 TTI::TargetCostKind CostKind, 3683 int Index, VectorType *SubTp, 3684 ArrayRef<const Value *> Args) { 3685 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp); 3686 // If we have a Mask, and the LT is being legalized somehow, split the Mask 3687 // into smaller vectors and sum the cost of each shuffle. 3688 if (!Mask.empty() && isa<FixedVectorType>(Tp) && LT.second.isVector() && 3689 Tp->getScalarSizeInBits() == LT.second.getScalarSizeInBits() && 3690 Mask.size() > LT.second.getVectorNumElements() && !Index && !SubTp) { 3691 unsigned TpNumElts = Mask.size(); 3692 unsigned LTNumElts = LT.second.getVectorNumElements(); 3693 unsigned NumVecs = (TpNumElts + LTNumElts - 1) / LTNumElts; 3694 VectorType *NTp = 3695 VectorType::get(Tp->getScalarType(), LT.second.getVectorElementCount()); 3696 InstructionCost Cost; 3697 for (unsigned N = 0; N < NumVecs; N++) { 3698 SmallVector<int> NMask; 3699 // Split the existing mask into chunks of size LTNumElts. Track the source 3700 // sub-vectors to ensure the result has at most 2 inputs. 3701 unsigned Source1, Source2; 3702 unsigned NumSources = 0; 3703 for (unsigned E = 0; E < LTNumElts; E++) { 3704 int MaskElt = (N * LTNumElts + E < TpNumElts) ? Mask[N * LTNumElts + E] 3705 : PoisonMaskElem; 3706 if (MaskElt < 0) { 3707 NMask.push_back(PoisonMaskElem); 3708 continue; 3709 } 3710 3711 // Calculate which source from the input this comes from and whether it 3712 // is new to us. 3713 unsigned Source = MaskElt / LTNumElts; 3714 if (NumSources == 0) { 3715 Source1 = Source; 3716 NumSources = 1; 3717 } else if (NumSources == 1 && Source != Source1) { 3718 Source2 = Source; 3719 NumSources = 2; 3720 } else if (NumSources >= 2 && Source != Source1 && Source != Source2) { 3721 NumSources++; 3722 } 3723 3724 // Add to the new mask. For the NumSources>2 case these are not correct, 3725 // but are only used for the modular lane number. 3726 if (Source == Source1) 3727 NMask.push_back(MaskElt % LTNumElts); 3728 else if (Source == Source2) 3729 NMask.push_back(MaskElt % LTNumElts + LTNumElts); 3730 else 3731 NMask.push_back(MaskElt % LTNumElts); 3732 } 3733 // If the sub-mask has at most 2 input sub-vectors then re-cost it using 3734 // getShuffleCost. If not then cost it using the worst case. 3735 if (NumSources <= 2) 3736 Cost += getShuffleCost(NumSources <= 1 ? TTI::SK_PermuteSingleSrc 3737 : TTI::SK_PermuteTwoSrc, 3738 NTp, NMask, CostKind, 0, nullptr, Args); 3739 else if (any_of(enumerate(NMask), [&](const auto &ME) { 3740 return ME.value() % LTNumElts == ME.index(); 3741 })) 3742 Cost += LTNumElts - 1; 3743 else 3744 Cost += LTNumElts; 3745 } 3746 return Cost; 3747 } 3748 3749 Kind = improveShuffleKindFromMask(Kind, Mask, Tp, Index, SubTp); 3750 3751 // Check for broadcast loads, which are supported by the LD1R instruction. 3752 // In terms of code-size, the shuffle vector is free when a load + dup get 3753 // folded into a LD1R. That's what we check and return here. For performance 3754 // and reciprocal throughput, a LD1R is not completely free. In this case, we 3755 // return the cost for the broadcast below (i.e. 1 for most/all types), so 3756 // that we model the load + dup sequence slightly higher because LD1R is a 3757 // high latency instruction. 3758 if (CostKind == TTI::TCK_CodeSize && Kind == TTI::SK_Broadcast) { 3759 bool IsLoad = !Args.empty() && isa<LoadInst>(Args[0]); 3760 if (IsLoad && LT.second.isVector() && 3761 isLegalBroadcastLoad(Tp->getElementType(), 3762 LT.second.getVectorElementCount())) 3763 return 0; 3764 } 3765 3766 // If we have 4 elements for the shuffle and a Mask, get the cost straight 3767 // from the perfect shuffle tables. 3768 if (Mask.size() == 4 && Tp->getElementCount() == ElementCount::getFixed(4) && 3769 (Tp->getScalarSizeInBits() == 16 || Tp->getScalarSizeInBits() == 32) && 3770 all_of(Mask, [](int E) { return E < 8; })) 3771 return getPerfectShuffleCost(Mask); 3772 3773 if (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Transpose || 3774 Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc || 3775 Kind == TTI::SK_Reverse || Kind == TTI::SK_Splice) { 3776 static const CostTblEntry ShuffleTbl[] = { 3777 // Broadcast shuffle kinds can be performed with 'dup'. 3778 {TTI::SK_Broadcast, MVT::v8i8, 1}, 3779 {TTI::SK_Broadcast, MVT::v16i8, 1}, 3780 {TTI::SK_Broadcast, MVT::v4i16, 1}, 3781 {TTI::SK_Broadcast, MVT::v8i16, 1}, 3782 {TTI::SK_Broadcast, MVT::v2i32, 1}, 3783 {TTI::SK_Broadcast, MVT::v4i32, 1}, 3784 {TTI::SK_Broadcast, MVT::v2i64, 1}, 3785 {TTI::SK_Broadcast, MVT::v4f16, 1}, 3786 {TTI::SK_Broadcast, MVT::v8f16, 1}, 3787 {TTI::SK_Broadcast, MVT::v2f32, 1}, 3788 {TTI::SK_Broadcast, MVT::v4f32, 1}, 3789 {TTI::SK_Broadcast, MVT::v2f64, 1}, 3790 // Transpose shuffle kinds can be performed with 'trn1/trn2' and 3791 // 'zip1/zip2' instructions. 3792 {TTI::SK_Transpose, MVT::v8i8, 1}, 3793 {TTI::SK_Transpose, MVT::v16i8, 1}, 3794 {TTI::SK_Transpose, MVT::v4i16, 1}, 3795 {TTI::SK_Transpose, MVT::v8i16, 1}, 3796 {TTI::SK_Transpose, MVT::v2i32, 1}, 3797 {TTI::SK_Transpose, MVT::v4i32, 1}, 3798 {TTI::SK_Transpose, MVT::v2i64, 1}, 3799 {TTI::SK_Transpose, MVT::v4f16, 1}, 3800 {TTI::SK_Transpose, MVT::v8f16, 1}, 3801 {TTI::SK_Transpose, MVT::v2f32, 1}, 3802 {TTI::SK_Transpose, MVT::v4f32, 1}, 3803 {TTI::SK_Transpose, MVT::v2f64, 1}, 3804 // Select shuffle kinds. 3805 // TODO: handle vXi8/vXi16. 3806 {TTI::SK_Select, MVT::v2i32, 1}, // mov. 3807 {TTI::SK_Select, MVT::v4i32, 2}, // rev+trn (or similar). 3808 {TTI::SK_Select, MVT::v2i64, 1}, // mov. 3809 {TTI::SK_Select, MVT::v2f32, 1}, // mov. 3810 {TTI::SK_Select, MVT::v4f32, 2}, // rev+trn (or similar). 3811 {TTI::SK_Select, MVT::v2f64, 1}, // mov. 3812 // PermuteSingleSrc shuffle kinds. 3813 {TTI::SK_PermuteSingleSrc, MVT::v2i32, 1}, // mov. 3814 {TTI::SK_PermuteSingleSrc, MVT::v4i32, 3}, // perfectshuffle worst case. 3815 {TTI::SK_PermuteSingleSrc, MVT::v2i64, 1}, // mov. 3816 {TTI::SK_PermuteSingleSrc, MVT::v2f32, 1}, // mov. 3817 {TTI::SK_PermuteSingleSrc, MVT::v4f32, 3}, // perfectshuffle worst case. 3818 {TTI::SK_PermuteSingleSrc, MVT::v2f64, 1}, // mov. 3819 {TTI::SK_PermuteSingleSrc, MVT::v4i16, 3}, // perfectshuffle worst case. 3820 {TTI::SK_PermuteSingleSrc, MVT::v4f16, 3}, // perfectshuffle worst case. 3821 {TTI::SK_PermuteSingleSrc, MVT::v4bf16, 3}, // same 3822 {TTI::SK_PermuteSingleSrc, MVT::v8i16, 8}, // constpool + load + tbl 3823 {TTI::SK_PermuteSingleSrc, MVT::v8f16, 8}, // constpool + load + tbl 3824 {TTI::SK_PermuteSingleSrc, MVT::v8bf16, 8}, // constpool + load + tbl 3825 {TTI::SK_PermuteSingleSrc, MVT::v8i8, 8}, // constpool + load + tbl 3826 {TTI::SK_PermuteSingleSrc, MVT::v16i8, 8}, // constpool + load + tbl 3827 // Reverse can be lowered with `rev`. 3828 {TTI::SK_Reverse, MVT::v2i32, 1}, // REV64 3829 {TTI::SK_Reverse, MVT::v4i32, 2}, // REV64; EXT 3830 {TTI::SK_Reverse, MVT::v2i64, 1}, // EXT 3831 {TTI::SK_Reverse, MVT::v2f32, 1}, // REV64 3832 {TTI::SK_Reverse, MVT::v4f32, 2}, // REV64; EXT 3833 {TTI::SK_Reverse, MVT::v2f64, 1}, // EXT 3834 {TTI::SK_Reverse, MVT::v8f16, 2}, // REV64; EXT 3835 {TTI::SK_Reverse, MVT::v8i16, 2}, // REV64; EXT 3836 {TTI::SK_Reverse, MVT::v16i8, 2}, // REV64; EXT 3837 {TTI::SK_Reverse, MVT::v4f16, 1}, // REV64 3838 {TTI::SK_Reverse, MVT::v4i16, 1}, // REV64 3839 {TTI::SK_Reverse, MVT::v8i8, 1}, // REV64 3840 // Splice can all be lowered as `ext`. 3841 {TTI::SK_Splice, MVT::v2i32, 1}, 3842 {TTI::SK_Splice, MVT::v4i32, 1}, 3843 {TTI::SK_Splice, MVT::v2i64, 1}, 3844 {TTI::SK_Splice, MVT::v2f32, 1}, 3845 {TTI::SK_Splice, MVT::v4f32, 1}, 3846 {TTI::SK_Splice, MVT::v2f64, 1}, 3847 {TTI::SK_Splice, MVT::v8f16, 1}, 3848 {TTI::SK_Splice, MVT::v8bf16, 1}, 3849 {TTI::SK_Splice, MVT::v8i16, 1}, 3850 {TTI::SK_Splice, MVT::v16i8, 1}, 3851 {TTI::SK_Splice, MVT::v4bf16, 1}, 3852 {TTI::SK_Splice, MVT::v4f16, 1}, 3853 {TTI::SK_Splice, MVT::v4i16, 1}, 3854 {TTI::SK_Splice, MVT::v8i8, 1}, 3855 // Broadcast shuffle kinds for scalable vectors 3856 {TTI::SK_Broadcast, MVT::nxv16i8, 1}, 3857 {TTI::SK_Broadcast, MVT::nxv8i16, 1}, 3858 {TTI::SK_Broadcast, MVT::nxv4i32, 1}, 3859 {TTI::SK_Broadcast, MVT::nxv2i64, 1}, 3860 {TTI::SK_Broadcast, MVT::nxv2f16, 1}, 3861 {TTI::SK_Broadcast, MVT::nxv4f16, 1}, 3862 {TTI::SK_Broadcast, MVT::nxv8f16, 1}, 3863 {TTI::SK_Broadcast, MVT::nxv2bf16, 1}, 3864 {TTI::SK_Broadcast, MVT::nxv4bf16, 1}, 3865 {TTI::SK_Broadcast, MVT::nxv8bf16, 1}, 3866 {TTI::SK_Broadcast, MVT::nxv2f32, 1}, 3867 {TTI::SK_Broadcast, MVT::nxv4f32, 1}, 3868 {TTI::SK_Broadcast, MVT::nxv2f64, 1}, 3869 {TTI::SK_Broadcast, MVT::nxv16i1, 1}, 3870 {TTI::SK_Broadcast, MVT::nxv8i1, 1}, 3871 {TTI::SK_Broadcast, MVT::nxv4i1, 1}, 3872 {TTI::SK_Broadcast, MVT::nxv2i1, 1}, 3873 // Handle the cases for vector.reverse with scalable vectors 3874 {TTI::SK_Reverse, MVT::nxv16i8, 1}, 3875 {TTI::SK_Reverse, MVT::nxv8i16, 1}, 3876 {TTI::SK_Reverse, MVT::nxv4i32, 1}, 3877 {TTI::SK_Reverse, MVT::nxv2i64, 1}, 3878 {TTI::SK_Reverse, MVT::nxv2f16, 1}, 3879 {TTI::SK_Reverse, MVT::nxv4f16, 1}, 3880 {TTI::SK_Reverse, MVT::nxv8f16, 1}, 3881 {TTI::SK_Reverse, MVT::nxv2bf16, 1}, 3882 {TTI::SK_Reverse, MVT::nxv4bf16, 1}, 3883 {TTI::SK_Reverse, MVT::nxv8bf16, 1}, 3884 {TTI::SK_Reverse, MVT::nxv2f32, 1}, 3885 {TTI::SK_Reverse, MVT::nxv4f32, 1}, 3886 {TTI::SK_Reverse, MVT::nxv2f64, 1}, 3887 {TTI::SK_Reverse, MVT::nxv16i1, 1}, 3888 {TTI::SK_Reverse, MVT::nxv8i1, 1}, 3889 {TTI::SK_Reverse, MVT::nxv4i1, 1}, 3890 {TTI::SK_Reverse, MVT::nxv2i1, 1}, 3891 }; 3892 if (const auto *Entry = CostTableLookup(ShuffleTbl, Kind, LT.second)) 3893 return LT.first * Entry->Cost; 3894 } 3895 3896 if (Kind == TTI::SK_Splice && isa<ScalableVectorType>(Tp)) 3897 return getSpliceCost(Tp, Index); 3898 3899 // Inserting a subvector can often be done with either a D, S or H register 3900 // move, so long as the inserted vector is "aligned". 3901 if (Kind == TTI::SK_InsertSubvector && LT.second.isFixedLengthVector() && 3902 LT.second.getSizeInBits() <= 128 && SubTp) { 3903 std::pair<InstructionCost, MVT> SubLT = getTypeLegalizationCost(SubTp); 3904 if (SubLT.second.isVector()) { 3905 int NumElts = LT.second.getVectorNumElements(); 3906 int NumSubElts = SubLT.second.getVectorNumElements(); 3907 if ((Index % NumSubElts) == 0 && (NumElts % NumSubElts) == 0) 3908 return SubLT.first; 3909 } 3910 } 3911 3912 return BaseT::getShuffleCost(Kind, Tp, Mask, CostKind, Index, SubTp); 3913 } 3914 3915 static bool containsDecreasingPointers(Loop *TheLoop, 3916 PredicatedScalarEvolution *PSE) { 3917 const auto &Strides = DenseMap<Value *, const SCEV *>(); 3918 for (BasicBlock *BB : TheLoop->blocks()) { 3919 // Scan the instructions in the block and look for addresses that are 3920 // consecutive and decreasing. 3921 for (Instruction &I : *BB) { 3922 if (isa<LoadInst>(&I) || isa<StoreInst>(&I)) { 3923 Value *Ptr = getLoadStorePointerOperand(&I); 3924 Type *AccessTy = getLoadStoreType(&I); 3925 if (getPtrStride(*PSE, AccessTy, Ptr, TheLoop, Strides, /*Assume=*/true, 3926 /*ShouldCheckWrap=*/false) 3927 .value_or(0) < 0) 3928 return true; 3929 } 3930 } 3931 } 3932 return false; 3933 } 3934 3935 bool AArch64TTIImpl::preferPredicateOverEpilogue(TailFoldingInfo *TFI) { 3936 if (!ST->hasSVE()) 3937 return false; 3938 3939 // We don't currently support vectorisation with interleaving for SVE - with 3940 // such loops we're better off not using tail-folding. This gives us a chance 3941 // to fall back on fixed-width vectorisation using NEON's ld2/st2/etc. 3942 if (TFI->IAI->hasGroups()) 3943 return false; 3944 3945 TailFoldingOpts Required = TailFoldingOpts::Disabled; 3946 if (TFI->LVL->getReductionVars().size()) 3947 Required |= TailFoldingOpts::Reductions; 3948 if (TFI->LVL->getFixedOrderRecurrences().size()) 3949 Required |= TailFoldingOpts::Recurrences; 3950 3951 // We call this to discover whether any load/store pointers in the loop have 3952 // negative strides. This will require extra work to reverse the loop 3953 // predicate, which may be expensive. 3954 if (containsDecreasingPointers(TFI->LVL->getLoop(), 3955 TFI->LVL->getPredicatedScalarEvolution())) 3956 Required |= TailFoldingOpts::Reverse; 3957 if (Required == TailFoldingOpts::Disabled) 3958 Required |= TailFoldingOpts::Simple; 3959 3960 if (!TailFoldingOptionLoc.satisfies(ST->getSVETailFoldingDefaultOpts(), 3961 Required)) 3962 return false; 3963 3964 // Don't tail-fold for tight loops where we would be better off interleaving 3965 // with an unpredicated loop. 3966 unsigned NumInsns = 0; 3967 for (BasicBlock *BB : TFI->LVL->getLoop()->blocks()) { 3968 NumInsns += BB->sizeWithoutDebug(); 3969 } 3970 3971 // We expect 4 of these to be a IV PHI, IV add, IV compare and branch. 3972 return NumInsns >= SVETailFoldInsnThreshold; 3973 } 3974 3975 InstructionCost 3976 AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, 3977 int64_t BaseOffset, bool HasBaseReg, 3978 int64_t Scale, unsigned AddrSpace) const { 3979 // Scaling factors are not free at all. 3980 // Operands | Rt Latency 3981 // ------------------------------------------- 3982 // Rt, [Xn, Xm] | 4 3983 // ------------------------------------------- 3984 // Rt, [Xn, Xm, lsl #imm] | Rn: 4 Rm: 5 3985 // Rt, [Xn, Wm, <extend> #imm] | 3986 TargetLoweringBase::AddrMode AM; 3987 AM.BaseGV = BaseGV; 3988 AM.BaseOffs = BaseOffset; 3989 AM.HasBaseReg = HasBaseReg; 3990 AM.Scale = Scale; 3991 if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace)) 3992 // Scale represents reg2 * scale, thus account for 1 if 3993 // it is not equal to 0 or 1. 3994 return AM.Scale != 0 && AM.Scale != 1; 3995 return -1; 3996 } 3997