1 //===-- AArch64TargetTransformInfo.cpp - AArch64 specific TTI -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "AArch64TargetTransformInfo.h" 10 #include "AArch64ExpandImm.h" 11 #include "AArch64PerfectShuffle.h" 12 #include "MCTargetDesc/AArch64AddressingModes.h" 13 #include "llvm/Analysis/IVDescriptors.h" 14 #include "llvm/Analysis/LoopInfo.h" 15 #include "llvm/Analysis/TargetTransformInfo.h" 16 #include "llvm/CodeGen/BasicTTIImpl.h" 17 #include "llvm/CodeGen/CostTable.h" 18 #include "llvm/CodeGen/TargetLowering.h" 19 #include "llvm/IR/IntrinsicInst.h" 20 #include "llvm/IR/Intrinsics.h" 21 #include "llvm/IR/IntrinsicsAArch64.h" 22 #include "llvm/IR/PatternMatch.h" 23 #include "llvm/Support/Debug.h" 24 #include "llvm/Transforms/InstCombine/InstCombiner.h" 25 #include "llvm/Transforms/Vectorize/LoopVectorizationLegality.h" 26 #include <algorithm> 27 #include <optional> 28 using namespace llvm; 29 using namespace llvm::PatternMatch; 30 31 #define DEBUG_TYPE "aarch64tti" 32 33 static cl::opt<bool> EnableFalkorHWPFUnrollFix("enable-falkor-hwpf-unroll-fix", 34 cl::init(true), cl::Hidden); 35 36 static cl::opt<unsigned> SVEGatherOverhead("sve-gather-overhead", cl::init(10), 37 cl::Hidden); 38 39 static cl::opt<unsigned> SVEScatterOverhead("sve-scatter-overhead", 40 cl::init(10), cl::Hidden); 41 42 static cl::opt<unsigned> SVETailFoldInsnThreshold("sve-tail-folding-insn-threshold", 43 cl::init(15), cl::Hidden); 44 45 static cl::opt<unsigned> 46 NeonNonConstStrideOverhead("neon-nonconst-stride-overhead", cl::init(10), 47 cl::Hidden); 48 49 static cl::opt<unsigned> CallPenaltyChangeSM( 50 "call-penalty-sm-change", cl::init(5), cl::Hidden, 51 cl::desc( 52 "Penalty of calling a function that requires a change to PSTATE.SM")); 53 54 static cl::opt<unsigned> InlineCallPenaltyChangeSM( 55 "inline-call-penalty-sm-change", cl::init(10), cl::Hidden, 56 cl::desc("Penalty of inlining a call that requires a change to PSTATE.SM")); 57 58 namespace { 59 class TailFoldingOption { 60 // These bitfields will only ever be set to something non-zero in operator=, 61 // when setting the -sve-tail-folding option. This option should always be of 62 // the form (default|simple|all|disable)[+(Flag1|Flag2|etc)], where here 63 // InitialBits is one of (disabled|all|simple). EnableBits represents 64 // additional flags we're enabling, and DisableBits for those flags we're 65 // disabling. The default flag is tracked in the variable NeedsDefault, since 66 // at the time of setting the option we may not know what the default value 67 // for the CPU is. 68 TailFoldingOpts InitialBits = TailFoldingOpts::Disabled; 69 TailFoldingOpts EnableBits = TailFoldingOpts::Disabled; 70 TailFoldingOpts DisableBits = TailFoldingOpts::Disabled; 71 72 // This value needs to be initialised to true in case the user does not 73 // explicitly set the -sve-tail-folding option. 74 bool NeedsDefault = true; 75 76 void setInitialBits(TailFoldingOpts Bits) { InitialBits = Bits; } 77 78 void setNeedsDefault(bool V) { NeedsDefault = V; } 79 80 void setEnableBit(TailFoldingOpts Bit) { 81 EnableBits |= Bit; 82 DisableBits &= ~Bit; 83 } 84 85 void setDisableBit(TailFoldingOpts Bit) { 86 EnableBits &= ~Bit; 87 DisableBits |= Bit; 88 } 89 90 TailFoldingOpts getBits(TailFoldingOpts DefaultBits) const { 91 TailFoldingOpts Bits = TailFoldingOpts::Disabled; 92 93 assert((InitialBits == TailFoldingOpts::Disabled || !NeedsDefault) && 94 "Initial bits should only include one of " 95 "(disabled|all|simple|default)"); 96 Bits = NeedsDefault ? DefaultBits : InitialBits; 97 Bits |= EnableBits; 98 Bits &= ~DisableBits; 99 100 return Bits; 101 } 102 103 void reportError(std::string Opt) { 104 errs() << "invalid argument '" << Opt 105 << "' to -sve-tail-folding=; the option should be of the form\n" 106 " (disabled|all|default|simple)[+(reductions|recurrences" 107 "|reverse|noreductions|norecurrences|noreverse)]\n"; 108 report_fatal_error("Unrecognised tail-folding option"); 109 } 110 111 public: 112 113 void operator=(const std::string &Val) { 114 // If the user explicitly sets -sve-tail-folding= then treat as an error. 115 if (Val.empty()) { 116 reportError(""); 117 return; 118 } 119 120 // Since the user is explicitly setting the option we don't automatically 121 // need the default unless they require it. 122 setNeedsDefault(false); 123 124 SmallVector<StringRef, 4> TailFoldTypes; 125 StringRef(Val).split(TailFoldTypes, '+', -1, false); 126 127 unsigned StartIdx = 1; 128 if (TailFoldTypes[0] == "disabled") 129 setInitialBits(TailFoldingOpts::Disabled); 130 else if (TailFoldTypes[0] == "all") 131 setInitialBits(TailFoldingOpts::All); 132 else if (TailFoldTypes[0] == "default") 133 setNeedsDefault(true); 134 else if (TailFoldTypes[0] == "simple") 135 setInitialBits(TailFoldingOpts::Simple); 136 else { 137 StartIdx = 0; 138 setInitialBits(TailFoldingOpts::Disabled); 139 } 140 141 for (unsigned I = StartIdx; I < TailFoldTypes.size(); I++) { 142 if (TailFoldTypes[I] == "reductions") 143 setEnableBit(TailFoldingOpts::Reductions); 144 else if (TailFoldTypes[I] == "recurrences") 145 setEnableBit(TailFoldingOpts::Recurrences); 146 else if (TailFoldTypes[I] == "reverse") 147 setEnableBit(TailFoldingOpts::Reverse); 148 else if (TailFoldTypes[I] == "noreductions") 149 setDisableBit(TailFoldingOpts::Reductions); 150 else if (TailFoldTypes[I] == "norecurrences") 151 setDisableBit(TailFoldingOpts::Recurrences); 152 else if (TailFoldTypes[I] == "noreverse") 153 setDisableBit(TailFoldingOpts::Reverse); 154 else 155 reportError(Val); 156 } 157 } 158 159 bool satisfies(TailFoldingOpts DefaultBits, TailFoldingOpts Required) const { 160 return (getBits(DefaultBits) & Required) == Required; 161 } 162 }; 163 } // namespace 164 165 TailFoldingOption TailFoldingOptionLoc; 166 167 cl::opt<TailFoldingOption, true, cl::parser<std::string>> SVETailFolding( 168 "sve-tail-folding", 169 cl::desc( 170 "Control the use of vectorisation using tail-folding for SVE where the" 171 " option is specified in the form (Initial)[+(Flag1|Flag2|...)]:" 172 "\ndisabled (Initial) No loop types will vectorize using " 173 "tail-folding" 174 "\ndefault (Initial) Uses the default tail-folding settings for " 175 "the target CPU" 176 "\nall (Initial) All legal loop types will vectorize using " 177 "tail-folding" 178 "\nsimple (Initial) Use tail-folding for simple loops (not " 179 "reductions or recurrences)" 180 "\nreductions Use tail-folding for loops containing reductions" 181 "\nnoreductions Inverse of above" 182 "\nrecurrences Use tail-folding for loops containing fixed order " 183 "recurrences" 184 "\nnorecurrences Inverse of above" 185 "\nreverse Use tail-folding for loops requiring reversed " 186 "predicates" 187 "\nnoreverse Inverse of above"), 188 cl::location(TailFoldingOptionLoc)); 189 190 // Experimental option that will only be fully functional when the 191 // code-generator is changed to use SVE instead of NEON for all fixed-width 192 // operations. 193 static cl::opt<bool> EnableFixedwidthAutovecInStreamingMode( 194 "enable-fixedwidth-autovec-in-streaming-mode", cl::init(false), cl::Hidden); 195 196 // Experimental option that will only be fully functional when the cost-model 197 // and code-generator have been changed to avoid using scalable vector 198 // instructions that are not legal in streaming SVE mode. 199 static cl::opt<bool> EnableScalableAutovecInStreamingMode( 200 "enable-scalable-autovec-in-streaming-mode", cl::init(false), cl::Hidden); 201 202 static bool isSMEABIRoutineCall(const CallInst &CI) { 203 const auto *F = CI.getCalledFunction(); 204 return F && StringSwitch<bool>(F->getName()) 205 .Case("__arm_sme_state", true) 206 .Case("__arm_tpidr2_save", true) 207 .Case("__arm_tpidr2_restore", true) 208 .Case("__arm_za_disable", true) 209 .Default(false); 210 } 211 212 /// Returns true if the function has explicit operations that can only be 213 /// lowered using incompatible instructions for the selected mode. This also 214 /// returns true if the function F may use or modify ZA state. 215 static bool hasPossibleIncompatibleOps(const Function *F) { 216 for (const BasicBlock &BB : *F) { 217 for (const Instruction &I : BB) { 218 // Be conservative for now and assume that any call to inline asm or to 219 // intrinsics could could result in non-streaming ops (e.g. calls to 220 // @llvm.aarch64.* or @llvm.gather/scatter intrinsics). We can assume that 221 // all native LLVM instructions can be lowered to compatible instructions. 222 if (isa<CallInst>(I) && !I.isDebugOrPseudoInst() && 223 (cast<CallInst>(I).isInlineAsm() || isa<IntrinsicInst>(I) || 224 isSMEABIRoutineCall(cast<CallInst>(I)))) 225 return true; 226 } 227 } 228 return false; 229 } 230 231 bool AArch64TTIImpl::areInlineCompatible(const Function *Caller, 232 const Function *Callee) const { 233 SMEAttrs CallerAttrs(*Caller); 234 SMEAttrs CalleeAttrs(*Callee); 235 if (CalleeAttrs.hasNewZABody()) 236 return false; 237 238 if (CallerAttrs.requiresLazySave(CalleeAttrs) || 239 CallerAttrs.requiresSMChange(CalleeAttrs, 240 /*BodyOverridesInterface=*/true)) { 241 if (hasPossibleIncompatibleOps(Callee)) 242 return false; 243 } 244 245 const TargetMachine &TM = getTLI()->getTargetMachine(); 246 247 const FeatureBitset &CallerBits = 248 TM.getSubtargetImpl(*Caller)->getFeatureBits(); 249 const FeatureBitset &CalleeBits = 250 TM.getSubtargetImpl(*Callee)->getFeatureBits(); 251 252 // Inline a callee if its target-features are a subset of the callers 253 // target-features. 254 return (CallerBits & CalleeBits) == CalleeBits; 255 } 256 257 bool AArch64TTIImpl::areTypesABICompatible( 258 const Function *Caller, const Function *Callee, 259 const ArrayRef<Type *> &Types) const { 260 if (!BaseT::areTypesABICompatible(Caller, Callee, Types)) 261 return false; 262 263 // We need to ensure that argument promotion does not attempt to promote 264 // pointers to fixed-length vector types larger than 128 bits like 265 // <8 x float> (and pointers to aggregate types which have such fixed-length 266 // vector type members) into the values of the pointees. Such vector types 267 // are used for SVE VLS but there is no ABI for SVE VLS arguments and the 268 // backend cannot lower such value arguments. The 128-bit fixed-length SVE 269 // types can be safely treated as 128-bit NEON types and they cannot be 270 // distinguished in IR. 271 if (ST->useSVEForFixedLengthVectors() && llvm::any_of(Types, [](Type *Ty) { 272 auto FVTy = dyn_cast<FixedVectorType>(Ty); 273 return FVTy && 274 FVTy->getScalarSizeInBits() * FVTy->getNumElements() > 128; 275 })) 276 return false; 277 278 return true; 279 } 280 281 unsigned 282 AArch64TTIImpl::getInlineCallPenalty(const Function *F, const CallBase &Call, 283 unsigned DefaultCallPenalty) const { 284 // This function calculates a penalty for executing Call in F. 285 // 286 // There are two ways this function can be called: 287 // (1) F: 288 // call from F -> G (the call here is Call) 289 // 290 // For (1), Call.getCaller() == F, so it will always return a high cost if 291 // a streaming-mode change is required (thus promoting the need to inline the 292 // function) 293 // 294 // (2) F: 295 // call from F -> G (the call here is not Call) 296 // G: 297 // call from G -> H (the call here is Call) 298 // 299 // For (2), if after inlining the body of G into F the call to H requires a 300 // streaming-mode change, and the call to G from F would also require a 301 // streaming-mode change, then there is benefit to do the streaming-mode 302 // change only once and avoid inlining of G into F. 303 SMEAttrs FAttrs(*F); 304 SMEAttrs CalleeAttrs(Call); 305 if (FAttrs.requiresSMChange(CalleeAttrs)) { 306 if (F == Call.getCaller()) // (1) 307 return CallPenaltyChangeSM * DefaultCallPenalty; 308 if (FAttrs.requiresSMChange(SMEAttrs(*Call.getCaller()))) // (2) 309 return InlineCallPenaltyChangeSM * DefaultCallPenalty; 310 } 311 312 return DefaultCallPenalty; 313 } 314 315 bool AArch64TTIImpl::shouldMaximizeVectorBandwidth( 316 TargetTransformInfo::RegisterKind K) const { 317 assert(K != TargetTransformInfo::RGK_Scalar); 318 return (K == TargetTransformInfo::RGK_FixedWidthVector && 319 ST->isNeonAvailable()); 320 } 321 322 /// Calculate the cost of materializing a 64-bit value. This helper 323 /// method might only calculate a fraction of a larger immediate. Therefore it 324 /// is valid to return a cost of ZERO. 325 InstructionCost AArch64TTIImpl::getIntImmCost(int64_t Val) { 326 // Check if the immediate can be encoded within an instruction. 327 if (Val == 0 || AArch64_AM::isLogicalImmediate(Val, 64)) 328 return 0; 329 330 if (Val < 0) 331 Val = ~Val; 332 333 // Calculate how many moves we will need to materialize this constant. 334 SmallVector<AArch64_IMM::ImmInsnModel, 4> Insn; 335 AArch64_IMM::expandMOVImm(Val, 64, Insn); 336 return Insn.size(); 337 } 338 339 /// Calculate the cost of materializing the given constant. 340 InstructionCost AArch64TTIImpl::getIntImmCost(const APInt &Imm, Type *Ty, 341 TTI::TargetCostKind CostKind) { 342 assert(Ty->isIntegerTy()); 343 344 unsigned BitSize = Ty->getPrimitiveSizeInBits(); 345 if (BitSize == 0) 346 return ~0U; 347 348 // Sign-extend all constants to a multiple of 64-bit. 349 APInt ImmVal = Imm; 350 if (BitSize & 0x3f) 351 ImmVal = Imm.sext((BitSize + 63) & ~0x3fU); 352 353 // Split the constant into 64-bit chunks and calculate the cost for each 354 // chunk. 355 InstructionCost Cost = 0; 356 for (unsigned ShiftVal = 0; ShiftVal < BitSize; ShiftVal += 64) { 357 APInt Tmp = ImmVal.ashr(ShiftVal).sextOrTrunc(64); 358 int64_t Val = Tmp.getSExtValue(); 359 Cost += getIntImmCost(Val); 360 } 361 // We need at least one instruction to materialze the constant. 362 return std::max<InstructionCost>(1, Cost); 363 } 364 365 InstructionCost AArch64TTIImpl::getIntImmCostInst(unsigned Opcode, unsigned Idx, 366 const APInt &Imm, Type *Ty, 367 TTI::TargetCostKind CostKind, 368 Instruction *Inst) { 369 assert(Ty->isIntegerTy()); 370 371 unsigned BitSize = Ty->getPrimitiveSizeInBits(); 372 // There is no cost model for constants with a bit size of 0. Return TCC_Free 373 // here, so that constant hoisting will ignore this constant. 374 if (BitSize == 0) 375 return TTI::TCC_Free; 376 377 unsigned ImmIdx = ~0U; 378 switch (Opcode) { 379 default: 380 return TTI::TCC_Free; 381 case Instruction::GetElementPtr: 382 // Always hoist the base address of a GetElementPtr. 383 if (Idx == 0) 384 return 2 * TTI::TCC_Basic; 385 return TTI::TCC_Free; 386 case Instruction::Store: 387 ImmIdx = 0; 388 break; 389 case Instruction::Add: 390 case Instruction::Sub: 391 case Instruction::Mul: 392 case Instruction::UDiv: 393 case Instruction::SDiv: 394 case Instruction::URem: 395 case Instruction::SRem: 396 case Instruction::And: 397 case Instruction::Or: 398 case Instruction::Xor: 399 case Instruction::ICmp: 400 ImmIdx = 1; 401 break; 402 // Always return TCC_Free for the shift value of a shift instruction. 403 case Instruction::Shl: 404 case Instruction::LShr: 405 case Instruction::AShr: 406 if (Idx == 1) 407 return TTI::TCC_Free; 408 break; 409 case Instruction::Trunc: 410 case Instruction::ZExt: 411 case Instruction::SExt: 412 case Instruction::IntToPtr: 413 case Instruction::PtrToInt: 414 case Instruction::BitCast: 415 case Instruction::PHI: 416 case Instruction::Call: 417 case Instruction::Select: 418 case Instruction::Ret: 419 case Instruction::Load: 420 break; 421 } 422 423 if (Idx == ImmIdx) { 424 int NumConstants = (BitSize + 63) / 64; 425 InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); 426 return (Cost <= NumConstants * TTI::TCC_Basic) 427 ? static_cast<int>(TTI::TCC_Free) 428 : Cost; 429 } 430 return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); 431 } 432 433 InstructionCost 434 AArch64TTIImpl::getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx, 435 const APInt &Imm, Type *Ty, 436 TTI::TargetCostKind CostKind) { 437 assert(Ty->isIntegerTy()); 438 439 unsigned BitSize = Ty->getPrimitiveSizeInBits(); 440 // There is no cost model for constants with a bit size of 0. Return TCC_Free 441 // here, so that constant hoisting will ignore this constant. 442 if (BitSize == 0) 443 return TTI::TCC_Free; 444 445 // Most (all?) AArch64 intrinsics do not support folding immediates into the 446 // selected instruction, so we compute the materialization cost for the 447 // immediate directly. 448 if (IID >= Intrinsic::aarch64_addg && IID <= Intrinsic::aarch64_udiv) 449 return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); 450 451 switch (IID) { 452 default: 453 return TTI::TCC_Free; 454 case Intrinsic::sadd_with_overflow: 455 case Intrinsic::uadd_with_overflow: 456 case Intrinsic::ssub_with_overflow: 457 case Intrinsic::usub_with_overflow: 458 case Intrinsic::smul_with_overflow: 459 case Intrinsic::umul_with_overflow: 460 if (Idx == 1) { 461 int NumConstants = (BitSize + 63) / 64; 462 InstructionCost Cost = AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); 463 return (Cost <= NumConstants * TTI::TCC_Basic) 464 ? static_cast<int>(TTI::TCC_Free) 465 : Cost; 466 } 467 break; 468 case Intrinsic::experimental_stackmap: 469 if ((Idx < 2) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue()))) 470 return TTI::TCC_Free; 471 break; 472 case Intrinsic::experimental_patchpoint_void: 473 case Intrinsic::experimental_patchpoint_i64: 474 if ((Idx < 4) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue()))) 475 return TTI::TCC_Free; 476 break; 477 case Intrinsic::experimental_gc_statepoint: 478 if ((Idx < 5) || (Imm.getBitWidth() <= 64 && isInt<64>(Imm.getSExtValue()))) 479 return TTI::TCC_Free; 480 break; 481 } 482 return AArch64TTIImpl::getIntImmCost(Imm, Ty, CostKind); 483 } 484 485 TargetTransformInfo::PopcntSupportKind 486 AArch64TTIImpl::getPopcntSupport(unsigned TyWidth) { 487 assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2"); 488 if (TyWidth == 32 || TyWidth == 64) 489 return TTI::PSK_FastHardware; 490 // TODO: AArch64TargetLowering::LowerCTPOP() supports 128bit popcount. 491 return TTI::PSK_Software; 492 } 493 494 InstructionCost 495 AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, 496 TTI::TargetCostKind CostKind) { 497 auto *RetTy = ICA.getReturnType(); 498 switch (ICA.getID()) { 499 case Intrinsic::umin: 500 case Intrinsic::umax: 501 case Intrinsic::smin: 502 case Intrinsic::smax: { 503 static const auto ValidMinMaxTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16, 504 MVT::v8i16, MVT::v2i32, MVT::v4i32, 505 MVT::nxv16i8, MVT::nxv8i16, MVT::nxv4i32, 506 MVT::nxv2i64}; 507 auto LT = getTypeLegalizationCost(RetTy); 508 // v2i64 types get converted to cmp+bif hence the cost of 2 509 if (LT.second == MVT::v2i64) 510 return LT.first * 2; 511 if (any_of(ValidMinMaxTys, [<](MVT M) { return M == LT.second; })) 512 return LT.first; 513 break; 514 } 515 case Intrinsic::sadd_sat: 516 case Intrinsic::ssub_sat: 517 case Intrinsic::uadd_sat: 518 case Intrinsic::usub_sat: { 519 static const auto ValidSatTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16, 520 MVT::v8i16, MVT::v2i32, MVT::v4i32, 521 MVT::v2i64}; 522 auto LT = getTypeLegalizationCost(RetTy); 523 // This is a base cost of 1 for the vadd, plus 3 extract shifts if we 524 // need to extend the type, as it uses shr(qadd(shl, shl)). 525 unsigned Instrs = 526 LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits() ? 1 : 4; 527 if (any_of(ValidSatTys, [<](MVT M) { return M == LT.second; })) 528 return LT.first * Instrs; 529 break; 530 } 531 case Intrinsic::abs: { 532 static const auto ValidAbsTys = {MVT::v8i8, MVT::v16i8, MVT::v4i16, 533 MVT::v8i16, MVT::v2i32, MVT::v4i32, 534 MVT::v2i64}; 535 auto LT = getTypeLegalizationCost(RetTy); 536 if (any_of(ValidAbsTys, [<](MVT M) { return M == LT.second; })) 537 return LT.first; 538 break; 539 } 540 case Intrinsic::bswap: { 541 static const auto ValidAbsTys = {MVT::v4i16, MVT::v8i16, MVT::v2i32, 542 MVT::v4i32, MVT::v2i64}; 543 auto LT = getTypeLegalizationCost(RetTy); 544 if (any_of(ValidAbsTys, [<](MVT M) { return M == LT.second; }) && 545 LT.second.getScalarSizeInBits() == RetTy->getScalarSizeInBits()) 546 return LT.first; 547 break; 548 } 549 case Intrinsic::experimental_stepvector: { 550 InstructionCost Cost = 1; // Cost of the `index' instruction 551 auto LT = getTypeLegalizationCost(RetTy); 552 // Legalisation of illegal vectors involves an `index' instruction plus 553 // (LT.first - 1) vector adds. 554 if (LT.first > 1) { 555 Type *LegalVTy = EVT(LT.second).getTypeForEVT(RetTy->getContext()); 556 InstructionCost AddCost = 557 getArithmeticInstrCost(Instruction::Add, LegalVTy, CostKind); 558 Cost += AddCost * (LT.first - 1); 559 } 560 return Cost; 561 } 562 case Intrinsic::bitreverse: { 563 static const CostTblEntry BitreverseTbl[] = { 564 {Intrinsic::bitreverse, MVT::i32, 1}, 565 {Intrinsic::bitreverse, MVT::i64, 1}, 566 {Intrinsic::bitreverse, MVT::v8i8, 1}, 567 {Intrinsic::bitreverse, MVT::v16i8, 1}, 568 {Intrinsic::bitreverse, MVT::v4i16, 2}, 569 {Intrinsic::bitreverse, MVT::v8i16, 2}, 570 {Intrinsic::bitreverse, MVT::v2i32, 2}, 571 {Intrinsic::bitreverse, MVT::v4i32, 2}, 572 {Intrinsic::bitreverse, MVT::v1i64, 2}, 573 {Intrinsic::bitreverse, MVT::v2i64, 2}, 574 }; 575 const auto LegalisationCost = getTypeLegalizationCost(RetTy); 576 const auto *Entry = 577 CostTableLookup(BitreverseTbl, ICA.getID(), LegalisationCost.second); 578 if (Entry) { 579 // Cost Model is using the legal type(i32) that i8 and i16 will be 580 // converted to +1 so that we match the actual lowering cost 581 if (TLI->getValueType(DL, RetTy, true) == MVT::i8 || 582 TLI->getValueType(DL, RetTy, true) == MVT::i16) 583 return LegalisationCost.first * Entry->Cost + 1; 584 585 return LegalisationCost.first * Entry->Cost; 586 } 587 break; 588 } 589 case Intrinsic::ctpop: { 590 if (!ST->hasNEON()) { 591 // 32-bit or 64-bit ctpop without NEON is 12 instructions. 592 return getTypeLegalizationCost(RetTy).first * 12; 593 } 594 static const CostTblEntry CtpopCostTbl[] = { 595 {ISD::CTPOP, MVT::v2i64, 4}, 596 {ISD::CTPOP, MVT::v4i32, 3}, 597 {ISD::CTPOP, MVT::v8i16, 2}, 598 {ISD::CTPOP, MVT::v16i8, 1}, 599 {ISD::CTPOP, MVT::i64, 4}, 600 {ISD::CTPOP, MVT::v2i32, 3}, 601 {ISD::CTPOP, MVT::v4i16, 2}, 602 {ISD::CTPOP, MVT::v8i8, 1}, 603 {ISD::CTPOP, MVT::i32, 5}, 604 }; 605 auto LT = getTypeLegalizationCost(RetTy); 606 MVT MTy = LT.second; 607 if (const auto *Entry = CostTableLookup(CtpopCostTbl, ISD::CTPOP, MTy)) { 608 // Extra cost of +1 when illegal vector types are legalized by promoting 609 // the integer type. 610 int ExtraCost = MTy.isVector() && MTy.getScalarSizeInBits() != 611 RetTy->getScalarSizeInBits() 612 ? 1 613 : 0; 614 return LT.first * Entry->Cost + ExtraCost; 615 } 616 break; 617 } 618 case Intrinsic::sadd_with_overflow: 619 case Intrinsic::uadd_with_overflow: 620 case Intrinsic::ssub_with_overflow: 621 case Intrinsic::usub_with_overflow: 622 case Intrinsic::smul_with_overflow: 623 case Intrinsic::umul_with_overflow: { 624 static const CostTblEntry WithOverflowCostTbl[] = { 625 {Intrinsic::sadd_with_overflow, MVT::i8, 3}, 626 {Intrinsic::uadd_with_overflow, MVT::i8, 3}, 627 {Intrinsic::sadd_with_overflow, MVT::i16, 3}, 628 {Intrinsic::uadd_with_overflow, MVT::i16, 3}, 629 {Intrinsic::sadd_with_overflow, MVT::i32, 1}, 630 {Intrinsic::uadd_with_overflow, MVT::i32, 1}, 631 {Intrinsic::sadd_with_overflow, MVT::i64, 1}, 632 {Intrinsic::uadd_with_overflow, MVT::i64, 1}, 633 {Intrinsic::ssub_with_overflow, MVT::i8, 3}, 634 {Intrinsic::usub_with_overflow, MVT::i8, 3}, 635 {Intrinsic::ssub_with_overflow, MVT::i16, 3}, 636 {Intrinsic::usub_with_overflow, MVT::i16, 3}, 637 {Intrinsic::ssub_with_overflow, MVT::i32, 1}, 638 {Intrinsic::usub_with_overflow, MVT::i32, 1}, 639 {Intrinsic::ssub_with_overflow, MVT::i64, 1}, 640 {Intrinsic::usub_with_overflow, MVT::i64, 1}, 641 {Intrinsic::smul_with_overflow, MVT::i8, 5}, 642 {Intrinsic::umul_with_overflow, MVT::i8, 4}, 643 {Intrinsic::smul_with_overflow, MVT::i16, 5}, 644 {Intrinsic::umul_with_overflow, MVT::i16, 4}, 645 {Intrinsic::smul_with_overflow, MVT::i32, 2}, // eg umull;tst 646 {Intrinsic::umul_with_overflow, MVT::i32, 2}, // eg umull;cmp sxtw 647 {Intrinsic::smul_with_overflow, MVT::i64, 3}, // eg mul;smulh;cmp 648 {Intrinsic::umul_with_overflow, MVT::i64, 3}, // eg mul;umulh;cmp asr 649 }; 650 EVT MTy = TLI->getValueType(DL, RetTy->getContainedType(0), true); 651 if (MTy.isSimple()) 652 if (const auto *Entry = CostTableLookup(WithOverflowCostTbl, ICA.getID(), 653 MTy.getSimpleVT())) 654 return Entry->Cost; 655 break; 656 } 657 case Intrinsic::fptosi_sat: 658 case Intrinsic::fptoui_sat: { 659 if (ICA.getArgTypes().empty()) 660 break; 661 bool IsSigned = ICA.getID() == Intrinsic::fptosi_sat; 662 auto LT = getTypeLegalizationCost(ICA.getArgTypes()[0]); 663 EVT MTy = TLI->getValueType(DL, RetTy); 664 // Check for the legal types, which are where the size of the input and the 665 // output are the same, or we are using cvt f64->i32 or f32->i64. 666 if ((LT.second == MVT::f32 || LT.second == MVT::f64 || 667 LT.second == MVT::v2f32 || LT.second == MVT::v4f32 || 668 LT.second == MVT::v2f64) && 669 (LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits() || 670 (LT.second == MVT::f64 && MTy == MVT::i32) || 671 (LT.second == MVT::f32 && MTy == MVT::i64))) 672 return LT.first; 673 // Similarly for fp16 sizes 674 if (ST->hasFullFP16() && 675 ((LT.second == MVT::f16 && MTy == MVT::i32) || 676 ((LT.second == MVT::v4f16 || LT.second == MVT::v8f16) && 677 (LT.second.getScalarSizeInBits() == MTy.getScalarSizeInBits())))) 678 return LT.first; 679 680 // Otherwise we use a legal convert followed by a min+max 681 if ((LT.second.getScalarType() == MVT::f32 || 682 LT.second.getScalarType() == MVT::f64 || 683 (ST->hasFullFP16() && LT.second.getScalarType() == MVT::f16)) && 684 LT.second.getScalarSizeInBits() >= MTy.getScalarSizeInBits()) { 685 Type *LegalTy = 686 Type::getIntNTy(RetTy->getContext(), LT.second.getScalarSizeInBits()); 687 if (LT.second.isVector()) 688 LegalTy = VectorType::get(LegalTy, LT.second.getVectorElementCount()); 689 InstructionCost Cost = 1; 690 IntrinsicCostAttributes Attrs1(IsSigned ? Intrinsic::smin : Intrinsic::umin, 691 LegalTy, {LegalTy, LegalTy}); 692 Cost += getIntrinsicInstrCost(Attrs1, CostKind); 693 IntrinsicCostAttributes Attrs2(IsSigned ? Intrinsic::smax : Intrinsic::umax, 694 LegalTy, {LegalTy, LegalTy}); 695 Cost += getIntrinsicInstrCost(Attrs2, CostKind); 696 return LT.first * Cost; 697 } 698 break; 699 } 700 case Intrinsic::fshl: 701 case Intrinsic::fshr: { 702 if (ICA.getArgs().empty()) 703 break; 704 705 // TODO: Add handling for fshl where third argument is not a constant. 706 const TTI::OperandValueInfo OpInfoZ = TTI::getOperandInfo(ICA.getArgs()[2]); 707 if (!OpInfoZ.isConstant()) 708 break; 709 710 const auto LegalisationCost = getTypeLegalizationCost(RetTy); 711 if (OpInfoZ.isUniform()) { 712 // FIXME: The costs could be lower if the codegen is better. 713 static const CostTblEntry FshlTbl[] = { 714 {Intrinsic::fshl, MVT::v4i32, 3}, // ushr + shl + orr 715 {Intrinsic::fshl, MVT::v2i64, 3}, {Intrinsic::fshl, MVT::v16i8, 4}, 716 {Intrinsic::fshl, MVT::v8i16, 4}, {Intrinsic::fshl, MVT::v2i32, 3}, 717 {Intrinsic::fshl, MVT::v8i8, 4}, {Intrinsic::fshl, MVT::v4i16, 4}}; 718 // Costs for both fshl & fshr are the same, so just pass Intrinsic::fshl 719 // to avoid having to duplicate the costs. 720 const auto *Entry = 721 CostTableLookup(FshlTbl, Intrinsic::fshl, LegalisationCost.second); 722 if (Entry) 723 return LegalisationCost.first * Entry->Cost; 724 } 725 726 auto TyL = getTypeLegalizationCost(RetTy); 727 if (!RetTy->isIntegerTy()) 728 break; 729 730 // Estimate cost manually, as types like i8 and i16 will get promoted to 731 // i32 and CostTableLookup will ignore the extra conversion cost. 732 bool HigherCost = (RetTy->getScalarSizeInBits() != 32 && 733 RetTy->getScalarSizeInBits() < 64) || 734 (RetTy->getScalarSizeInBits() % 64 != 0); 735 unsigned ExtraCost = HigherCost ? 1 : 0; 736 if (RetTy->getScalarSizeInBits() == 32 || 737 RetTy->getScalarSizeInBits() == 64) 738 ExtraCost = 0; // fhsl/fshr for i32 and i64 can be lowered to a single 739 // extr instruction. 740 else if (HigherCost) 741 ExtraCost = 1; 742 else 743 break; 744 return TyL.first + ExtraCost; 745 } 746 default: 747 break; 748 } 749 return BaseT::getIntrinsicInstrCost(ICA, CostKind); 750 } 751 752 /// The function will remove redundant reinterprets casting in the presence 753 /// of the control flow 754 static std::optional<Instruction *> processPhiNode(InstCombiner &IC, 755 IntrinsicInst &II) { 756 SmallVector<Instruction *, 32> Worklist; 757 auto RequiredType = II.getType(); 758 759 auto *PN = dyn_cast<PHINode>(II.getArgOperand(0)); 760 assert(PN && "Expected Phi Node!"); 761 762 // Don't create a new Phi unless we can remove the old one. 763 if (!PN->hasOneUse()) 764 return std::nullopt; 765 766 for (Value *IncValPhi : PN->incoming_values()) { 767 auto *Reinterpret = dyn_cast<IntrinsicInst>(IncValPhi); 768 if (!Reinterpret || 769 Reinterpret->getIntrinsicID() != 770 Intrinsic::aarch64_sve_convert_to_svbool || 771 RequiredType != Reinterpret->getArgOperand(0)->getType()) 772 return std::nullopt; 773 } 774 775 // Create the new Phi 776 IC.Builder.SetInsertPoint(PN); 777 PHINode *NPN = IC.Builder.CreatePHI(RequiredType, PN->getNumIncomingValues()); 778 Worklist.push_back(PN); 779 780 for (unsigned I = 0; I < PN->getNumIncomingValues(); I++) { 781 auto *Reinterpret = cast<Instruction>(PN->getIncomingValue(I)); 782 NPN->addIncoming(Reinterpret->getOperand(0), PN->getIncomingBlock(I)); 783 Worklist.push_back(Reinterpret); 784 } 785 786 // Cleanup Phi Node and reinterprets 787 return IC.replaceInstUsesWith(II, NPN); 788 } 789 790 // (from_svbool (binop (to_svbool pred) (svbool_t _) (svbool_t _)))) 791 // => (binop (pred) (from_svbool _) (from_svbool _)) 792 // 793 // The above transformation eliminates a `to_svbool` in the predicate 794 // operand of bitwise operation `binop` by narrowing the vector width of 795 // the operation. For example, it would convert a `<vscale x 16 x i1> 796 // and` into a `<vscale x 4 x i1> and`. This is profitable because 797 // to_svbool must zero the new lanes during widening, whereas 798 // from_svbool is free. 799 static std::optional<Instruction *> 800 tryCombineFromSVBoolBinOp(InstCombiner &IC, IntrinsicInst &II) { 801 auto BinOp = dyn_cast<IntrinsicInst>(II.getOperand(0)); 802 if (!BinOp) 803 return std::nullopt; 804 805 auto IntrinsicID = BinOp->getIntrinsicID(); 806 switch (IntrinsicID) { 807 case Intrinsic::aarch64_sve_and_z: 808 case Intrinsic::aarch64_sve_bic_z: 809 case Intrinsic::aarch64_sve_eor_z: 810 case Intrinsic::aarch64_sve_nand_z: 811 case Intrinsic::aarch64_sve_nor_z: 812 case Intrinsic::aarch64_sve_orn_z: 813 case Intrinsic::aarch64_sve_orr_z: 814 break; 815 default: 816 return std::nullopt; 817 } 818 819 auto BinOpPred = BinOp->getOperand(0); 820 auto BinOpOp1 = BinOp->getOperand(1); 821 auto BinOpOp2 = BinOp->getOperand(2); 822 823 auto PredIntr = dyn_cast<IntrinsicInst>(BinOpPred); 824 if (!PredIntr || 825 PredIntr->getIntrinsicID() != Intrinsic::aarch64_sve_convert_to_svbool) 826 return std::nullopt; 827 828 auto PredOp = PredIntr->getOperand(0); 829 auto PredOpTy = cast<VectorType>(PredOp->getType()); 830 if (PredOpTy != II.getType()) 831 return std::nullopt; 832 833 SmallVector<Value *> NarrowedBinOpArgs = {PredOp}; 834 auto NarrowBinOpOp1 = IC.Builder.CreateIntrinsic( 835 Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp1}); 836 NarrowedBinOpArgs.push_back(NarrowBinOpOp1); 837 if (BinOpOp1 == BinOpOp2) 838 NarrowedBinOpArgs.push_back(NarrowBinOpOp1); 839 else 840 NarrowedBinOpArgs.push_back(IC.Builder.CreateIntrinsic( 841 Intrinsic::aarch64_sve_convert_from_svbool, {PredOpTy}, {BinOpOp2})); 842 843 auto NarrowedBinOp = 844 IC.Builder.CreateIntrinsic(IntrinsicID, {PredOpTy}, NarrowedBinOpArgs); 845 return IC.replaceInstUsesWith(II, NarrowedBinOp); 846 } 847 848 static std::optional<Instruction *> 849 instCombineConvertFromSVBool(InstCombiner &IC, IntrinsicInst &II) { 850 // If the reinterpret instruction operand is a PHI Node 851 if (isa<PHINode>(II.getArgOperand(0))) 852 return processPhiNode(IC, II); 853 854 if (auto BinOpCombine = tryCombineFromSVBoolBinOp(IC, II)) 855 return BinOpCombine; 856 857 // Ignore converts to/from svcount_t. 858 if (isa<TargetExtType>(II.getArgOperand(0)->getType()) || 859 isa<TargetExtType>(II.getType())) 860 return std::nullopt; 861 862 SmallVector<Instruction *, 32> CandidatesForRemoval; 863 Value *Cursor = II.getOperand(0), *EarliestReplacement = nullptr; 864 865 const auto *IVTy = cast<VectorType>(II.getType()); 866 867 // Walk the chain of conversions. 868 while (Cursor) { 869 // If the type of the cursor has fewer lanes than the final result, zeroing 870 // must take place, which breaks the equivalence chain. 871 const auto *CursorVTy = cast<VectorType>(Cursor->getType()); 872 if (CursorVTy->getElementCount().getKnownMinValue() < 873 IVTy->getElementCount().getKnownMinValue()) 874 break; 875 876 // If the cursor has the same type as I, it is a viable replacement. 877 if (Cursor->getType() == IVTy) 878 EarliestReplacement = Cursor; 879 880 auto *IntrinsicCursor = dyn_cast<IntrinsicInst>(Cursor); 881 882 // If this is not an SVE conversion intrinsic, this is the end of the chain. 883 if (!IntrinsicCursor || !(IntrinsicCursor->getIntrinsicID() == 884 Intrinsic::aarch64_sve_convert_to_svbool || 885 IntrinsicCursor->getIntrinsicID() == 886 Intrinsic::aarch64_sve_convert_from_svbool)) 887 break; 888 889 CandidatesForRemoval.insert(CandidatesForRemoval.begin(), IntrinsicCursor); 890 Cursor = IntrinsicCursor->getOperand(0); 891 } 892 893 // If no viable replacement in the conversion chain was found, there is 894 // nothing to do. 895 if (!EarliestReplacement) 896 return std::nullopt; 897 898 return IC.replaceInstUsesWith(II, EarliestReplacement); 899 } 900 901 static bool isAllActivePredicate(Value *Pred) { 902 // Look through convert.from.svbool(convert.to.svbool(...) chain. 903 Value *UncastedPred; 904 if (match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_convert_from_svbool>( 905 m_Intrinsic<Intrinsic::aarch64_sve_convert_to_svbool>( 906 m_Value(UncastedPred))))) 907 // If the predicate has the same or less lanes than the uncasted 908 // predicate then we know the casting has no effect. 909 if (cast<ScalableVectorType>(Pred->getType())->getMinNumElements() <= 910 cast<ScalableVectorType>(UncastedPred->getType())->getMinNumElements()) 911 Pred = UncastedPred; 912 913 return match(Pred, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>( 914 m_ConstantInt<AArch64SVEPredPattern::all>())); 915 } 916 917 static std::optional<Instruction *> instCombineSVESel(InstCombiner &IC, 918 IntrinsicInst &II) { 919 // svsel(ptrue, x, y) => x 920 auto *OpPredicate = II.getOperand(0); 921 if (isAllActivePredicate(OpPredicate)) 922 return IC.replaceInstUsesWith(II, II.getOperand(1)); 923 924 auto Select = 925 IC.Builder.CreateSelect(OpPredicate, II.getOperand(1), II.getOperand(2)); 926 return IC.replaceInstUsesWith(II, Select); 927 } 928 929 static std::optional<Instruction *> instCombineSVEDup(InstCombiner &IC, 930 IntrinsicInst &II) { 931 IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(1)); 932 if (!Pg) 933 return std::nullopt; 934 935 if (Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue) 936 return std::nullopt; 937 938 const auto PTruePattern = 939 cast<ConstantInt>(Pg->getOperand(0))->getZExtValue(); 940 if (PTruePattern != AArch64SVEPredPattern::vl1) 941 return std::nullopt; 942 943 // The intrinsic is inserting into lane zero so use an insert instead. 944 auto *IdxTy = Type::getInt64Ty(II.getContext()); 945 auto *Insert = InsertElementInst::Create( 946 II.getArgOperand(0), II.getArgOperand(2), ConstantInt::get(IdxTy, 0)); 947 Insert->insertBefore(&II); 948 Insert->takeName(&II); 949 950 return IC.replaceInstUsesWith(II, Insert); 951 } 952 953 static std::optional<Instruction *> instCombineSVEDupX(InstCombiner &IC, 954 IntrinsicInst &II) { 955 // Replace DupX with a regular IR splat. 956 auto *RetTy = cast<ScalableVectorType>(II.getType()); 957 Value *Splat = IC.Builder.CreateVectorSplat(RetTy->getElementCount(), 958 II.getArgOperand(0)); 959 Splat->takeName(&II); 960 return IC.replaceInstUsesWith(II, Splat); 961 } 962 963 static std::optional<Instruction *> instCombineSVECmpNE(InstCombiner &IC, 964 IntrinsicInst &II) { 965 LLVMContext &Ctx = II.getContext(); 966 967 // Check that the predicate is all active 968 auto *Pg = dyn_cast<IntrinsicInst>(II.getArgOperand(0)); 969 if (!Pg || Pg->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue) 970 return std::nullopt; 971 972 const auto PTruePattern = 973 cast<ConstantInt>(Pg->getOperand(0))->getZExtValue(); 974 if (PTruePattern != AArch64SVEPredPattern::all) 975 return std::nullopt; 976 977 // Check that we have a compare of zero.. 978 auto *SplatValue = 979 dyn_cast_or_null<ConstantInt>(getSplatValue(II.getArgOperand(2))); 980 if (!SplatValue || !SplatValue->isZero()) 981 return std::nullopt; 982 983 // ..against a dupq 984 auto *DupQLane = dyn_cast<IntrinsicInst>(II.getArgOperand(1)); 985 if (!DupQLane || 986 DupQLane->getIntrinsicID() != Intrinsic::aarch64_sve_dupq_lane) 987 return std::nullopt; 988 989 // Where the dupq is a lane 0 replicate of a vector insert 990 if (!cast<ConstantInt>(DupQLane->getArgOperand(1))->isZero()) 991 return std::nullopt; 992 993 auto *VecIns = dyn_cast<IntrinsicInst>(DupQLane->getArgOperand(0)); 994 if (!VecIns || VecIns->getIntrinsicID() != Intrinsic::vector_insert) 995 return std::nullopt; 996 997 // Where the vector insert is a fixed constant vector insert into undef at 998 // index zero 999 if (!isa<UndefValue>(VecIns->getArgOperand(0))) 1000 return std::nullopt; 1001 1002 if (!cast<ConstantInt>(VecIns->getArgOperand(2))->isZero()) 1003 return std::nullopt; 1004 1005 auto *ConstVec = dyn_cast<Constant>(VecIns->getArgOperand(1)); 1006 if (!ConstVec) 1007 return std::nullopt; 1008 1009 auto *VecTy = dyn_cast<FixedVectorType>(ConstVec->getType()); 1010 auto *OutTy = dyn_cast<ScalableVectorType>(II.getType()); 1011 if (!VecTy || !OutTy || VecTy->getNumElements() != OutTy->getMinNumElements()) 1012 return std::nullopt; 1013 1014 unsigned NumElts = VecTy->getNumElements(); 1015 unsigned PredicateBits = 0; 1016 1017 // Expand intrinsic operands to a 16-bit byte level predicate 1018 for (unsigned I = 0; I < NumElts; ++I) { 1019 auto *Arg = dyn_cast<ConstantInt>(ConstVec->getAggregateElement(I)); 1020 if (!Arg) 1021 return std::nullopt; 1022 if (!Arg->isZero()) 1023 PredicateBits |= 1 << (I * (16 / NumElts)); 1024 } 1025 1026 // If all bits are zero bail early with an empty predicate 1027 if (PredicateBits == 0) { 1028 auto *PFalse = Constant::getNullValue(II.getType()); 1029 PFalse->takeName(&II); 1030 return IC.replaceInstUsesWith(II, PFalse); 1031 } 1032 1033 // Calculate largest predicate type used (where byte predicate is largest) 1034 unsigned Mask = 8; 1035 for (unsigned I = 0; I < 16; ++I) 1036 if ((PredicateBits & (1 << I)) != 0) 1037 Mask |= (I % 8); 1038 1039 unsigned PredSize = Mask & -Mask; 1040 auto *PredType = ScalableVectorType::get( 1041 Type::getInt1Ty(Ctx), AArch64::SVEBitsPerBlock / (PredSize * 8)); 1042 1043 // Ensure all relevant bits are set 1044 for (unsigned I = 0; I < 16; I += PredSize) 1045 if ((PredicateBits & (1 << I)) == 0) 1046 return std::nullopt; 1047 1048 auto *PTruePat = 1049 ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all); 1050 auto *PTrue = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, 1051 {PredType}, {PTruePat}); 1052 auto *ConvertToSVBool = IC.Builder.CreateIntrinsic( 1053 Intrinsic::aarch64_sve_convert_to_svbool, {PredType}, {PTrue}); 1054 auto *ConvertFromSVBool = 1055 IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_convert_from_svbool, 1056 {II.getType()}, {ConvertToSVBool}); 1057 1058 ConvertFromSVBool->takeName(&II); 1059 return IC.replaceInstUsesWith(II, ConvertFromSVBool); 1060 } 1061 1062 static std::optional<Instruction *> instCombineSVELast(InstCombiner &IC, 1063 IntrinsicInst &II) { 1064 Value *Pg = II.getArgOperand(0); 1065 Value *Vec = II.getArgOperand(1); 1066 auto IntrinsicID = II.getIntrinsicID(); 1067 bool IsAfter = IntrinsicID == Intrinsic::aarch64_sve_lasta; 1068 1069 // lastX(splat(X)) --> X 1070 if (auto *SplatVal = getSplatValue(Vec)) 1071 return IC.replaceInstUsesWith(II, SplatVal); 1072 1073 // If x and/or y is a splat value then: 1074 // lastX (binop (x, y)) --> binop(lastX(x), lastX(y)) 1075 Value *LHS, *RHS; 1076 if (match(Vec, m_OneUse(m_BinOp(m_Value(LHS), m_Value(RHS))))) { 1077 if (isSplatValue(LHS) || isSplatValue(RHS)) { 1078 auto *OldBinOp = cast<BinaryOperator>(Vec); 1079 auto OpC = OldBinOp->getOpcode(); 1080 auto *NewLHS = 1081 IC.Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, LHS}); 1082 auto *NewRHS = 1083 IC.Builder.CreateIntrinsic(IntrinsicID, {Vec->getType()}, {Pg, RHS}); 1084 auto *NewBinOp = BinaryOperator::CreateWithCopiedFlags( 1085 OpC, NewLHS, NewRHS, OldBinOp, OldBinOp->getName(), &II); 1086 return IC.replaceInstUsesWith(II, NewBinOp); 1087 } 1088 } 1089 1090 auto *C = dyn_cast<Constant>(Pg); 1091 if (IsAfter && C && C->isNullValue()) { 1092 // The intrinsic is extracting lane 0 so use an extract instead. 1093 auto *IdxTy = Type::getInt64Ty(II.getContext()); 1094 auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, 0)); 1095 Extract->insertBefore(&II); 1096 Extract->takeName(&II); 1097 return IC.replaceInstUsesWith(II, Extract); 1098 } 1099 1100 auto *IntrPG = dyn_cast<IntrinsicInst>(Pg); 1101 if (!IntrPG) 1102 return std::nullopt; 1103 1104 if (IntrPG->getIntrinsicID() != Intrinsic::aarch64_sve_ptrue) 1105 return std::nullopt; 1106 1107 const auto PTruePattern = 1108 cast<ConstantInt>(IntrPG->getOperand(0))->getZExtValue(); 1109 1110 // Can the intrinsic's predicate be converted to a known constant index? 1111 unsigned MinNumElts = getNumElementsFromSVEPredPattern(PTruePattern); 1112 if (!MinNumElts) 1113 return std::nullopt; 1114 1115 unsigned Idx = MinNumElts - 1; 1116 // Increment the index if extracting the element after the last active 1117 // predicate element. 1118 if (IsAfter) 1119 ++Idx; 1120 1121 // Ignore extracts whose index is larger than the known minimum vector 1122 // length. NOTE: This is an artificial constraint where we prefer to 1123 // maintain what the user asked for until an alternative is proven faster. 1124 auto *PgVTy = cast<ScalableVectorType>(Pg->getType()); 1125 if (Idx >= PgVTy->getMinNumElements()) 1126 return std::nullopt; 1127 1128 // The intrinsic is extracting a fixed lane so use an extract instead. 1129 auto *IdxTy = Type::getInt64Ty(II.getContext()); 1130 auto *Extract = ExtractElementInst::Create(Vec, ConstantInt::get(IdxTy, Idx)); 1131 Extract->insertBefore(&II); 1132 Extract->takeName(&II); 1133 return IC.replaceInstUsesWith(II, Extract); 1134 } 1135 1136 static std::optional<Instruction *> instCombineSVECondLast(InstCombiner &IC, 1137 IntrinsicInst &II) { 1138 // The SIMD&FP variant of CLAST[AB] is significantly faster than the scalar 1139 // integer variant across a variety of micro-architectures. Replace scalar 1140 // integer CLAST[AB] intrinsic with optimal SIMD&FP variant. A simple 1141 // bitcast-to-fp + clast[ab] + bitcast-to-int will cost a cycle or two more 1142 // depending on the micro-architecture, but has been observed as generally 1143 // being faster, particularly when the CLAST[AB] op is a loop-carried 1144 // dependency. 1145 Value *Pg = II.getArgOperand(0); 1146 Value *Fallback = II.getArgOperand(1); 1147 Value *Vec = II.getArgOperand(2); 1148 Type *Ty = II.getType(); 1149 1150 if (!Ty->isIntegerTy()) 1151 return std::nullopt; 1152 1153 Type *FPTy; 1154 switch (cast<IntegerType>(Ty)->getBitWidth()) { 1155 default: 1156 return std::nullopt; 1157 case 16: 1158 FPTy = IC.Builder.getHalfTy(); 1159 break; 1160 case 32: 1161 FPTy = IC.Builder.getFloatTy(); 1162 break; 1163 case 64: 1164 FPTy = IC.Builder.getDoubleTy(); 1165 break; 1166 } 1167 1168 Value *FPFallBack = IC.Builder.CreateBitCast(Fallback, FPTy); 1169 auto *FPVTy = VectorType::get( 1170 FPTy, cast<VectorType>(Vec->getType())->getElementCount()); 1171 Value *FPVec = IC.Builder.CreateBitCast(Vec, FPVTy); 1172 auto *FPII = IC.Builder.CreateIntrinsic( 1173 II.getIntrinsicID(), {FPVec->getType()}, {Pg, FPFallBack, FPVec}); 1174 Value *FPIItoInt = IC.Builder.CreateBitCast(FPII, II.getType()); 1175 return IC.replaceInstUsesWith(II, FPIItoInt); 1176 } 1177 1178 static std::optional<Instruction *> instCombineRDFFR(InstCombiner &IC, 1179 IntrinsicInst &II) { 1180 LLVMContext &Ctx = II.getContext(); 1181 // Replace rdffr with predicated rdffr.z intrinsic, so that optimizePTestInstr 1182 // can work with RDFFR_PP for ptest elimination. 1183 auto *AllPat = 1184 ConstantInt::get(Type::getInt32Ty(Ctx), AArch64SVEPredPattern::all); 1185 auto *PTrue = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptrue, 1186 {II.getType()}, {AllPat}); 1187 auto *RDFFR = 1188 IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_rdffr_z, {}, {PTrue}); 1189 RDFFR->takeName(&II); 1190 return IC.replaceInstUsesWith(II, RDFFR); 1191 } 1192 1193 static std::optional<Instruction *> 1194 instCombineSVECntElts(InstCombiner &IC, IntrinsicInst &II, unsigned NumElts) { 1195 const auto Pattern = cast<ConstantInt>(II.getArgOperand(0))->getZExtValue(); 1196 1197 if (Pattern == AArch64SVEPredPattern::all) { 1198 Constant *StepVal = ConstantInt::get(II.getType(), NumElts); 1199 auto *VScale = IC.Builder.CreateVScale(StepVal); 1200 VScale->takeName(&II); 1201 return IC.replaceInstUsesWith(II, VScale); 1202 } 1203 1204 unsigned MinNumElts = getNumElementsFromSVEPredPattern(Pattern); 1205 1206 return MinNumElts && NumElts >= MinNumElts 1207 ? std::optional<Instruction *>(IC.replaceInstUsesWith( 1208 II, ConstantInt::get(II.getType(), MinNumElts))) 1209 : std::nullopt; 1210 } 1211 1212 static std::optional<Instruction *> instCombineSVEPTest(InstCombiner &IC, 1213 IntrinsicInst &II) { 1214 Value *PgVal = II.getArgOperand(0); 1215 Value *OpVal = II.getArgOperand(1); 1216 1217 // PTEST_<FIRST|LAST>(X, X) is equivalent to PTEST_ANY(X, X). 1218 // Later optimizations prefer this form. 1219 if (PgVal == OpVal && 1220 (II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_first || 1221 II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_last)) { 1222 Value *Ops[] = {PgVal, OpVal}; 1223 Type *Tys[] = {PgVal->getType()}; 1224 1225 auto *PTest = 1226 IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_ptest_any, Tys, Ops); 1227 PTest->takeName(&II); 1228 1229 return IC.replaceInstUsesWith(II, PTest); 1230 } 1231 1232 IntrinsicInst *Pg = dyn_cast<IntrinsicInst>(PgVal); 1233 IntrinsicInst *Op = dyn_cast<IntrinsicInst>(OpVal); 1234 1235 if (!Pg || !Op) 1236 return std::nullopt; 1237 1238 Intrinsic::ID OpIID = Op->getIntrinsicID(); 1239 1240 if (Pg->getIntrinsicID() == Intrinsic::aarch64_sve_convert_to_svbool && 1241 OpIID == Intrinsic::aarch64_sve_convert_to_svbool && 1242 Pg->getArgOperand(0)->getType() == Op->getArgOperand(0)->getType()) { 1243 Value *Ops[] = {Pg->getArgOperand(0), Op->getArgOperand(0)}; 1244 Type *Tys[] = {Pg->getArgOperand(0)->getType()}; 1245 1246 auto *PTest = IC.Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops); 1247 1248 PTest->takeName(&II); 1249 return IC.replaceInstUsesWith(II, PTest); 1250 } 1251 1252 // Transform PTEST_ANY(X=OP(PG,...), X) -> PTEST_ANY(PG, X)). 1253 // Later optimizations may rewrite sequence to use the flag-setting variant 1254 // of instruction X to remove PTEST. 1255 if ((Pg == Op) && (II.getIntrinsicID() == Intrinsic::aarch64_sve_ptest_any) && 1256 ((OpIID == Intrinsic::aarch64_sve_brka_z) || 1257 (OpIID == Intrinsic::aarch64_sve_brkb_z) || 1258 (OpIID == Intrinsic::aarch64_sve_brkpa_z) || 1259 (OpIID == Intrinsic::aarch64_sve_brkpb_z) || 1260 (OpIID == Intrinsic::aarch64_sve_rdffr_z) || 1261 (OpIID == Intrinsic::aarch64_sve_and_z) || 1262 (OpIID == Intrinsic::aarch64_sve_bic_z) || 1263 (OpIID == Intrinsic::aarch64_sve_eor_z) || 1264 (OpIID == Intrinsic::aarch64_sve_nand_z) || 1265 (OpIID == Intrinsic::aarch64_sve_nor_z) || 1266 (OpIID == Intrinsic::aarch64_sve_orn_z) || 1267 (OpIID == Intrinsic::aarch64_sve_orr_z))) { 1268 Value *Ops[] = {Pg->getArgOperand(0), Pg}; 1269 Type *Tys[] = {Pg->getType()}; 1270 1271 auto *PTest = IC.Builder.CreateIntrinsic(II.getIntrinsicID(), Tys, Ops); 1272 PTest->takeName(&II); 1273 1274 return IC.replaceInstUsesWith(II, PTest); 1275 } 1276 1277 return std::nullopt; 1278 } 1279 1280 template <Intrinsic::ID MulOpc, typename Intrinsic::ID FuseOpc> 1281 static std::optional<Instruction *> 1282 instCombineSVEVectorFuseMulAddSub(InstCombiner &IC, IntrinsicInst &II, 1283 bool MergeIntoAddendOp) { 1284 Value *P = II.getOperand(0); 1285 Value *MulOp0, *MulOp1, *AddendOp, *Mul; 1286 if (MergeIntoAddendOp) { 1287 AddendOp = II.getOperand(1); 1288 Mul = II.getOperand(2); 1289 } else { 1290 AddendOp = II.getOperand(2); 1291 Mul = II.getOperand(1); 1292 } 1293 1294 if (!match(Mul, m_Intrinsic<MulOpc>(m_Specific(P), m_Value(MulOp0), 1295 m_Value(MulOp1)))) 1296 return std::nullopt; 1297 1298 if (!Mul->hasOneUse()) 1299 return std::nullopt; 1300 1301 Instruction *FMFSource = nullptr; 1302 if (II.getType()->isFPOrFPVectorTy()) { 1303 llvm::FastMathFlags FAddFlags = II.getFastMathFlags(); 1304 // Stop the combine when the flags on the inputs differ in case dropping 1305 // flags would lead to us missing out on more beneficial optimizations. 1306 if (FAddFlags != cast<CallInst>(Mul)->getFastMathFlags()) 1307 return std::nullopt; 1308 if (!FAddFlags.allowContract()) 1309 return std::nullopt; 1310 FMFSource = &II; 1311 } 1312 1313 CallInst *Res; 1314 if (MergeIntoAddendOp) 1315 Res = IC.Builder.CreateIntrinsic(FuseOpc, {II.getType()}, 1316 {P, AddendOp, MulOp0, MulOp1}, FMFSource); 1317 else 1318 Res = IC.Builder.CreateIntrinsic(FuseOpc, {II.getType()}, 1319 {P, MulOp0, MulOp1, AddendOp}, FMFSource); 1320 1321 return IC.replaceInstUsesWith(II, Res); 1322 } 1323 1324 static std::optional<Instruction *> 1325 instCombineSVELD1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) { 1326 Value *Pred = II.getOperand(0); 1327 Value *PtrOp = II.getOperand(1); 1328 Type *VecTy = II.getType(); 1329 1330 if (isAllActivePredicate(Pred)) { 1331 LoadInst *Load = IC.Builder.CreateLoad(VecTy, PtrOp); 1332 Load->copyMetadata(II); 1333 return IC.replaceInstUsesWith(II, Load); 1334 } 1335 1336 CallInst *MaskedLoad = 1337 IC.Builder.CreateMaskedLoad(VecTy, PtrOp, PtrOp->getPointerAlignment(DL), 1338 Pred, ConstantAggregateZero::get(VecTy)); 1339 MaskedLoad->copyMetadata(II); 1340 return IC.replaceInstUsesWith(II, MaskedLoad); 1341 } 1342 1343 static std::optional<Instruction *> 1344 instCombineSVEST1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) { 1345 Value *VecOp = II.getOperand(0); 1346 Value *Pred = II.getOperand(1); 1347 Value *PtrOp = II.getOperand(2); 1348 1349 if (isAllActivePredicate(Pred)) { 1350 StoreInst *Store = IC.Builder.CreateStore(VecOp, PtrOp); 1351 Store->copyMetadata(II); 1352 return IC.eraseInstFromFunction(II); 1353 } 1354 1355 CallInst *MaskedStore = IC.Builder.CreateMaskedStore( 1356 VecOp, PtrOp, PtrOp->getPointerAlignment(DL), Pred); 1357 MaskedStore->copyMetadata(II); 1358 return IC.eraseInstFromFunction(II); 1359 } 1360 1361 static Instruction::BinaryOps intrinsicIDToBinOpCode(unsigned Intrinsic) { 1362 switch (Intrinsic) { 1363 case Intrinsic::aarch64_sve_fmul_u: 1364 return Instruction::BinaryOps::FMul; 1365 case Intrinsic::aarch64_sve_fadd_u: 1366 return Instruction::BinaryOps::FAdd; 1367 case Intrinsic::aarch64_sve_fsub_u: 1368 return Instruction::BinaryOps::FSub; 1369 default: 1370 return Instruction::BinaryOpsEnd; 1371 } 1372 } 1373 1374 static std::optional<Instruction *> 1375 instCombineSVEVectorBinOp(InstCombiner &IC, IntrinsicInst &II) { 1376 // Bail due to missing support for ISD::STRICT_ scalable vector operations. 1377 if (II.isStrictFP()) 1378 return std::nullopt; 1379 1380 auto *OpPredicate = II.getOperand(0); 1381 auto BinOpCode = intrinsicIDToBinOpCode(II.getIntrinsicID()); 1382 if (BinOpCode == Instruction::BinaryOpsEnd || 1383 !match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>( 1384 m_ConstantInt<AArch64SVEPredPattern::all>()))) 1385 return std::nullopt; 1386 IRBuilderBase::FastMathFlagGuard FMFGuard(IC.Builder); 1387 IC.Builder.setFastMathFlags(II.getFastMathFlags()); 1388 auto BinOp = 1389 IC.Builder.CreateBinOp(BinOpCode, II.getOperand(1), II.getOperand(2)); 1390 return IC.replaceInstUsesWith(II, BinOp); 1391 } 1392 1393 // Canonicalise operations that take an all active predicate (e.g. sve.add -> 1394 // sve.add_u). 1395 static std::optional<Instruction *> instCombineSVEAllActive(IntrinsicInst &II, 1396 Intrinsic::ID IID) { 1397 auto *OpPredicate = II.getOperand(0); 1398 if (!match(OpPredicate, m_Intrinsic<Intrinsic::aarch64_sve_ptrue>( 1399 m_ConstantInt<AArch64SVEPredPattern::all>()))) 1400 return std::nullopt; 1401 1402 auto *Mod = II.getModule(); 1403 auto *NewDecl = Intrinsic::getDeclaration(Mod, IID, {II.getType()}); 1404 II.setCalledFunction(NewDecl); 1405 1406 return &II; 1407 } 1408 1409 // Simplify operations where predicate has all inactive lanes or try to replace 1410 // with _u form when all lanes are active 1411 static std::optional<Instruction *> 1412 instCombineSVEAllOrNoActive(InstCombiner &IC, IntrinsicInst &II, 1413 Intrinsic::ID IID) { 1414 if (match(II.getOperand(0), m_ZeroInt())) { 1415 // llvm_ir, pred(0), op1, op2 - Spec says to return op1 when all lanes are 1416 // inactive for sv[func]_m 1417 return IC.replaceInstUsesWith(II, II.getOperand(1)); 1418 } 1419 return instCombineSVEAllActive(II, IID); 1420 } 1421 1422 static std::optional<Instruction *> instCombineSVEVectorAdd(InstCombiner &IC, 1423 IntrinsicInst &II) { 1424 if (auto II_U = 1425 instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_add_u)) 1426 return II_U; 1427 if (auto MLA = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul, 1428 Intrinsic::aarch64_sve_mla>( 1429 IC, II, true)) 1430 return MLA; 1431 if (auto MAD = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul, 1432 Intrinsic::aarch64_sve_mad>( 1433 IC, II, false)) 1434 return MAD; 1435 return std::nullopt; 1436 } 1437 1438 static std::optional<Instruction *> 1439 instCombineSVEVectorFAdd(InstCombiner &IC, IntrinsicInst &II) { 1440 if (auto II_U = 1441 instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fadd_u)) 1442 return II_U; 1443 if (auto FMLA = 1444 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul, 1445 Intrinsic::aarch64_sve_fmla>(IC, II, 1446 true)) 1447 return FMLA; 1448 if (auto FMAD = 1449 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul, 1450 Intrinsic::aarch64_sve_fmad>(IC, II, 1451 false)) 1452 return FMAD; 1453 if (auto FMLA = 1454 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u, 1455 Intrinsic::aarch64_sve_fmla>(IC, II, 1456 true)) 1457 return FMLA; 1458 return std::nullopt; 1459 } 1460 1461 static std::optional<Instruction *> 1462 instCombineSVEVectorFAddU(InstCombiner &IC, IntrinsicInst &II) { 1463 if (auto FMLA = 1464 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul, 1465 Intrinsic::aarch64_sve_fmla>(IC, II, 1466 true)) 1467 return FMLA; 1468 if (auto FMAD = 1469 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul, 1470 Intrinsic::aarch64_sve_fmad>(IC, II, 1471 false)) 1472 return FMAD; 1473 if (auto FMLA_U = 1474 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u, 1475 Intrinsic::aarch64_sve_fmla_u>( 1476 IC, II, true)) 1477 return FMLA_U; 1478 return instCombineSVEVectorBinOp(IC, II); 1479 } 1480 1481 static std::optional<Instruction *> 1482 instCombineSVEVectorFSub(InstCombiner &IC, IntrinsicInst &II) { 1483 if (auto II_U = 1484 instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fsub_u)) 1485 return II_U; 1486 if (auto FMLS = 1487 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul, 1488 Intrinsic::aarch64_sve_fmls>(IC, II, 1489 true)) 1490 return FMLS; 1491 if (auto FMSB = 1492 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul, 1493 Intrinsic::aarch64_sve_fnmsb>( 1494 IC, II, false)) 1495 return FMSB; 1496 if (auto FMLS = 1497 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u, 1498 Intrinsic::aarch64_sve_fmls>(IC, II, 1499 true)) 1500 return FMLS; 1501 return std::nullopt; 1502 } 1503 1504 static std::optional<Instruction *> 1505 instCombineSVEVectorFSubU(InstCombiner &IC, IntrinsicInst &II) { 1506 if (auto FMLS = 1507 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul, 1508 Intrinsic::aarch64_sve_fmls>(IC, II, 1509 true)) 1510 return FMLS; 1511 if (auto FMSB = 1512 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul, 1513 Intrinsic::aarch64_sve_fnmsb>( 1514 IC, II, false)) 1515 return FMSB; 1516 if (auto FMLS_U = 1517 instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_fmul_u, 1518 Intrinsic::aarch64_sve_fmls_u>( 1519 IC, II, true)) 1520 return FMLS_U; 1521 return instCombineSVEVectorBinOp(IC, II); 1522 } 1523 1524 static std::optional<Instruction *> instCombineSVEVectorSub(InstCombiner &IC, 1525 IntrinsicInst &II) { 1526 if (auto II_U = 1527 instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_sub_u)) 1528 return II_U; 1529 if (auto MLS = instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul, 1530 Intrinsic::aarch64_sve_mls>( 1531 IC, II, true)) 1532 return MLS; 1533 return std::nullopt; 1534 } 1535 1536 static std::optional<Instruction *> instCombineSVEVectorMul(InstCombiner &IC, 1537 IntrinsicInst &II, 1538 Intrinsic::ID IID) { 1539 auto *OpPredicate = II.getOperand(0); 1540 auto *OpMultiplicand = II.getOperand(1); 1541 auto *OpMultiplier = II.getOperand(2); 1542 1543 // Return true if a given instruction is a unit splat value, false otherwise. 1544 auto IsUnitSplat = [](auto *I) { 1545 auto *SplatValue = getSplatValue(I); 1546 if (!SplatValue) 1547 return false; 1548 return match(SplatValue, m_FPOne()) || match(SplatValue, m_One()); 1549 }; 1550 1551 // Return true if a given instruction is an aarch64_sve_dup intrinsic call 1552 // with a unit splat value, false otherwise. 1553 auto IsUnitDup = [](auto *I) { 1554 auto *IntrI = dyn_cast<IntrinsicInst>(I); 1555 if (!IntrI || IntrI->getIntrinsicID() != Intrinsic::aarch64_sve_dup) 1556 return false; 1557 1558 auto *SplatValue = IntrI->getOperand(2); 1559 return match(SplatValue, m_FPOne()) || match(SplatValue, m_One()); 1560 }; 1561 1562 if (IsUnitSplat(OpMultiplier)) { 1563 // [f]mul pg %n, (dupx 1) => %n 1564 OpMultiplicand->takeName(&II); 1565 return IC.replaceInstUsesWith(II, OpMultiplicand); 1566 } else if (IsUnitDup(OpMultiplier)) { 1567 // [f]mul pg %n, (dup pg 1) => %n 1568 auto *DupInst = cast<IntrinsicInst>(OpMultiplier); 1569 auto *DupPg = DupInst->getOperand(1); 1570 // TODO: this is naive. The optimization is still valid if DupPg 1571 // 'encompasses' OpPredicate, not only if they're the same predicate. 1572 if (OpPredicate == DupPg) { 1573 OpMultiplicand->takeName(&II); 1574 return IC.replaceInstUsesWith(II, OpMultiplicand); 1575 } 1576 } 1577 1578 return instCombineSVEVectorBinOp(IC, II); 1579 } 1580 1581 static std::optional<Instruction *> instCombineSVEUnpack(InstCombiner &IC, 1582 IntrinsicInst &II) { 1583 Value *UnpackArg = II.getArgOperand(0); 1584 auto *RetTy = cast<ScalableVectorType>(II.getType()); 1585 bool IsSigned = II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpkhi || 1586 II.getIntrinsicID() == Intrinsic::aarch64_sve_sunpklo; 1587 1588 // Hi = uunpkhi(splat(X)) --> Hi = splat(extend(X)) 1589 // Lo = uunpklo(splat(X)) --> Lo = splat(extend(X)) 1590 if (auto *ScalarArg = getSplatValue(UnpackArg)) { 1591 ScalarArg = 1592 IC.Builder.CreateIntCast(ScalarArg, RetTy->getScalarType(), IsSigned); 1593 Value *NewVal = 1594 IC.Builder.CreateVectorSplat(RetTy->getElementCount(), ScalarArg); 1595 NewVal->takeName(&II); 1596 return IC.replaceInstUsesWith(II, NewVal); 1597 } 1598 1599 return std::nullopt; 1600 } 1601 static std::optional<Instruction *> instCombineSVETBL(InstCombiner &IC, 1602 IntrinsicInst &II) { 1603 auto *OpVal = II.getOperand(0); 1604 auto *OpIndices = II.getOperand(1); 1605 VectorType *VTy = cast<VectorType>(II.getType()); 1606 1607 // Check whether OpIndices is a constant splat value < minimal element count 1608 // of result. 1609 auto *SplatValue = dyn_cast_or_null<ConstantInt>(getSplatValue(OpIndices)); 1610 if (!SplatValue || 1611 SplatValue->getValue().uge(VTy->getElementCount().getKnownMinValue())) 1612 return std::nullopt; 1613 1614 // Convert sve_tbl(OpVal sve_dup_x(SplatValue)) to 1615 // splat_vector(extractelement(OpVal, SplatValue)) for further optimization. 1616 auto *Extract = IC.Builder.CreateExtractElement(OpVal, SplatValue); 1617 auto *VectorSplat = 1618 IC.Builder.CreateVectorSplat(VTy->getElementCount(), Extract); 1619 1620 VectorSplat->takeName(&II); 1621 return IC.replaceInstUsesWith(II, VectorSplat); 1622 } 1623 1624 static std::optional<Instruction *> instCombineSVEZip(InstCombiner &IC, 1625 IntrinsicInst &II) { 1626 // zip1(uzp1(A, B), uzp2(A, B)) --> A 1627 // zip2(uzp1(A, B), uzp2(A, B)) --> B 1628 Value *A, *B; 1629 if (match(II.getArgOperand(0), 1630 m_Intrinsic<Intrinsic::aarch64_sve_uzp1>(m_Value(A), m_Value(B))) && 1631 match(II.getArgOperand(1), m_Intrinsic<Intrinsic::aarch64_sve_uzp2>( 1632 m_Specific(A), m_Specific(B)))) 1633 return IC.replaceInstUsesWith( 1634 II, (II.getIntrinsicID() == Intrinsic::aarch64_sve_zip1 ? A : B)); 1635 1636 return std::nullopt; 1637 } 1638 1639 static std::optional<Instruction *> 1640 instCombineLD1GatherIndex(InstCombiner &IC, IntrinsicInst &II) { 1641 Value *Mask = II.getOperand(0); 1642 Value *BasePtr = II.getOperand(1); 1643 Value *Index = II.getOperand(2); 1644 Type *Ty = II.getType(); 1645 Value *PassThru = ConstantAggregateZero::get(Ty); 1646 1647 // Contiguous gather => masked load. 1648 // (sve.ld1.gather.index Mask BasePtr (sve.index IndexBase 1)) 1649 // => (masked.load (gep BasePtr IndexBase) Align Mask zeroinitializer) 1650 Value *IndexBase; 1651 if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>( 1652 m_Value(IndexBase), m_SpecificInt(1)))) { 1653 Align Alignment = 1654 BasePtr->getPointerAlignment(II.getModule()->getDataLayout()); 1655 1656 Type *VecPtrTy = PointerType::getUnqual(Ty); 1657 Value *Ptr = IC.Builder.CreateGEP(cast<VectorType>(Ty)->getElementType(), 1658 BasePtr, IndexBase); 1659 Ptr = IC.Builder.CreateBitCast(Ptr, VecPtrTy); 1660 CallInst *MaskedLoad = 1661 IC.Builder.CreateMaskedLoad(Ty, Ptr, Alignment, Mask, PassThru); 1662 MaskedLoad->takeName(&II); 1663 return IC.replaceInstUsesWith(II, MaskedLoad); 1664 } 1665 1666 return std::nullopt; 1667 } 1668 1669 static std::optional<Instruction *> 1670 instCombineST1ScatterIndex(InstCombiner &IC, IntrinsicInst &II) { 1671 Value *Val = II.getOperand(0); 1672 Value *Mask = II.getOperand(1); 1673 Value *BasePtr = II.getOperand(2); 1674 Value *Index = II.getOperand(3); 1675 Type *Ty = Val->getType(); 1676 1677 // Contiguous scatter => masked store. 1678 // (sve.st1.scatter.index Value Mask BasePtr (sve.index IndexBase 1)) 1679 // => (masked.store Value (gep BasePtr IndexBase) Align Mask) 1680 Value *IndexBase; 1681 if (match(Index, m_Intrinsic<Intrinsic::aarch64_sve_index>( 1682 m_Value(IndexBase), m_SpecificInt(1)))) { 1683 Align Alignment = 1684 BasePtr->getPointerAlignment(II.getModule()->getDataLayout()); 1685 1686 Value *Ptr = IC.Builder.CreateGEP(cast<VectorType>(Ty)->getElementType(), 1687 BasePtr, IndexBase); 1688 Type *VecPtrTy = PointerType::getUnqual(Ty); 1689 Ptr = IC.Builder.CreateBitCast(Ptr, VecPtrTy); 1690 1691 (void)IC.Builder.CreateMaskedStore(Val, Ptr, Alignment, Mask); 1692 1693 return IC.eraseInstFromFunction(II); 1694 } 1695 1696 return std::nullopt; 1697 } 1698 1699 static std::optional<Instruction *> instCombineSVESDIV(InstCombiner &IC, 1700 IntrinsicInst &II) { 1701 Type *Int32Ty = IC.Builder.getInt32Ty(); 1702 Value *Pred = II.getOperand(0); 1703 Value *Vec = II.getOperand(1); 1704 Value *DivVec = II.getOperand(2); 1705 1706 Value *SplatValue = getSplatValue(DivVec); 1707 ConstantInt *SplatConstantInt = dyn_cast_or_null<ConstantInt>(SplatValue); 1708 if (!SplatConstantInt) 1709 return std::nullopt; 1710 APInt Divisor = SplatConstantInt->getValue(); 1711 1712 if (Divisor.isPowerOf2()) { 1713 Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2()); 1714 auto ASRD = IC.Builder.CreateIntrinsic( 1715 Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2}); 1716 return IC.replaceInstUsesWith(II, ASRD); 1717 } 1718 if (Divisor.isNegatedPowerOf2()) { 1719 Divisor.negate(); 1720 Constant *DivisorLog2 = ConstantInt::get(Int32Ty, Divisor.logBase2()); 1721 auto ASRD = IC.Builder.CreateIntrinsic( 1722 Intrinsic::aarch64_sve_asrd, {II.getType()}, {Pred, Vec, DivisorLog2}); 1723 auto NEG = IC.Builder.CreateIntrinsic( 1724 Intrinsic::aarch64_sve_neg, {ASRD->getType()}, {ASRD, Pred, ASRD}); 1725 return IC.replaceInstUsesWith(II, NEG); 1726 } 1727 1728 return std::nullopt; 1729 } 1730 1731 bool SimplifyValuePattern(SmallVector<Value *> &Vec, bool AllowPoison) { 1732 size_t VecSize = Vec.size(); 1733 if (VecSize == 1) 1734 return true; 1735 if (!isPowerOf2_64(VecSize)) 1736 return false; 1737 size_t HalfVecSize = VecSize / 2; 1738 1739 for (auto LHS = Vec.begin(), RHS = Vec.begin() + HalfVecSize; 1740 RHS != Vec.end(); LHS++, RHS++) { 1741 if (*LHS != nullptr && *RHS != nullptr) { 1742 if (*LHS == *RHS) 1743 continue; 1744 else 1745 return false; 1746 } 1747 if (!AllowPoison) 1748 return false; 1749 if (*LHS == nullptr && *RHS != nullptr) 1750 *LHS = *RHS; 1751 } 1752 1753 Vec.resize(HalfVecSize); 1754 SimplifyValuePattern(Vec, AllowPoison); 1755 return true; 1756 } 1757 1758 // Try to simplify dupqlane patterns like dupqlane(f32 A, f32 B, f32 A, f32 B) 1759 // to dupqlane(f64(C)) where C is A concatenated with B 1760 static std::optional<Instruction *> instCombineSVEDupqLane(InstCombiner &IC, 1761 IntrinsicInst &II) { 1762 Value *CurrentInsertElt = nullptr, *Default = nullptr; 1763 if (!match(II.getOperand(0), 1764 m_Intrinsic<Intrinsic::vector_insert>( 1765 m_Value(Default), m_Value(CurrentInsertElt), m_Value())) || 1766 !isa<FixedVectorType>(CurrentInsertElt->getType())) 1767 return std::nullopt; 1768 auto IIScalableTy = cast<ScalableVectorType>(II.getType()); 1769 1770 // Insert the scalars into a container ordered by InsertElement index 1771 SmallVector<Value *> Elts(IIScalableTy->getMinNumElements(), nullptr); 1772 while (auto InsertElt = dyn_cast<InsertElementInst>(CurrentInsertElt)) { 1773 auto Idx = cast<ConstantInt>(InsertElt->getOperand(2)); 1774 Elts[Idx->getValue().getZExtValue()] = InsertElt->getOperand(1); 1775 CurrentInsertElt = InsertElt->getOperand(0); 1776 } 1777 1778 bool AllowPoison = 1779 isa<PoisonValue>(CurrentInsertElt) && isa<PoisonValue>(Default); 1780 if (!SimplifyValuePattern(Elts, AllowPoison)) 1781 return std::nullopt; 1782 1783 // Rebuild the simplified chain of InsertElements. e.g. (a, b, a, b) as (a, b) 1784 Value *InsertEltChain = PoisonValue::get(CurrentInsertElt->getType()); 1785 for (size_t I = 0; I < Elts.size(); I++) { 1786 if (Elts[I] == nullptr) 1787 continue; 1788 InsertEltChain = IC.Builder.CreateInsertElement(InsertEltChain, Elts[I], 1789 IC.Builder.getInt64(I)); 1790 } 1791 if (InsertEltChain == nullptr) 1792 return std::nullopt; 1793 1794 // Splat the simplified sequence, e.g. (f16 a, f16 b, f16 c, f16 d) as one i64 1795 // value or (f16 a, f16 b) as one i32 value. This requires an InsertSubvector 1796 // be bitcast to a type wide enough to fit the sequence, be splatted, and then 1797 // be narrowed back to the original type. 1798 unsigned PatternWidth = IIScalableTy->getScalarSizeInBits() * Elts.size(); 1799 unsigned PatternElementCount = IIScalableTy->getScalarSizeInBits() * 1800 IIScalableTy->getMinNumElements() / 1801 PatternWidth; 1802 1803 IntegerType *WideTy = IC.Builder.getIntNTy(PatternWidth); 1804 auto *WideScalableTy = ScalableVectorType::get(WideTy, PatternElementCount); 1805 auto *WideShuffleMaskTy = 1806 ScalableVectorType::get(IC.Builder.getInt32Ty(), PatternElementCount); 1807 1808 auto ZeroIdx = ConstantInt::get(IC.Builder.getInt64Ty(), APInt(64, 0)); 1809 auto InsertSubvector = IC.Builder.CreateInsertVector( 1810 II.getType(), PoisonValue::get(II.getType()), InsertEltChain, ZeroIdx); 1811 auto WideBitcast = 1812 IC.Builder.CreateBitOrPointerCast(InsertSubvector, WideScalableTy); 1813 auto WideShuffleMask = ConstantAggregateZero::get(WideShuffleMaskTy); 1814 auto WideShuffle = IC.Builder.CreateShuffleVector( 1815 WideBitcast, PoisonValue::get(WideScalableTy), WideShuffleMask); 1816 auto NarrowBitcast = 1817 IC.Builder.CreateBitOrPointerCast(WideShuffle, II.getType()); 1818 1819 return IC.replaceInstUsesWith(II, NarrowBitcast); 1820 } 1821 1822 static std::optional<Instruction *> instCombineMaxMinNM(InstCombiner &IC, 1823 IntrinsicInst &II) { 1824 Value *A = II.getArgOperand(0); 1825 Value *B = II.getArgOperand(1); 1826 if (A == B) 1827 return IC.replaceInstUsesWith(II, A); 1828 1829 return std::nullopt; 1830 } 1831 1832 static std::optional<Instruction *> instCombineSVESrshl(InstCombiner &IC, 1833 IntrinsicInst &II) { 1834 Value *Pred = II.getOperand(0); 1835 Value *Vec = II.getOperand(1); 1836 Value *Shift = II.getOperand(2); 1837 1838 // Convert SRSHL into the simpler LSL intrinsic when fed by an ABS intrinsic. 1839 Value *AbsPred, *MergedValue; 1840 if (!match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_sqabs>( 1841 m_Value(MergedValue), m_Value(AbsPred), m_Value())) && 1842 !match(Vec, m_Intrinsic<Intrinsic::aarch64_sve_abs>( 1843 m_Value(MergedValue), m_Value(AbsPred), m_Value()))) 1844 1845 return std::nullopt; 1846 1847 // Transform is valid if any of the following are true: 1848 // * The ABS merge value is an undef or non-negative 1849 // * The ABS predicate is all active 1850 // * The ABS predicate and the SRSHL predicates are the same 1851 if (!isa<UndefValue>(MergedValue) && !match(MergedValue, m_NonNegative()) && 1852 AbsPred != Pred && !isAllActivePredicate(AbsPred)) 1853 return std::nullopt; 1854 1855 // Only valid when the shift amount is non-negative, otherwise the rounding 1856 // behaviour of SRSHL cannot be ignored. 1857 if (!match(Shift, m_NonNegative())) 1858 return std::nullopt; 1859 1860 auto LSL = IC.Builder.CreateIntrinsic(Intrinsic::aarch64_sve_lsl, 1861 {II.getType()}, {Pred, Vec, Shift}); 1862 1863 return IC.replaceInstUsesWith(II, LSL); 1864 } 1865 1866 std::optional<Instruction *> 1867 AArch64TTIImpl::instCombineIntrinsic(InstCombiner &IC, 1868 IntrinsicInst &II) const { 1869 Intrinsic::ID IID = II.getIntrinsicID(); 1870 switch (IID) { 1871 default: 1872 break; 1873 case Intrinsic::aarch64_neon_fmaxnm: 1874 case Intrinsic::aarch64_neon_fminnm: 1875 return instCombineMaxMinNM(IC, II); 1876 case Intrinsic::aarch64_sve_convert_from_svbool: 1877 return instCombineConvertFromSVBool(IC, II); 1878 case Intrinsic::aarch64_sve_dup: 1879 return instCombineSVEDup(IC, II); 1880 case Intrinsic::aarch64_sve_dup_x: 1881 return instCombineSVEDupX(IC, II); 1882 case Intrinsic::aarch64_sve_cmpne: 1883 case Intrinsic::aarch64_sve_cmpne_wide: 1884 return instCombineSVECmpNE(IC, II); 1885 case Intrinsic::aarch64_sve_rdffr: 1886 return instCombineRDFFR(IC, II); 1887 case Intrinsic::aarch64_sve_lasta: 1888 case Intrinsic::aarch64_sve_lastb: 1889 return instCombineSVELast(IC, II); 1890 case Intrinsic::aarch64_sve_clasta_n: 1891 case Intrinsic::aarch64_sve_clastb_n: 1892 return instCombineSVECondLast(IC, II); 1893 case Intrinsic::aarch64_sve_cntd: 1894 return instCombineSVECntElts(IC, II, 2); 1895 case Intrinsic::aarch64_sve_cntw: 1896 return instCombineSVECntElts(IC, II, 4); 1897 case Intrinsic::aarch64_sve_cnth: 1898 return instCombineSVECntElts(IC, II, 8); 1899 case Intrinsic::aarch64_sve_cntb: 1900 return instCombineSVECntElts(IC, II, 16); 1901 case Intrinsic::aarch64_sve_ptest_any: 1902 case Intrinsic::aarch64_sve_ptest_first: 1903 case Intrinsic::aarch64_sve_ptest_last: 1904 return instCombineSVEPTest(IC, II); 1905 case Intrinsic::aarch64_sve_fabd: 1906 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fabd_u); 1907 case Intrinsic::aarch64_sve_fadd: 1908 return instCombineSVEVectorFAdd(IC, II); 1909 case Intrinsic::aarch64_sve_fadd_u: 1910 return instCombineSVEVectorFAddU(IC, II); 1911 case Intrinsic::aarch64_sve_fdiv: 1912 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fdiv_u); 1913 case Intrinsic::aarch64_sve_fmax: 1914 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmax_u); 1915 case Intrinsic::aarch64_sve_fmaxnm: 1916 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmaxnm_u); 1917 case Intrinsic::aarch64_sve_fmin: 1918 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmin_u); 1919 case Intrinsic::aarch64_sve_fminnm: 1920 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fminnm_u); 1921 case Intrinsic::aarch64_sve_fmla: 1922 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmla_u); 1923 case Intrinsic::aarch64_sve_fmls: 1924 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmls_u); 1925 case Intrinsic::aarch64_sve_fmul: 1926 if (auto II_U = 1927 instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmul_u)) 1928 return II_U; 1929 return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_fmul_u); 1930 case Intrinsic::aarch64_sve_fmul_u: 1931 return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_fmul_u); 1932 case Intrinsic::aarch64_sve_fmulx: 1933 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fmulx_u); 1934 case Intrinsic::aarch64_sve_fnmla: 1935 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fnmla_u); 1936 case Intrinsic::aarch64_sve_fnmls: 1937 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_fnmls_u); 1938 case Intrinsic::aarch64_sve_fsub: 1939 return instCombineSVEVectorFSub(IC, II); 1940 case Intrinsic::aarch64_sve_fsub_u: 1941 return instCombineSVEVectorFSubU(IC, II); 1942 case Intrinsic::aarch64_sve_add: 1943 return instCombineSVEVectorAdd(IC, II); 1944 case Intrinsic::aarch64_sve_add_u: 1945 return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul_u, 1946 Intrinsic::aarch64_sve_mla_u>( 1947 IC, II, true); 1948 case Intrinsic::aarch64_sve_mla: 1949 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_mla_u); 1950 case Intrinsic::aarch64_sve_mls: 1951 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_mls_u); 1952 case Intrinsic::aarch64_sve_mul: 1953 if (auto II_U = 1954 instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_mul_u)) 1955 return II_U; 1956 return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_mul_u); 1957 case Intrinsic::aarch64_sve_mul_u: 1958 return instCombineSVEVectorMul(IC, II, Intrinsic::aarch64_sve_mul_u); 1959 case Intrinsic::aarch64_sve_sabd: 1960 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_sabd_u); 1961 case Intrinsic::aarch64_sve_smax: 1962 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_smax_u); 1963 case Intrinsic::aarch64_sve_smin: 1964 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_smin_u); 1965 case Intrinsic::aarch64_sve_smulh: 1966 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_smulh_u); 1967 case Intrinsic::aarch64_sve_sub: 1968 return instCombineSVEVectorSub(IC, II); 1969 case Intrinsic::aarch64_sve_sub_u: 1970 return instCombineSVEVectorFuseMulAddSub<Intrinsic::aarch64_sve_mul_u, 1971 Intrinsic::aarch64_sve_mls_u>( 1972 IC, II, true); 1973 case Intrinsic::aarch64_sve_uabd: 1974 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_uabd_u); 1975 case Intrinsic::aarch64_sve_umax: 1976 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_umax_u); 1977 case Intrinsic::aarch64_sve_umin: 1978 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_umin_u); 1979 case Intrinsic::aarch64_sve_umulh: 1980 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_umulh_u); 1981 case Intrinsic::aarch64_sve_asr: 1982 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_asr_u); 1983 case Intrinsic::aarch64_sve_lsl: 1984 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_lsl_u); 1985 case Intrinsic::aarch64_sve_lsr: 1986 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_lsr_u); 1987 case Intrinsic::aarch64_sve_and: 1988 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_and_u); 1989 case Intrinsic::aarch64_sve_bic: 1990 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_bic_u); 1991 case Intrinsic::aarch64_sve_eor: 1992 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_eor_u); 1993 case Intrinsic::aarch64_sve_orr: 1994 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_orr_u); 1995 case Intrinsic::aarch64_sve_sqsub: 1996 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_sqsub_u); 1997 case Intrinsic::aarch64_sve_uqsub: 1998 return instCombineSVEAllOrNoActive(IC, II, Intrinsic::aarch64_sve_uqsub_u); 1999 case Intrinsic::aarch64_sve_tbl: 2000 return instCombineSVETBL(IC, II); 2001 case Intrinsic::aarch64_sve_uunpkhi: 2002 case Intrinsic::aarch64_sve_uunpklo: 2003 case Intrinsic::aarch64_sve_sunpkhi: 2004 case Intrinsic::aarch64_sve_sunpklo: 2005 return instCombineSVEUnpack(IC, II); 2006 case Intrinsic::aarch64_sve_zip1: 2007 case Intrinsic::aarch64_sve_zip2: 2008 return instCombineSVEZip(IC, II); 2009 case Intrinsic::aarch64_sve_ld1_gather_index: 2010 return instCombineLD1GatherIndex(IC, II); 2011 case Intrinsic::aarch64_sve_st1_scatter_index: 2012 return instCombineST1ScatterIndex(IC, II); 2013 case Intrinsic::aarch64_sve_ld1: 2014 return instCombineSVELD1(IC, II, DL); 2015 case Intrinsic::aarch64_sve_st1: 2016 return instCombineSVEST1(IC, II, DL); 2017 case Intrinsic::aarch64_sve_sdiv: 2018 return instCombineSVESDIV(IC, II); 2019 case Intrinsic::aarch64_sve_sel: 2020 return instCombineSVESel(IC, II); 2021 case Intrinsic::aarch64_sve_srshl: 2022 return instCombineSVESrshl(IC, II); 2023 case Intrinsic::aarch64_sve_dupq_lane: 2024 return instCombineSVEDupqLane(IC, II); 2025 } 2026 2027 return std::nullopt; 2028 } 2029 2030 std::optional<Value *> AArch64TTIImpl::simplifyDemandedVectorEltsIntrinsic( 2031 InstCombiner &IC, IntrinsicInst &II, APInt OrigDemandedElts, 2032 APInt &UndefElts, APInt &UndefElts2, APInt &UndefElts3, 2033 std::function<void(Instruction *, unsigned, APInt, APInt &)> 2034 SimplifyAndSetOp) const { 2035 switch (II.getIntrinsicID()) { 2036 default: 2037 break; 2038 case Intrinsic::aarch64_neon_fcvtxn: 2039 case Intrinsic::aarch64_neon_rshrn: 2040 case Intrinsic::aarch64_neon_sqrshrn: 2041 case Intrinsic::aarch64_neon_sqrshrun: 2042 case Intrinsic::aarch64_neon_sqshrn: 2043 case Intrinsic::aarch64_neon_sqshrun: 2044 case Intrinsic::aarch64_neon_sqxtn: 2045 case Intrinsic::aarch64_neon_sqxtun: 2046 case Intrinsic::aarch64_neon_uqrshrn: 2047 case Intrinsic::aarch64_neon_uqshrn: 2048 case Intrinsic::aarch64_neon_uqxtn: 2049 SimplifyAndSetOp(&II, 0, OrigDemandedElts, UndefElts); 2050 break; 2051 } 2052 2053 return std::nullopt; 2054 } 2055 2056 TypeSize 2057 AArch64TTIImpl::getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const { 2058 switch (K) { 2059 case TargetTransformInfo::RGK_Scalar: 2060 return TypeSize::getFixed(64); 2061 case TargetTransformInfo::RGK_FixedWidthVector: 2062 if (!ST->isNeonAvailable() && !EnableFixedwidthAutovecInStreamingMode) 2063 return TypeSize::getFixed(0); 2064 2065 if (ST->hasSVE()) 2066 return TypeSize::getFixed( 2067 std::max(ST->getMinSVEVectorSizeInBits(), 128u)); 2068 2069 return TypeSize::getFixed(ST->hasNEON() ? 128 : 0); 2070 case TargetTransformInfo::RGK_ScalableVector: 2071 if (!ST->isSVEAvailable() && !EnableScalableAutovecInStreamingMode) 2072 return TypeSize::getScalable(0); 2073 2074 return TypeSize::getScalable(ST->hasSVE() ? 128 : 0); 2075 } 2076 llvm_unreachable("Unsupported register kind"); 2077 } 2078 2079 bool AArch64TTIImpl::isWideningInstruction(Type *DstTy, unsigned Opcode, 2080 ArrayRef<const Value *> Args, 2081 Type *SrcOverrideTy) { 2082 // A helper that returns a vector type from the given type. The number of 2083 // elements in type Ty determines the vector width. 2084 auto toVectorTy = [&](Type *ArgTy) { 2085 return VectorType::get(ArgTy->getScalarType(), 2086 cast<VectorType>(DstTy)->getElementCount()); 2087 }; 2088 2089 // Exit early if DstTy is not a vector type whose elements are one of [i16, 2090 // i32, i64]. SVE doesn't generally have the same set of instructions to 2091 // perform an extend with the add/sub/mul. There are SMULLB style 2092 // instructions, but they operate on top/bottom, requiring some sort of lane 2093 // interleaving to be used with zext/sext. 2094 unsigned DstEltSize = DstTy->getScalarSizeInBits(); 2095 if (!useNeonVector(DstTy) || Args.size() != 2 || 2096 (DstEltSize != 16 && DstEltSize != 32 && DstEltSize != 64)) 2097 return false; 2098 2099 // Determine if the operation has a widening variant. We consider both the 2100 // "long" (e.g., usubl) and "wide" (e.g., usubw) versions of the 2101 // instructions. 2102 // 2103 // TODO: Add additional widening operations (e.g., shl, etc.) once we 2104 // verify that their extending operands are eliminated during code 2105 // generation. 2106 Type *SrcTy = SrcOverrideTy; 2107 switch (Opcode) { 2108 case Instruction::Add: // UADDL(2), SADDL(2), UADDW(2), SADDW(2). 2109 case Instruction::Sub: // USUBL(2), SSUBL(2), USUBW(2), SSUBW(2). 2110 // The second operand needs to be an extend 2111 if (isa<SExtInst>(Args[1]) || isa<ZExtInst>(Args[1])) { 2112 if (!SrcTy) 2113 SrcTy = 2114 toVectorTy(cast<Instruction>(Args[1])->getOperand(0)->getType()); 2115 } else 2116 return false; 2117 break; 2118 case Instruction::Mul: { // SMULL(2), UMULL(2) 2119 // Both operands need to be extends of the same type. 2120 if ((isa<SExtInst>(Args[0]) && isa<SExtInst>(Args[1])) || 2121 (isa<ZExtInst>(Args[0]) && isa<ZExtInst>(Args[1]))) { 2122 if (!SrcTy) 2123 SrcTy = 2124 toVectorTy(cast<Instruction>(Args[0])->getOperand(0)->getType()); 2125 } else if (isa<ZExtInst>(Args[0]) || isa<ZExtInst>(Args[1])) { 2126 // If one of the operands is a Zext and the other has enough zero bits to 2127 // be treated as unsigned, we can still general a umull, meaning the zext 2128 // is free. 2129 KnownBits Known = 2130 computeKnownBits(isa<ZExtInst>(Args[0]) ? Args[1] : Args[0], DL); 2131 if (Args[0]->getType()->getScalarSizeInBits() - 2132 Known.Zero.countLeadingOnes() > 2133 DstTy->getScalarSizeInBits() / 2) 2134 return false; 2135 if (!SrcTy) 2136 SrcTy = toVectorTy(Type::getIntNTy(DstTy->getContext(), 2137 DstTy->getScalarSizeInBits() / 2)); 2138 } else 2139 return false; 2140 break; 2141 } 2142 default: 2143 return false; 2144 } 2145 2146 // Legalize the destination type and ensure it can be used in a widening 2147 // operation. 2148 auto DstTyL = getTypeLegalizationCost(DstTy); 2149 if (!DstTyL.second.isVector() || DstEltSize != DstTy->getScalarSizeInBits()) 2150 return false; 2151 2152 // Legalize the source type and ensure it can be used in a widening 2153 // operation. 2154 assert(SrcTy && "Expected some SrcTy"); 2155 auto SrcTyL = getTypeLegalizationCost(SrcTy); 2156 unsigned SrcElTySize = SrcTyL.second.getScalarSizeInBits(); 2157 if (!SrcTyL.second.isVector() || SrcElTySize != SrcTy->getScalarSizeInBits()) 2158 return false; 2159 2160 // Get the total number of vector elements in the legalized types. 2161 InstructionCost NumDstEls = 2162 DstTyL.first * DstTyL.second.getVectorMinNumElements(); 2163 InstructionCost NumSrcEls = 2164 SrcTyL.first * SrcTyL.second.getVectorMinNumElements(); 2165 2166 // Return true if the legalized types have the same number of vector elements 2167 // and the destination element type size is twice that of the source type. 2168 return NumDstEls == NumSrcEls && 2 * SrcElTySize == DstEltSize; 2169 } 2170 2171 // s/urhadd instructions implement the following pattern, making the 2172 // extends free: 2173 // %x = add ((zext i8 -> i16), 1) 2174 // %y = (zext i8 -> i16) 2175 // trunc i16 (lshr (add %x, %y), 1) -> i8 2176 // 2177 bool AArch64TTIImpl::isExtPartOfAvgExpr(const Instruction *ExtUser, Type *Dst, 2178 Type *Src) { 2179 // The source should be a legal vector type. 2180 if (!Src->isVectorTy() || !TLI->isTypeLegal(TLI->getValueType(DL, Src)) || 2181 (Src->isScalableTy() && !ST->hasSVE2())) 2182 return false; 2183 2184 if (ExtUser->getOpcode() != Instruction::Add || !ExtUser->hasOneUse()) 2185 return false; 2186 2187 // Look for trunc/shl/add before trying to match the pattern. 2188 const Instruction *Add = ExtUser; 2189 auto *AddUser = 2190 dyn_cast_or_null<Instruction>(Add->getUniqueUndroppableUser()); 2191 if (AddUser && AddUser->getOpcode() == Instruction::Add) 2192 Add = AddUser; 2193 2194 auto *Shr = dyn_cast_or_null<Instruction>(Add->getUniqueUndroppableUser()); 2195 if (!Shr || Shr->getOpcode() != Instruction::LShr) 2196 return false; 2197 2198 auto *Trunc = dyn_cast_or_null<Instruction>(Shr->getUniqueUndroppableUser()); 2199 if (!Trunc || Trunc->getOpcode() != Instruction::Trunc || 2200 Src->getScalarSizeInBits() != 2201 cast<CastInst>(Trunc)->getDestTy()->getScalarSizeInBits()) 2202 return false; 2203 2204 // Try to match the whole pattern. Ext could be either the first or second 2205 // m_ZExtOrSExt matched. 2206 Instruction *Ex1, *Ex2; 2207 if (!(match(Add, m_c_Add(m_Instruction(Ex1), 2208 m_c_Add(m_Instruction(Ex2), m_SpecificInt(1)))))) 2209 return false; 2210 2211 // Ensure both extends are of the same type 2212 if (match(Ex1, m_ZExtOrSExt(m_Value())) && 2213 Ex1->getOpcode() == Ex2->getOpcode()) 2214 return true; 2215 2216 return false; 2217 } 2218 2219 InstructionCost AArch64TTIImpl::getCastInstrCost(unsigned Opcode, Type *Dst, 2220 Type *Src, 2221 TTI::CastContextHint CCH, 2222 TTI::TargetCostKind CostKind, 2223 const Instruction *I) { 2224 int ISD = TLI->InstructionOpcodeToISD(Opcode); 2225 assert(ISD && "Invalid opcode"); 2226 // If the cast is observable, and it is used by a widening instruction (e.g., 2227 // uaddl, saddw, etc.), it may be free. 2228 if (I && I->hasOneUser()) { 2229 auto *SingleUser = cast<Instruction>(*I->user_begin()); 2230 SmallVector<const Value *, 4> Operands(SingleUser->operand_values()); 2231 if (isWideningInstruction(Dst, SingleUser->getOpcode(), Operands, Src)) { 2232 // For adds only count the second operand as free if both operands are 2233 // extends but not the same operation. (i.e both operands are not free in 2234 // add(sext, zext)). 2235 if (SingleUser->getOpcode() == Instruction::Add) { 2236 if (I == SingleUser->getOperand(1) || 2237 (isa<CastInst>(SingleUser->getOperand(1)) && 2238 cast<CastInst>(SingleUser->getOperand(1))->getOpcode() == Opcode)) 2239 return 0; 2240 } else // Others are free so long as isWideningInstruction returned true. 2241 return 0; 2242 } 2243 2244 // The cast will be free for the s/urhadd instructions 2245 if ((isa<ZExtInst>(I) || isa<SExtInst>(I)) && 2246 isExtPartOfAvgExpr(SingleUser, Dst, Src)) 2247 return 0; 2248 } 2249 2250 // TODO: Allow non-throughput costs that aren't binary. 2251 auto AdjustCost = [&CostKind](InstructionCost Cost) -> InstructionCost { 2252 if (CostKind != TTI::TCK_RecipThroughput) 2253 return Cost == 0 ? 0 : 1; 2254 return Cost; 2255 }; 2256 2257 EVT SrcTy = TLI->getValueType(DL, Src); 2258 EVT DstTy = TLI->getValueType(DL, Dst); 2259 2260 if (!SrcTy.isSimple() || !DstTy.isSimple()) 2261 return AdjustCost( 2262 BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I)); 2263 2264 static const TypeConversionCostTblEntry 2265 ConversionTbl[] = { 2266 { ISD::TRUNCATE, MVT::v2i8, MVT::v2i64, 1}, // xtn 2267 { ISD::TRUNCATE, MVT::v2i16, MVT::v2i64, 1}, // xtn 2268 { ISD::TRUNCATE, MVT::v2i32, MVT::v2i64, 1}, // xtn 2269 { ISD::TRUNCATE, MVT::v4i8, MVT::v4i32, 1}, // xtn 2270 { ISD::TRUNCATE, MVT::v4i8, MVT::v4i64, 3}, // 2 xtn + 1 uzp1 2271 { ISD::TRUNCATE, MVT::v4i16, MVT::v4i32, 1}, // xtn 2272 { ISD::TRUNCATE, MVT::v4i16, MVT::v4i64, 2}, // 1 uzp1 + 1 xtn 2273 { ISD::TRUNCATE, MVT::v4i32, MVT::v4i64, 1}, // 1 uzp1 2274 { ISD::TRUNCATE, MVT::v8i8, MVT::v8i16, 1}, // 1 xtn 2275 { ISD::TRUNCATE, MVT::v8i8, MVT::v8i32, 2}, // 1 uzp1 + 1 xtn 2276 { ISD::TRUNCATE, MVT::v8i8, MVT::v8i64, 4}, // 3 x uzp1 + xtn 2277 { ISD::TRUNCATE, MVT::v8i16, MVT::v8i32, 1}, // 1 uzp1 2278 { ISD::TRUNCATE, MVT::v8i16, MVT::v8i64, 3}, // 3 x uzp1 2279 { ISD::TRUNCATE, MVT::v8i32, MVT::v8i64, 2}, // 2 x uzp1 2280 { ISD::TRUNCATE, MVT::v16i8, MVT::v16i16, 1}, // uzp1 2281 { ISD::TRUNCATE, MVT::v16i8, MVT::v16i32, 3}, // (2 + 1) x uzp1 2282 { ISD::TRUNCATE, MVT::v16i8, MVT::v16i64, 7}, // (4 + 2 + 1) x uzp1 2283 { ISD::TRUNCATE, MVT::v16i16, MVT::v16i32, 2}, // 2 x uzp1 2284 { ISD::TRUNCATE, MVT::v16i16, MVT::v16i64, 6}, // (4 + 2) x uzp1 2285 { ISD::TRUNCATE, MVT::v16i32, MVT::v16i64, 4}, // 4 x uzp1 2286 2287 // Truncations on nxvmiN 2288 { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i16, 1 }, 2289 { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i32, 1 }, 2290 { ISD::TRUNCATE, MVT::nxv2i1, MVT::nxv2i64, 1 }, 2291 { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i16, 1 }, 2292 { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i32, 1 }, 2293 { ISD::TRUNCATE, MVT::nxv4i1, MVT::nxv4i64, 2 }, 2294 { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i16, 1 }, 2295 { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i32, 3 }, 2296 { ISD::TRUNCATE, MVT::nxv8i1, MVT::nxv8i64, 5 }, 2297 { ISD::TRUNCATE, MVT::nxv16i1, MVT::nxv16i8, 1 }, 2298 { ISD::TRUNCATE, MVT::nxv2i16, MVT::nxv2i32, 1 }, 2299 { ISD::TRUNCATE, MVT::nxv2i32, MVT::nxv2i64, 1 }, 2300 { ISD::TRUNCATE, MVT::nxv4i16, MVT::nxv4i32, 1 }, 2301 { ISD::TRUNCATE, MVT::nxv4i32, MVT::nxv4i64, 2 }, 2302 { ISD::TRUNCATE, MVT::nxv8i16, MVT::nxv8i32, 3 }, 2303 { ISD::TRUNCATE, MVT::nxv8i32, MVT::nxv8i64, 6 }, 2304 2305 // The number of shll instructions for the extension. 2306 { ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i16, 3 }, 2307 { ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i16, 3 }, 2308 { ISD::SIGN_EXTEND, MVT::v4i64, MVT::v4i32, 2 }, 2309 { ISD::ZERO_EXTEND, MVT::v4i64, MVT::v4i32, 2 }, 2310 { ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i8, 3 }, 2311 { ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i8, 3 }, 2312 { ISD::SIGN_EXTEND, MVT::v8i32, MVT::v8i16, 2 }, 2313 { ISD::ZERO_EXTEND, MVT::v8i32, MVT::v8i16, 2 }, 2314 { ISD::SIGN_EXTEND, MVT::v8i64, MVT::v8i8, 7 }, 2315 { ISD::ZERO_EXTEND, MVT::v8i64, MVT::v8i8, 7 }, 2316 { ISD::SIGN_EXTEND, MVT::v8i64, MVT::v8i16, 6 }, 2317 { ISD::ZERO_EXTEND, MVT::v8i64, MVT::v8i16, 6 }, 2318 { ISD::SIGN_EXTEND, MVT::v16i16, MVT::v16i8, 2 }, 2319 { ISD::ZERO_EXTEND, MVT::v16i16, MVT::v16i8, 2 }, 2320 { ISD::SIGN_EXTEND, MVT::v16i32, MVT::v16i8, 6 }, 2321 { ISD::ZERO_EXTEND, MVT::v16i32, MVT::v16i8, 6 }, 2322 2323 // LowerVectorINT_TO_FP: 2324 { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 }, 2325 { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 }, 2326 { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 }, 2327 { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i32, 1 }, 2328 { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i32, 1 }, 2329 { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i64, 1 }, 2330 2331 // Complex: to v2f32 2332 { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i8, 3 }, 2333 { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 }, 2334 { ISD::SINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 }, 2335 { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i8, 3 }, 2336 { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i16, 3 }, 2337 { ISD::UINT_TO_FP, MVT::v2f32, MVT::v2i64, 2 }, 2338 2339 // Complex: to v4f32 2340 { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i8, 4 }, 2341 { ISD::SINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 }, 2342 { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i8, 3 }, 2343 { ISD::UINT_TO_FP, MVT::v4f32, MVT::v4i16, 2 }, 2344 2345 // Complex: to v8f32 2346 { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i8, 10 }, 2347 { ISD::SINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 }, 2348 { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i8, 10 }, 2349 { ISD::UINT_TO_FP, MVT::v8f32, MVT::v8i16, 4 }, 2350 2351 // Complex: to v16f32 2352 { ISD::SINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 }, 2353 { ISD::UINT_TO_FP, MVT::v16f32, MVT::v16i8, 21 }, 2354 2355 // Complex: to v2f64 2356 { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i8, 4 }, 2357 { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 }, 2358 { ISD::SINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 }, 2359 { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i8, 4 }, 2360 { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i16, 4 }, 2361 { ISD::UINT_TO_FP, MVT::v2f64, MVT::v2i32, 2 }, 2362 2363 // Complex: to v4f64 2364 { ISD::SINT_TO_FP, MVT::v4f64, MVT::v4i32, 4 }, 2365 { ISD::UINT_TO_FP, MVT::v4f64, MVT::v4i32, 4 }, 2366 2367 // LowerVectorFP_TO_INT 2368 { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f32, 1 }, 2369 { ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f32, 1 }, 2370 { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f64, 1 }, 2371 { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f32, 1 }, 2372 { ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f32, 1 }, 2373 { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f64, 1 }, 2374 2375 // Complex, from v2f32: legal type is v2i32 (no cost) or v2i64 (1 ext). 2376 { ISD::FP_TO_SINT, MVT::v2i64, MVT::v2f32, 2 }, 2377 { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f32, 1 }, 2378 { ISD::FP_TO_SINT, MVT::v2i8, MVT::v2f32, 1 }, 2379 { ISD::FP_TO_UINT, MVT::v2i64, MVT::v2f32, 2 }, 2380 { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f32, 1 }, 2381 { ISD::FP_TO_UINT, MVT::v2i8, MVT::v2f32, 1 }, 2382 2383 // Complex, from v4f32: legal type is v4i16, 1 narrowing => ~2 2384 { ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f32, 2 }, 2385 { ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f32, 2 }, 2386 { ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f32, 2 }, 2387 { ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f32, 2 }, 2388 2389 // Complex, from nxv2f32. 2390 { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f32, 1 }, 2391 { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f32, 1 }, 2392 { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f32, 1 }, 2393 { ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f32, 1 }, 2394 { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f32, 1 }, 2395 { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f32, 1 }, 2396 { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f32, 1 }, 2397 { ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f32, 1 }, 2398 2399 // Complex, from v2f64: legal type is v2i32, 1 narrowing => ~2. 2400 { ISD::FP_TO_SINT, MVT::v2i32, MVT::v2f64, 2 }, 2401 { ISD::FP_TO_SINT, MVT::v2i16, MVT::v2f64, 2 }, 2402 { ISD::FP_TO_SINT, MVT::v2i8, MVT::v2f64, 2 }, 2403 { ISD::FP_TO_UINT, MVT::v2i32, MVT::v2f64, 2 }, 2404 { ISD::FP_TO_UINT, MVT::v2i16, MVT::v2f64, 2 }, 2405 { ISD::FP_TO_UINT, MVT::v2i8, MVT::v2f64, 2 }, 2406 2407 // Complex, from nxv2f64. 2408 { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f64, 1 }, 2409 { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f64, 1 }, 2410 { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f64, 1 }, 2411 { ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f64, 1 }, 2412 { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f64, 1 }, 2413 { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f64, 1 }, 2414 { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f64, 1 }, 2415 { ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f64, 1 }, 2416 2417 // Complex, from nxv4f32. 2418 { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f32, 4 }, 2419 { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f32, 1 }, 2420 { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f32, 1 }, 2421 { ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f32, 1 }, 2422 { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f32, 4 }, 2423 { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f32, 1 }, 2424 { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f32, 1 }, 2425 { ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f32, 1 }, 2426 2427 // Complex, from nxv8f64. Illegal -> illegal conversions not required. 2428 { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f64, 7 }, 2429 { ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f64, 7 }, 2430 { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f64, 7 }, 2431 { ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f64, 7 }, 2432 2433 // Complex, from nxv4f64. Illegal -> illegal conversions not required. 2434 { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f64, 3 }, 2435 { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f64, 3 }, 2436 { ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f64, 3 }, 2437 { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f64, 3 }, 2438 { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f64, 3 }, 2439 { ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f64, 3 }, 2440 2441 // Complex, from nxv8f32. Illegal -> illegal conversions not required. 2442 { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f32, 3 }, 2443 { ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f32, 3 }, 2444 { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f32, 3 }, 2445 { ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f32, 3 }, 2446 2447 // Complex, from nxv8f16. 2448 { ISD::FP_TO_SINT, MVT::nxv8i64, MVT::nxv8f16, 10 }, 2449 { ISD::FP_TO_SINT, MVT::nxv8i32, MVT::nxv8f16, 4 }, 2450 { ISD::FP_TO_SINT, MVT::nxv8i16, MVT::nxv8f16, 1 }, 2451 { ISD::FP_TO_SINT, MVT::nxv8i8, MVT::nxv8f16, 1 }, 2452 { ISD::FP_TO_UINT, MVT::nxv8i64, MVT::nxv8f16, 10 }, 2453 { ISD::FP_TO_UINT, MVT::nxv8i32, MVT::nxv8f16, 4 }, 2454 { ISD::FP_TO_UINT, MVT::nxv8i16, MVT::nxv8f16, 1 }, 2455 { ISD::FP_TO_UINT, MVT::nxv8i8, MVT::nxv8f16, 1 }, 2456 2457 // Complex, from nxv4f16. 2458 { ISD::FP_TO_SINT, MVT::nxv4i64, MVT::nxv4f16, 4 }, 2459 { ISD::FP_TO_SINT, MVT::nxv4i32, MVT::nxv4f16, 1 }, 2460 { ISD::FP_TO_SINT, MVT::nxv4i16, MVT::nxv4f16, 1 }, 2461 { ISD::FP_TO_SINT, MVT::nxv4i8, MVT::nxv4f16, 1 }, 2462 { ISD::FP_TO_UINT, MVT::nxv4i64, MVT::nxv4f16, 4 }, 2463 { ISD::FP_TO_UINT, MVT::nxv4i32, MVT::nxv4f16, 1 }, 2464 { ISD::FP_TO_UINT, MVT::nxv4i16, MVT::nxv4f16, 1 }, 2465 { ISD::FP_TO_UINT, MVT::nxv4i8, MVT::nxv4f16, 1 }, 2466 2467 // Complex, from nxv2f16. 2468 { ISD::FP_TO_SINT, MVT::nxv2i64, MVT::nxv2f16, 1 }, 2469 { ISD::FP_TO_SINT, MVT::nxv2i32, MVT::nxv2f16, 1 }, 2470 { ISD::FP_TO_SINT, MVT::nxv2i16, MVT::nxv2f16, 1 }, 2471 { ISD::FP_TO_SINT, MVT::nxv2i8, MVT::nxv2f16, 1 }, 2472 { ISD::FP_TO_UINT, MVT::nxv2i64, MVT::nxv2f16, 1 }, 2473 { ISD::FP_TO_UINT, MVT::nxv2i32, MVT::nxv2f16, 1 }, 2474 { ISD::FP_TO_UINT, MVT::nxv2i16, MVT::nxv2f16, 1 }, 2475 { ISD::FP_TO_UINT, MVT::nxv2i8, MVT::nxv2f16, 1 }, 2476 2477 // Truncate from nxvmf32 to nxvmf16. 2478 { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f32, 1 }, 2479 { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f32, 1 }, 2480 { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f32, 3 }, 2481 2482 // Truncate from nxvmf64 to nxvmf16. 2483 { ISD::FP_ROUND, MVT::nxv2f16, MVT::nxv2f64, 1 }, 2484 { ISD::FP_ROUND, MVT::nxv4f16, MVT::nxv4f64, 3 }, 2485 { ISD::FP_ROUND, MVT::nxv8f16, MVT::nxv8f64, 7 }, 2486 2487 // Truncate from nxvmf64 to nxvmf32. 2488 { ISD::FP_ROUND, MVT::nxv2f32, MVT::nxv2f64, 1 }, 2489 { ISD::FP_ROUND, MVT::nxv4f32, MVT::nxv4f64, 3 }, 2490 { ISD::FP_ROUND, MVT::nxv8f32, MVT::nxv8f64, 6 }, 2491 2492 // Extend from nxvmf16 to nxvmf32. 2493 { ISD::FP_EXTEND, MVT::nxv2f32, MVT::nxv2f16, 1}, 2494 { ISD::FP_EXTEND, MVT::nxv4f32, MVT::nxv4f16, 1}, 2495 { ISD::FP_EXTEND, MVT::nxv8f32, MVT::nxv8f16, 2}, 2496 2497 // Extend from nxvmf16 to nxvmf64. 2498 { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f16, 1}, 2499 { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f16, 2}, 2500 { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f16, 4}, 2501 2502 // Extend from nxvmf32 to nxvmf64. 2503 { ISD::FP_EXTEND, MVT::nxv2f64, MVT::nxv2f32, 1}, 2504 { ISD::FP_EXTEND, MVT::nxv4f64, MVT::nxv4f32, 2}, 2505 { ISD::FP_EXTEND, MVT::nxv8f64, MVT::nxv8f32, 6}, 2506 2507 // Bitcasts from float to integer 2508 { ISD::BITCAST, MVT::nxv2f16, MVT::nxv2i16, 0 }, 2509 { ISD::BITCAST, MVT::nxv4f16, MVT::nxv4i16, 0 }, 2510 { ISD::BITCAST, MVT::nxv2f32, MVT::nxv2i32, 0 }, 2511 2512 // Bitcasts from integer to float 2513 { ISD::BITCAST, MVT::nxv2i16, MVT::nxv2f16, 0 }, 2514 { ISD::BITCAST, MVT::nxv4i16, MVT::nxv4f16, 0 }, 2515 { ISD::BITCAST, MVT::nxv2i32, MVT::nxv2f32, 0 }, 2516 2517 // Add cost for extending to illegal -too wide- scalable vectors. 2518 // zero/sign extend are implemented by multiple unpack operations, 2519 // where each operation has a cost of 1. 2520 { ISD::ZERO_EXTEND, MVT::nxv16i16, MVT::nxv16i8, 2}, 2521 { ISD::ZERO_EXTEND, MVT::nxv16i32, MVT::nxv16i8, 6}, 2522 { ISD::ZERO_EXTEND, MVT::nxv16i64, MVT::nxv16i8, 14}, 2523 { ISD::ZERO_EXTEND, MVT::nxv8i32, MVT::nxv8i16, 2}, 2524 { ISD::ZERO_EXTEND, MVT::nxv8i64, MVT::nxv8i16, 6}, 2525 { ISD::ZERO_EXTEND, MVT::nxv4i64, MVT::nxv4i32, 2}, 2526 2527 { ISD::SIGN_EXTEND, MVT::nxv16i16, MVT::nxv16i8, 2}, 2528 { ISD::SIGN_EXTEND, MVT::nxv16i32, MVT::nxv16i8, 6}, 2529 { ISD::SIGN_EXTEND, MVT::nxv16i64, MVT::nxv16i8, 14}, 2530 { ISD::SIGN_EXTEND, MVT::nxv8i32, MVT::nxv8i16, 2}, 2531 { ISD::SIGN_EXTEND, MVT::nxv8i64, MVT::nxv8i16, 6}, 2532 { ISD::SIGN_EXTEND, MVT::nxv4i64, MVT::nxv4i32, 2}, 2533 }; 2534 2535 // We have to estimate a cost of fixed length operation upon 2536 // SVE registers(operations) with the number of registers required 2537 // for a fixed type to be represented upon SVE registers. 2538 EVT WiderTy = SrcTy.bitsGT(DstTy) ? SrcTy : DstTy; 2539 if (SrcTy.isFixedLengthVector() && DstTy.isFixedLengthVector() && 2540 SrcTy.getVectorNumElements() == DstTy.getVectorNumElements() && 2541 ST->useSVEForFixedLengthVectors(WiderTy)) { 2542 std::pair<InstructionCost, MVT> LT = 2543 getTypeLegalizationCost(WiderTy.getTypeForEVT(Dst->getContext())); 2544 unsigned NumElements = AArch64::SVEBitsPerBlock / 2545 LT.second.getVectorElementType().getSizeInBits(); 2546 return AdjustCost( 2547 LT.first * 2548 getCastInstrCost( 2549 Opcode, ScalableVectorType::get(Dst->getScalarType(), NumElements), 2550 ScalableVectorType::get(Src->getScalarType(), NumElements), CCH, 2551 CostKind, I)); 2552 } 2553 2554 if (const auto *Entry = ConvertCostTableLookup(ConversionTbl, ISD, 2555 DstTy.getSimpleVT(), 2556 SrcTy.getSimpleVT())) 2557 return AdjustCost(Entry->Cost); 2558 2559 static const TypeConversionCostTblEntry FP16Tbl[] = { 2560 {ISD::FP_TO_SINT, MVT::v4i8, MVT::v4f16, 1}, // fcvtzs 2561 {ISD::FP_TO_UINT, MVT::v4i8, MVT::v4f16, 1}, 2562 {ISD::FP_TO_SINT, MVT::v4i16, MVT::v4f16, 1}, // fcvtzs 2563 {ISD::FP_TO_UINT, MVT::v4i16, MVT::v4f16, 1}, 2564 {ISD::FP_TO_SINT, MVT::v4i32, MVT::v4f16, 2}, // fcvtl+fcvtzs 2565 {ISD::FP_TO_UINT, MVT::v4i32, MVT::v4f16, 2}, 2566 {ISD::FP_TO_SINT, MVT::v8i8, MVT::v8f16, 2}, // fcvtzs+xtn 2567 {ISD::FP_TO_UINT, MVT::v8i8, MVT::v8f16, 2}, 2568 {ISD::FP_TO_SINT, MVT::v8i16, MVT::v8f16, 1}, // fcvtzs 2569 {ISD::FP_TO_UINT, MVT::v8i16, MVT::v8f16, 1}, 2570 {ISD::FP_TO_SINT, MVT::v8i32, MVT::v8f16, 4}, // 2*fcvtl+2*fcvtzs 2571 {ISD::FP_TO_UINT, MVT::v8i32, MVT::v8f16, 4}, 2572 {ISD::FP_TO_SINT, MVT::v16i8, MVT::v16f16, 3}, // 2*fcvtzs+xtn 2573 {ISD::FP_TO_UINT, MVT::v16i8, MVT::v16f16, 3}, 2574 {ISD::FP_TO_SINT, MVT::v16i16, MVT::v16f16, 2}, // 2*fcvtzs 2575 {ISD::FP_TO_UINT, MVT::v16i16, MVT::v16f16, 2}, 2576 {ISD::FP_TO_SINT, MVT::v16i32, MVT::v16f16, 8}, // 4*fcvtl+4*fcvtzs 2577 {ISD::FP_TO_UINT, MVT::v16i32, MVT::v16f16, 8}, 2578 {ISD::UINT_TO_FP, MVT::v8f16, MVT::v8i8, 2}, // ushll + ucvtf 2579 {ISD::SINT_TO_FP, MVT::v8f16, MVT::v8i8, 2}, // sshll + scvtf 2580 {ISD::UINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * ushl(2) + 2 * ucvtf 2581 {ISD::SINT_TO_FP, MVT::v16f16, MVT::v16i8, 4}, // 2 * sshl(2) + 2 * scvtf 2582 }; 2583 2584 if (ST->hasFullFP16()) 2585 if (const auto *Entry = ConvertCostTableLookup( 2586 FP16Tbl, ISD, DstTy.getSimpleVT(), SrcTy.getSimpleVT())) 2587 return AdjustCost(Entry->Cost); 2588 2589 if ((ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND) && 2590 CCH == TTI::CastContextHint::Masked && ST->hasSVEorSME() && 2591 TLI->getTypeAction(Src->getContext(), SrcTy) == 2592 TargetLowering::TypePromoteInteger && 2593 TLI->getTypeAction(Dst->getContext(), DstTy) == 2594 TargetLowering::TypeSplitVector) { 2595 // The standard behaviour in the backend for these cases is to split the 2596 // extend up into two parts: 2597 // 1. Perform an extending load or masked load up to the legal type. 2598 // 2. Extend the loaded data to the final type. 2599 std::pair<InstructionCost, MVT> SrcLT = getTypeLegalizationCost(Src); 2600 Type *LegalTy = EVT(SrcLT.second).getTypeForEVT(Src->getContext()); 2601 InstructionCost Part1 = AArch64TTIImpl::getCastInstrCost( 2602 Opcode, LegalTy, Src, CCH, CostKind, I); 2603 InstructionCost Part2 = AArch64TTIImpl::getCastInstrCost( 2604 Opcode, Dst, LegalTy, TTI::CastContextHint::None, CostKind, I); 2605 return Part1 + Part2; 2606 } 2607 2608 // The BasicTTIImpl version only deals with CCH==TTI::CastContextHint::Normal, 2609 // but we also want to include the TTI::CastContextHint::Masked case too. 2610 if ((ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND) && 2611 CCH == TTI::CastContextHint::Masked && ST->hasSVEorSME() && 2612 TLI->isTypeLegal(DstTy)) 2613 CCH = TTI::CastContextHint::Normal; 2614 2615 return AdjustCost( 2616 BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I)); 2617 } 2618 2619 InstructionCost AArch64TTIImpl::getExtractWithExtendCost(unsigned Opcode, 2620 Type *Dst, 2621 VectorType *VecTy, 2622 unsigned Index) { 2623 2624 // Make sure we were given a valid extend opcode. 2625 assert((Opcode == Instruction::SExt || Opcode == Instruction::ZExt) && 2626 "Invalid opcode"); 2627 2628 // We are extending an element we extract from a vector, so the source type 2629 // of the extend is the element type of the vector. 2630 auto *Src = VecTy->getElementType(); 2631 2632 // Sign- and zero-extends are for integer types only. 2633 assert(isa<IntegerType>(Dst) && isa<IntegerType>(Src) && "Invalid type"); 2634 2635 // Get the cost for the extract. We compute the cost (if any) for the extend 2636 // below. 2637 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; 2638 InstructionCost Cost = getVectorInstrCost(Instruction::ExtractElement, VecTy, 2639 CostKind, Index, nullptr, nullptr); 2640 2641 // Legalize the types. 2642 auto VecLT = getTypeLegalizationCost(VecTy); 2643 auto DstVT = TLI->getValueType(DL, Dst); 2644 auto SrcVT = TLI->getValueType(DL, Src); 2645 2646 // If the resulting type is still a vector and the destination type is legal, 2647 // we may get the extension for free. If not, get the default cost for the 2648 // extend. 2649 if (!VecLT.second.isVector() || !TLI->isTypeLegal(DstVT)) 2650 return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None, 2651 CostKind); 2652 2653 // The destination type should be larger than the element type. If not, get 2654 // the default cost for the extend. 2655 if (DstVT.getFixedSizeInBits() < SrcVT.getFixedSizeInBits()) 2656 return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None, 2657 CostKind); 2658 2659 switch (Opcode) { 2660 default: 2661 llvm_unreachable("Opcode should be either SExt or ZExt"); 2662 2663 // For sign-extends, we only need a smov, which performs the extension 2664 // automatically. 2665 case Instruction::SExt: 2666 return Cost; 2667 2668 // For zero-extends, the extend is performed automatically by a umov unless 2669 // the destination type is i64 and the element type is i8 or i16. 2670 case Instruction::ZExt: 2671 if (DstVT.getSizeInBits() != 64u || SrcVT.getSizeInBits() == 32u) 2672 return Cost; 2673 } 2674 2675 // If we are unable to perform the extend for free, get the default cost. 2676 return Cost + getCastInstrCost(Opcode, Dst, Src, TTI::CastContextHint::None, 2677 CostKind); 2678 } 2679 2680 InstructionCost AArch64TTIImpl::getCFInstrCost(unsigned Opcode, 2681 TTI::TargetCostKind CostKind, 2682 const Instruction *I) { 2683 if (CostKind != TTI::TCK_RecipThroughput) 2684 return Opcode == Instruction::PHI ? 0 : 1; 2685 assert(CostKind == TTI::TCK_RecipThroughput && "unexpected CostKind"); 2686 // Branches are assumed to be predicted. 2687 return 0; 2688 } 2689 2690 InstructionCost AArch64TTIImpl::getVectorInstrCostHelper(const Instruction *I, 2691 Type *Val, 2692 unsigned Index, 2693 bool HasRealUse) { 2694 assert(Val->isVectorTy() && "This must be a vector type"); 2695 2696 if (Index != -1U) { 2697 // Legalize the type. 2698 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Val); 2699 2700 // This type is legalized to a scalar type. 2701 if (!LT.second.isVector()) 2702 return 0; 2703 2704 // The type may be split. For fixed-width vectors we can normalize the 2705 // index to the new type. 2706 if (LT.second.isFixedLengthVector()) { 2707 unsigned Width = LT.second.getVectorNumElements(); 2708 Index = Index % Width; 2709 } 2710 2711 // The element at index zero is already inside the vector. 2712 // - For a physical (HasRealUse==true) insert-element or extract-element 2713 // instruction that extracts integers, an explicit FPR -> GPR move is 2714 // needed. So it has non-zero cost. 2715 // - For the rest of cases (virtual instruction or element type is float), 2716 // consider the instruction free. 2717 if (Index == 0 && (!HasRealUse || !Val->getScalarType()->isIntegerTy())) 2718 return 0; 2719 2720 // This is recognising a LD1 single-element structure to one lane of one 2721 // register instruction. I.e., if this is an `insertelement` instruction, 2722 // and its second operand is a load, then we will generate a LD1, which 2723 // are expensive instructions. 2724 if (I && dyn_cast<LoadInst>(I->getOperand(1))) 2725 return ST->getVectorInsertExtractBaseCost() + 1; 2726 2727 // i1 inserts and extract will include an extra cset or cmp of the vector 2728 // value. Increase the cost by 1 to account. 2729 if (Val->getScalarSizeInBits() == 1) 2730 return ST->getVectorInsertExtractBaseCost() + 1; 2731 2732 // FIXME: 2733 // If the extract-element and insert-element instructions could be 2734 // simplified away (e.g., could be combined into users by looking at use-def 2735 // context), they have no cost. This is not done in the first place for 2736 // compile-time considerations. 2737 } 2738 2739 // All other insert/extracts cost this much. 2740 return ST->getVectorInsertExtractBaseCost(); 2741 } 2742 2743 InstructionCost AArch64TTIImpl::getVectorInstrCost(unsigned Opcode, Type *Val, 2744 TTI::TargetCostKind CostKind, 2745 unsigned Index, Value *Op0, 2746 Value *Op1) { 2747 bool HasRealUse = 2748 Opcode == Instruction::InsertElement && Op0 && !isa<UndefValue>(Op0); 2749 return getVectorInstrCostHelper(nullptr, Val, Index, HasRealUse); 2750 } 2751 2752 InstructionCost AArch64TTIImpl::getVectorInstrCost(const Instruction &I, 2753 Type *Val, 2754 TTI::TargetCostKind CostKind, 2755 unsigned Index) { 2756 return getVectorInstrCostHelper(&I, Val, Index, true /* HasRealUse */); 2757 } 2758 2759 InstructionCost AArch64TTIImpl::getScalarizationOverhead( 2760 VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract, 2761 TTI::TargetCostKind CostKind) { 2762 if (isa<ScalableVectorType>(Ty)) 2763 return InstructionCost::getInvalid(); 2764 if (Ty->getElementType()->isFloatingPointTy()) 2765 return BaseT::getScalarizationOverhead(Ty, DemandedElts, Insert, Extract, 2766 CostKind); 2767 return DemandedElts.popcount() * (Insert + Extract) * 2768 ST->getVectorInsertExtractBaseCost(); 2769 } 2770 2771 InstructionCost AArch64TTIImpl::getArithmeticInstrCost( 2772 unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind, 2773 TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info, 2774 ArrayRef<const Value *> Args, 2775 const Instruction *CxtI) { 2776 2777 // TODO: Handle more cost kinds. 2778 if (CostKind != TTI::TCK_RecipThroughput) 2779 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, 2780 Op2Info, Args, CxtI); 2781 2782 // Legalize the type. 2783 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty); 2784 int ISD = TLI->InstructionOpcodeToISD(Opcode); 2785 2786 switch (ISD) { 2787 default: 2788 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, 2789 Op2Info); 2790 case ISD::SDIV: 2791 if (Op2Info.isConstant() && Op2Info.isUniform() && Op2Info.isPowerOf2()) { 2792 // On AArch64, scalar signed division by constants power-of-two are 2793 // normally expanded to the sequence ADD + CMP + SELECT + SRA. 2794 // The OperandValue properties many not be same as that of previous 2795 // operation; conservatively assume OP_None. 2796 InstructionCost Cost = getArithmeticInstrCost( 2797 Instruction::Add, Ty, CostKind, 2798 Op1Info.getNoProps(), Op2Info.getNoProps()); 2799 Cost += getArithmeticInstrCost(Instruction::Sub, Ty, CostKind, 2800 Op1Info.getNoProps(), Op2Info.getNoProps()); 2801 Cost += getArithmeticInstrCost( 2802 Instruction::Select, Ty, CostKind, 2803 Op1Info.getNoProps(), Op2Info.getNoProps()); 2804 Cost += getArithmeticInstrCost(Instruction::AShr, Ty, CostKind, 2805 Op1Info.getNoProps(), Op2Info.getNoProps()); 2806 return Cost; 2807 } 2808 [[fallthrough]]; 2809 case ISD::UDIV: { 2810 if (Op2Info.isConstant() && Op2Info.isUniform()) { 2811 auto VT = TLI->getValueType(DL, Ty); 2812 if (TLI->isOperationLegalOrCustom(ISD::MULHU, VT)) { 2813 // Vector signed division by constant are expanded to the 2814 // sequence MULHS + ADD/SUB + SRA + SRL + ADD, and unsigned division 2815 // to MULHS + SUB + SRL + ADD + SRL. 2816 InstructionCost MulCost = getArithmeticInstrCost( 2817 Instruction::Mul, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps()); 2818 InstructionCost AddCost = getArithmeticInstrCost( 2819 Instruction::Add, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps()); 2820 InstructionCost ShrCost = getArithmeticInstrCost( 2821 Instruction::AShr, Ty, CostKind, Op1Info.getNoProps(), Op2Info.getNoProps()); 2822 return MulCost * 2 + AddCost * 2 + ShrCost * 2 + 1; 2823 } 2824 } 2825 2826 InstructionCost Cost = BaseT::getArithmeticInstrCost( 2827 Opcode, Ty, CostKind, Op1Info, Op2Info); 2828 if (Ty->isVectorTy()) { 2829 if (TLI->isOperationLegalOrCustom(ISD, LT.second) && ST->hasSVE()) { 2830 // SDIV/UDIV operations are lowered using SVE, then we can have less 2831 // costs. 2832 if (isa<FixedVectorType>(Ty) && cast<FixedVectorType>(Ty) 2833 ->getPrimitiveSizeInBits() 2834 .getFixedValue() < 128) { 2835 EVT VT = TLI->getValueType(DL, Ty); 2836 static const CostTblEntry DivTbl[]{ 2837 {ISD::SDIV, MVT::v2i8, 5}, {ISD::SDIV, MVT::v4i8, 8}, 2838 {ISD::SDIV, MVT::v8i8, 8}, {ISD::SDIV, MVT::v2i16, 5}, 2839 {ISD::SDIV, MVT::v4i16, 5}, {ISD::SDIV, MVT::v2i32, 1}, 2840 {ISD::UDIV, MVT::v2i8, 5}, {ISD::UDIV, MVT::v4i8, 8}, 2841 {ISD::UDIV, MVT::v8i8, 8}, {ISD::UDIV, MVT::v2i16, 5}, 2842 {ISD::UDIV, MVT::v4i16, 5}, {ISD::UDIV, MVT::v2i32, 1}}; 2843 2844 const auto *Entry = CostTableLookup(DivTbl, ISD, VT.getSimpleVT()); 2845 if (nullptr != Entry) 2846 return Entry->Cost; 2847 } 2848 // For 8/16-bit elements, the cost is higher because the type 2849 // requires promotion and possibly splitting: 2850 if (LT.second.getScalarType() == MVT::i8) 2851 Cost *= 8; 2852 else if (LT.second.getScalarType() == MVT::i16) 2853 Cost *= 4; 2854 return Cost; 2855 } else { 2856 // If one of the operands is a uniform constant then the cost for each 2857 // element is Cost for insertion, extraction and division. 2858 // Insertion cost = 2, Extraction Cost = 2, Division = cost for the 2859 // operation with scalar type 2860 if ((Op1Info.isConstant() && Op1Info.isUniform()) || 2861 (Op2Info.isConstant() && Op2Info.isUniform())) { 2862 if (auto *VTy = dyn_cast<FixedVectorType>(Ty)) { 2863 InstructionCost DivCost = BaseT::getArithmeticInstrCost( 2864 Opcode, Ty->getScalarType(), CostKind, Op1Info, Op2Info); 2865 return (4 + DivCost) * VTy->getNumElements(); 2866 } 2867 } 2868 // On AArch64, without SVE, vector divisions are expanded 2869 // into scalar divisions of each pair of elements. 2870 Cost += getArithmeticInstrCost(Instruction::ExtractElement, Ty, 2871 CostKind, Op1Info, Op2Info); 2872 Cost += getArithmeticInstrCost(Instruction::InsertElement, Ty, CostKind, 2873 Op1Info, Op2Info); 2874 } 2875 2876 // TODO: if one of the arguments is scalar, then it's not necessary to 2877 // double the cost of handling the vector elements. 2878 Cost += Cost; 2879 } 2880 return Cost; 2881 } 2882 case ISD::MUL: 2883 // When SVE is available, then we can lower the v2i64 operation using 2884 // the SVE mul instruction, which has a lower cost. 2885 if (LT.second == MVT::v2i64 && ST->hasSVE()) 2886 return LT.first; 2887 2888 // When SVE is not available, there is no MUL.2d instruction, 2889 // which means mul <2 x i64> is expensive as elements are extracted 2890 // from the vectors and the muls scalarized. 2891 // As getScalarizationOverhead is a bit too pessimistic, we 2892 // estimate the cost for a i64 vector directly here, which is: 2893 // - four 2-cost i64 extracts, 2894 // - two 2-cost i64 inserts, and 2895 // - two 1-cost muls. 2896 // So, for a v2i64 with LT.First = 1 the cost is 14, and for a v4i64 with 2897 // LT.first = 2 the cost is 28. If both operands are extensions it will not 2898 // need to scalarize so the cost can be cheaper (smull or umull). 2899 // so the cost can be cheaper (smull or umull). 2900 if (LT.second != MVT::v2i64 || isWideningInstruction(Ty, Opcode, Args)) 2901 return LT.first; 2902 return LT.first * 14; 2903 case ISD::ADD: 2904 case ISD::XOR: 2905 case ISD::OR: 2906 case ISD::AND: 2907 case ISD::SRL: 2908 case ISD::SRA: 2909 case ISD::SHL: 2910 // These nodes are marked as 'custom' for combining purposes only. 2911 // We know that they are legal. See LowerAdd in ISelLowering. 2912 return LT.first; 2913 2914 case ISD::FNEG: 2915 case ISD::FADD: 2916 case ISD::FSUB: 2917 // Increase the cost for half and bfloat types if not architecturally 2918 // supported. 2919 if ((Ty->getScalarType()->isHalfTy() && !ST->hasFullFP16()) || 2920 (Ty->getScalarType()->isBFloatTy() && !ST->hasBF16())) 2921 return 2 * LT.first; 2922 if (!Ty->getScalarType()->isFP128Ty()) 2923 return LT.first; 2924 [[fallthrough]]; 2925 case ISD::FMUL: 2926 case ISD::FDIV: 2927 // These nodes are marked as 'custom' just to lower them to SVE. 2928 // We know said lowering will incur no additional cost. 2929 if (!Ty->getScalarType()->isFP128Ty()) 2930 return 2 * LT.first; 2931 2932 return BaseT::getArithmeticInstrCost(Opcode, Ty, CostKind, Op1Info, 2933 Op2Info); 2934 } 2935 } 2936 2937 InstructionCost AArch64TTIImpl::getAddressComputationCost(Type *Ty, 2938 ScalarEvolution *SE, 2939 const SCEV *Ptr) { 2940 // Address computations in vectorized code with non-consecutive addresses will 2941 // likely result in more instructions compared to scalar code where the 2942 // computation can more often be merged into the index mode. The resulting 2943 // extra micro-ops can significantly decrease throughput. 2944 unsigned NumVectorInstToHideOverhead = NeonNonConstStrideOverhead; 2945 int MaxMergeDistance = 64; 2946 2947 if (Ty->isVectorTy() && SE && 2948 !BaseT::isConstantStridedAccessLessThan(SE, Ptr, MaxMergeDistance + 1)) 2949 return NumVectorInstToHideOverhead; 2950 2951 // In many cases the address computation is not merged into the instruction 2952 // addressing mode. 2953 return 1; 2954 } 2955 2956 InstructionCost AArch64TTIImpl::getCmpSelInstrCost(unsigned Opcode, Type *ValTy, 2957 Type *CondTy, 2958 CmpInst::Predicate VecPred, 2959 TTI::TargetCostKind CostKind, 2960 const Instruction *I) { 2961 // TODO: Handle other cost kinds. 2962 if (CostKind != TTI::TCK_RecipThroughput) 2963 return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, 2964 I); 2965 2966 int ISD = TLI->InstructionOpcodeToISD(Opcode); 2967 // We don't lower some vector selects well that are wider than the register 2968 // width. 2969 if (isa<FixedVectorType>(ValTy) && ISD == ISD::SELECT) { 2970 // We would need this many instructions to hide the scalarization happening. 2971 const int AmortizationCost = 20; 2972 2973 // If VecPred is not set, check if we can get a predicate from the context 2974 // instruction, if its type matches the requested ValTy. 2975 if (VecPred == CmpInst::BAD_ICMP_PREDICATE && I && I->getType() == ValTy) { 2976 CmpInst::Predicate CurrentPred; 2977 if (match(I, m_Select(m_Cmp(CurrentPred, m_Value(), m_Value()), m_Value(), 2978 m_Value()))) 2979 VecPred = CurrentPred; 2980 } 2981 // Check if we have a compare/select chain that can be lowered using 2982 // a (F)CMxx & BFI pair. 2983 if (CmpInst::isIntPredicate(VecPred) || VecPred == CmpInst::FCMP_OLE || 2984 VecPred == CmpInst::FCMP_OLT || VecPred == CmpInst::FCMP_OGT || 2985 VecPred == CmpInst::FCMP_OGE || VecPred == CmpInst::FCMP_OEQ || 2986 VecPred == CmpInst::FCMP_UNE) { 2987 static const auto ValidMinMaxTys = { 2988 MVT::v8i8, MVT::v16i8, MVT::v4i16, MVT::v8i16, MVT::v2i32, 2989 MVT::v4i32, MVT::v2i64, MVT::v2f32, MVT::v4f32, MVT::v2f64}; 2990 static const auto ValidFP16MinMaxTys = {MVT::v4f16, MVT::v8f16}; 2991 2992 auto LT = getTypeLegalizationCost(ValTy); 2993 if (any_of(ValidMinMaxTys, [<](MVT M) { return M == LT.second; }) || 2994 (ST->hasFullFP16() && 2995 any_of(ValidFP16MinMaxTys, [<](MVT M) { return M == LT.second; }))) 2996 return LT.first; 2997 } 2998 2999 static const TypeConversionCostTblEntry 3000 VectorSelectTbl[] = { 3001 { ISD::SELECT, MVT::v2i1, MVT::v2f32, 2 }, 3002 { ISD::SELECT, MVT::v2i1, MVT::v2f64, 2 }, 3003 { ISD::SELECT, MVT::v4i1, MVT::v4f32, 2 }, 3004 { ISD::SELECT, MVT::v4i1, MVT::v4f16, 2 }, 3005 { ISD::SELECT, MVT::v8i1, MVT::v8f16, 2 }, 3006 { ISD::SELECT, MVT::v16i1, MVT::v16i16, 16 }, 3007 { ISD::SELECT, MVT::v8i1, MVT::v8i32, 8 }, 3008 { ISD::SELECT, MVT::v16i1, MVT::v16i32, 16 }, 3009 { ISD::SELECT, MVT::v4i1, MVT::v4i64, 4 * AmortizationCost }, 3010 { ISD::SELECT, MVT::v8i1, MVT::v8i64, 8 * AmortizationCost }, 3011 { ISD::SELECT, MVT::v16i1, MVT::v16i64, 16 * AmortizationCost } 3012 }; 3013 3014 EVT SelCondTy = TLI->getValueType(DL, CondTy); 3015 EVT SelValTy = TLI->getValueType(DL, ValTy); 3016 if (SelCondTy.isSimple() && SelValTy.isSimple()) { 3017 if (const auto *Entry = ConvertCostTableLookup(VectorSelectTbl, ISD, 3018 SelCondTy.getSimpleVT(), 3019 SelValTy.getSimpleVT())) 3020 return Entry->Cost; 3021 } 3022 } 3023 3024 if (isa<FixedVectorType>(ValTy) && ISD == ISD::SETCC) { 3025 auto LT = getTypeLegalizationCost(ValTy); 3026 // Cost v4f16 FCmp without FP16 support via converting to v4f32 and back. 3027 if (LT.second == MVT::v4f16 && !ST->hasFullFP16()) 3028 return LT.first * 4; // fcvtl + fcvtl + fcmp + xtn 3029 } 3030 3031 // Treat the icmp in icmp(and, 0) as free, as we can make use of ands. 3032 // FIXME: This can apply to more conditions and add/sub if it can be shown to 3033 // be profitable. 3034 if (ValTy->isIntegerTy() && ISD == ISD::SETCC && I && 3035 ICmpInst::isEquality(VecPred) && 3036 TLI->isTypeLegal(TLI->getValueType(DL, ValTy)) && 3037 match(I->getOperand(1), m_Zero()) && 3038 match(I->getOperand(0), m_And(m_Value(), m_Value()))) 3039 return 0; 3040 3041 // The base case handles scalable vectors fine for now, since it treats the 3042 // cost as 1 * legalization cost. 3043 return BaseT::getCmpSelInstrCost(Opcode, ValTy, CondTy, VecPred, CostKind, I); 3044 } 3045 3046 AArch64TTIImpl::TTI::MemCmpExpansionOptions 3047 AArch64TTIImpl::enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const { 3048 TTI::MemCmpExpansionOptions Options; 3049 if (ST->requiresStrictAlign()) { 3050 // TODO: Add cost modeling for strict align. Misaligned loads expand to 3051 // a bunch of instructions when strict align is enabled. 3052 return Options; 3053 } 3054 Options.AllowOverlappingLoads = true; 3055 Options.MaxNumLoads = TLI->getMaxExpandSizeMemcmp(OptSize); 3056 Options.NumLoadsPerBlock = Options.MaxNumLoads; 3057 // TODO: Though vector loads usually perform well on AArch64, in some targets 3058 // they may wake up the FP unit, which raises the power consumption. Perhaps 3059 // they could be used with no holds barred (-O3). 3060 Options.LoadSizes = {8, 4, 2, 1}; 3061 Options.AllowedTailExpansions = {3, 5, 6}; 3062 return Options; 3063 } 3064 3065 bool AArch64TTIImpl::prefersVectorizedAddressing() const { 3066 return ST->hasSVE(); 3067 } 3068 3069 InstructionCost 3070 AArch64TTIImpl::getMaskedMemoryOpCost(unsigned Opcode, Type *Src, 3071 Align Alignment, unsigned AddressSpace, 3072 TTI::TargetCostKind CostKind) { 3073 if (useNeonVector(Src)) 3074 return BaseT::getMaskedMemoryOpCost(Opcode, Src, Alignment, AddressSpace, 3075 CostKind); 3076 auto LT = getTypeLegalizationCost(Src); 3077 if (!LT.first.isValid()) 3078 return InstructionCost::getInvalid(); 3079 3080 // The code-generator is currently not able to handle scalable vectors 3081 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting 3082 // it. This change will be removed when code-generation for these types is 3083 // sufficiently reliable. 3084 if (cast<VectorType>(Src)->getElementCount() == ElementCount::getScalable(1)) 3085 return InstructionCost::getInvalid(); 3086 3087 return LT.first; 3088 } 3089 3090 static unsigned getSVEGatherScatterOverhead(unsigned Opcode) { 3091 return Opcode == Instruction::Load ? SVEGatherOverhead : SVEScatterOverhead; 3092 } 3093 3094 InstructionCost AArch64TTIImpl::getGatherScatterOpCost( 3095 unsigned Opcode, Type *DataTy, const Value *Ptr, bool VariableMask, 3096 Align Alignment, TTI::TargetCostKind CostKind, const Instruction *I) { 3097 if (useNeonVector(DataTy) || !isLegalMaskedGatherScatter(DataTy)) 3098 return BaseT::getGatherScatterOpCost(Opcode, DataTy, Ptr, VariableMask, 3099 Alignment, CostKind, I); 3100 auto *VT = cast<VectorType>(DataTy); 3101 auto LT = getTypeLegalizationCost(DataTy); 3102 if (!LT.first.isValid()) 3103 return InstructionCost::getInvalid(); 3104 3105 if (!LT.second.isVector() || 3106 !isElementTypeLegalForScalableVector(VT->getElementType())) 3107 return InstructionCost::getInvalid(); 3108 3109 // The code-generator is currently not able to handle scalable vectors 3110 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting 3111 // it. This change will be removed when code-generation for these types is 3112 // sufficiently reliable. 3113 if (cast<VectorType>(DataTy)->getElementCount() == 3114 ElementCount::getScalable(1)) 3115 return InstructionCost::getInvalid(); 3116 3117 ElementCount LegalVF = LT.second.getVectorElementCount(); 3118 InstructionCost MemOpCost = 3119 getMemoryOpCost(Opcode, VT->getElementType(), Alignment, 0, CostKind, 3120 {TTI::OK_AnyValue, TTI::OP_None}, I); 3121 // Add on an overhead cost for using gathers/scatters. 3122 // TODO: At the moment this is applied unilaterally for all CPUs, but at some 3123 // point we may want a per-CPU overhead. 3124 MemOpCost *= getSVEGatherScatterOverhead(Opcode); 3125 return LT.first * MemOpCost * getMaxNumElements(LegalVF); 3126 } 3127 3128 bool AArch64TTIImpl::useNeonVector(const Type *Ty) const { 3129 return isa<FixedVectorType>(Ty) && !ST->useSVEForFixedLengthVectors(); 3130 } 3131 3132 InstructionCost AArch64TTIImpl::getMemoryOpCost(unsigned Opcode, Type *Ty, 3133 MaybeAlign Alignment, 3134 unsigned AddressSpace, 3135 TTI::TargetCostKind CostKind, 3136 TTI::OperandValueInfo OpInfo, 3137 const Instruction *I) { 3138 EVT VT = TLI->getValueType(DL, Ty, true); 3139 // Type legalization can't handle structs 3140 if (VT == MVT::Other) 3141 return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace, 3142 CostKind); 3143 3144 auto LT = getTypeLegalizationCost(Ty); 3145 if (!LT.first.isValid()) 3146 return InstructionCost::getInvalid(); 3147 3148 // The code-generator is currently not able to handle scalable vectors 3149 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting 3150 // it. This change will be removed when code-generation for these types is 3151 // sufficiently reliable. 3152 if (auto *VTy = dyn_cast<ScalableVectorType>(Ty)) 3153 if (VTy->getElementCount() == ElementCount::getScalable(1)) 3154 return InstructionCost::getInvalid(); 3155 3156 // TODO: consider latency as well for TCK_SizeAndLatency. 3157 if (CostKind == TTI::TCK_CodeSize || CostKind == TTI::TCK_SizeAndLatency) 3158 return LT.first; 3159 3160 if (CostKind != TTI::TCK_RecipThroughput) 3161 return 1; 3162 3163 if (ST->isMisaligned128StoreSlow() && Opcode == Instruction::Store && 3164 LT.second.is128BitVector() && (!Alignment || *Alignment < Align(16))) { 3165 // Unaligned stores are extremely inefficient. We don't split all 3166 // unaligned 128-bit stores because the negative impact that has shown in 3167 // practice on inlined block copy code. 3168 // We make such stores expensive so that we will only vectorize if there 3169 // are 6 other instructions getting vectorized. 3170 const int AmortizationCost = 6; 3171 3172 return LT.first * 2 * AmortizationCost; 3173 } 3174 3175 // Opaque ptr or ptr vector types are i64s and can be lowered to STP/LDPs. 3176 if (Ty->isPtrOrPtrVectorTy()) 3177 return LT.first; 3178 3179 // Check truncating stores and extending loads. 3180 if (useNeonVector(Ty) && 3181 Ty->getScalarSizeInBits() != LT.second.getScalarSizeInBits()) { 3182 // v4i8 types are lowered to scalar a load/store and sshll/xtn. 3183 if (VT == MVT::v4i8) 3184 return 2; 3185 // Otherwise we need to scalarize. 3186 return cast<FixedVectorType>(Ty)->getNumElements() * 2; 3187 } 3188 3189 return LT.first; 3190 } 3191 3192 InstructionCost AArch64TTIImpl::getInterleavedMemoryOpCost( 3193 unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices, 3194 Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind, 3195 bool UseMaskForCond, bool UseMaskForGaps) { 3196 assert(Factor >= 2 && "Invalid interleave factor"); 3197 auto *VecVTy = cast<VectorType>(VecTy); 3198 3199 if (VecTy->isScalableTy() && (!ST->hasSVE() || Factor != 2)) 3200 return InstructionCost::getInvalid(); 3201 3202 // Vectorization for masked interleaved accesses is only enabled for scalable 3203 // VF. 3204 if (!VecTy->isScalableTy() && (UseMaskForCond || UseMaskForGaps)) 3205 return InstructionCost::getInvalid(); 3206 3207 if (!UseMaskForGaps && Factor <= TLI->getMaxSupportedInterleaveFactor()) { 3208 unsigned MinElts = VecVTy->getElementCount().getKnownMinValue(); 3209 auto *SubVecTy = 3210 VectorType::get(VecVTy->getElementType(), 3211 VecVTy->getElementCount().divideCoefficientBy(Factor)); 3212 3213 // ldN/stN only support legal vector types of size 64 or 128 in bits. 3214 // Accesses having vector types that are a multiple of 128 bits can be 3215 // matched to more than one ldN/stN instruction. 3216 bool UseScalable; 3217 if (MinElts % Factor == 0 && 3218 TLI->isLegalInterleavedAccessType(SubVecTy, DL, UseScalable)) 3219 return Factor * TLI->getNumInterleavedAccesses(SubVecTy, DL, UseScalable); 3220 } 3221 3222 return BaseT::getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices, 3223 Alignment, AddressSpace, CostKind, 3224 UseMaskForCond, UseMaskForGaps); 3225 } 3226 3227 InstructionCost 3228 AArch64TTIImpl::getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) { 3229 InstructionCost Cost = 0; 3230 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; 3231 for (auto *I : Tys) { 3232 if (!I->isVectorTy()) 3233 continue; 3234 if (I->getScalarSizeInBits() * cast<FixedVectorType>(I)->getNumElements() == 3235 128) 3236 Cost += getMemoryOpCost(Instruction::Store, I, Align(128), 0, CostKind) + 3237 getMemoryOpCost(Instruction::Load, I, Align(128), 0, CostKind); 3238 } 3239 return Cost; 3240 } 3241 3242 unsigned AArch64TTIImpl::getMaxInterleaveFactor(ElementCount VF) { 3243 return ST->getMaxInterleaveFactor(); 3244 } 3245 3246 // For Falkor, we want to avoid having too many strided loads in a loop since 3247 // that can exhaust the HW prefetcher resources. We adjust the unroller 3248 // MaxCount preference below to attempt to ensure unrolling doesn't create too 3249 // many strided loads. 3250 static void 3251 getFalkorUnrollingPreferences(Loop *L, ScalarEvolution &SE, 3252 TargetTransformInfo::UnrollingPreferences &UP) { 3253 enum { MaxStridedLoads = 7 }; 3254 auto countStridedLoads = [](Loop *L, ScalarEvolution &SE) { 3255 int StridedLoads = 0; 3256 // FIXME? We could make this more precise by looking at the CFG and 3257 // e.g. not counting loads in each side of an if-then-else diamond. 3258 for (const auto BB : L->blocks()) { 3259 for (auto &I : *BB) { 3260 LoadInst *LMemI = dyn_cast<LoadInst>(&I); 3261 if (!LMemI) 3262 continue; 3263 3264 Value *PtrValue = LMemI->getPointerOperand(); 3265 if (L->isLoopInvariant(PtrValue)) 3266 continue; 3267 3268 const SCEV *LSCEV = SE.getSCEV(PtrValue); 3269 const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV); 3270 if (!LSCEVAddRec || !LSCEVAddRec->isAffine()) 3271 continue; 3272 3273 // FIXME? We could take pairing of unrolled load copies into account 3274 // by looking at the AddRec, but we would probably have to limit this 3275 // to loops with no stores or other memory optimization barriers. 3276 ++StridedLoads; 3277 // We've seen enough strided loads that seeing more won't make a 3278 // difference. 3279 if (StridedLoads > MaxStridedLoads / 2) 3280 return StridedLoads; 3281 } 3282 } 3283 return StridedLoads; 3284 }; 3285 3286 int StridedLoads = countStridedLoads(L, SE); 3287 LLVM_DEBUG(dbgs() << "falkor-hwpf: detected " << StridedLoads 3288 << " strided loads\n"); 3289 // Pick the largest power of 2 unroll count that won't result in too many 3290 // strided loads. 3291 if (StridedLoads) { 3292 UP.MaxCount = 1 << Log2_32(MaxStridedLoads / StridedLoads); 3293 LLVM_DEBUG(dbgs() << "falkor-hwpf: setting unroll MaxCount to " 3294 << UP.MaxCount << '\n'); 3295 } 3296 } 3297 3298 void AArch64TTIImpl::getUnrollingPreferences(Loop *L, ScalarEvolution &SE, 3299 TTI::UnrollingPreferences &UP, 3300 OptimizationRemarkEmitter *ORE) { 3301 // Enable partial unrolling and runtime unrolling. 3302 BaseT::getUnrollingPreferences(L, SE, UP, ORE); 3303 3304 UP.UpperBound = true; 3305 3306 // For inner loop, it is more likely to be a hot one, and the runtime check 3307 // can be promoted out from LICM pass, so the overhead is less, let's try 3308 // a larger threshold to unroll more loops. 3309 if (L->getLoopDepth() > 1) 3310 UP.PartialThreshold *= 2; 3311 3312 // Disable partial & runtime unrolling on -Os. 3313 UP.PartialOptSizeThreshold = 0; 3314 3315 if (ST->getProcFamily() == AArch64Subtarget::Falkor && 3316 EnableFalkorHWPFUnrollFix) 3317 getFalkorUnrollingPreferences(L, SE, UP); 3318 3319 // Scan the loop: don't unroll loops with calls as this could prevent 3320 // inlining. Don't unroll vector loops either, as they don't benefit much from 3321 // unrolling. 3322 for (auto *BB : L->getBlocks()) { 3323 for (auto &I : *BB) { 3324 // Don't unroll vectorised loop. 3325 if (I.getType()->isVectorTy()) 3326 return; 3327 3328 if (isa<CallInst>(I) || isa<InvokeInst>(I)) { 3329 if (const Function *F = cast<CallBase>(I).getCalledFunction()) { 3330 if (!isLoweredToCall(F)) 3331 continue; 3332 } 3333 return; 3334 } 3335 } 3336 } 3337 3338 // Enable runtime unrolling for in-order models 3339 // If mcpu is omitted, getProcFamily() returns AArch64Subtarget::Others, so by 3340 // checking for that case, we can ensure that the default behaviour is 3341 // unchanged 3342 if (ST->getProcFamily() != AArch64Subtarget::Others && 3343 !ST->getSchedModel().isOutOfOrder()) { 3344 UP.Runtime = true; 3345 UP.Partial = true; 3346 UP.UnrollRemainder = true; 3347 UP.DefaultUnrollRuntimeCount = 4; 3348 3349 UP.UnrollAndJam = true; 3350 UP.UnrollAndJamInnerLoopThreshold = 60; 3351 } 3352 } 3353 3354 void AArch64TTIImpl::getPeelingPreferences(Loop *L, ScalarEvolution &SE, 3355 TTI::PeelingPreferences &PP) { 3356 BaseT::getPeelingPreferences(L, SE, PP); 3357 } 3358 3359 Value *AArch64TTIImpl::getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst, 3360 Type *ExpectedType) { 3361 switch (Inst->getIntrinsicID()) { 3362 default: 3363 return nullptr; 3364 case Intrinsic::aarch64_neon_st2: 3365 case Intrinsic::aarch64_neon_st3: 3366 case Intrinsic::aarch64_neon_st4: { 3367 // Create a struct type 3368 StructType *ST = dyn_cast<StructType>(ExpectedType); 3369 if (!ST) 3370 return nullptr; 3371 unsigned NumElts = Inst->arg_size() - 1; 3372 if (ST->getNumElements() != NumElts) 3373 return nullptr; 3374 for (unsigned i = 0, e = NumElts; i != e; ++i) { 3375 if (Inst->getArgOperand(i)->getType() != ST->getElementType(i)) 3376 return nullptr; 3377 } 3378 Value *Res = PoisonValue::get(ExpectedType); 3379 IRBuilder<> Builder(Inst); 3380 for (unsigned i = 0, e = NumElts; i != e; ++i) { 3381 Value *L = Inst->getArgOperand(i); 3382 Res = Builder.CreateInsertValue(Res, L, i); 3383 } 3384 return Res; 3385 } 3386 case Intrinsic::aarch64_neon_ld2: 3387 case Intrinsic::aarch64_neon_ld3: 3388 case Intrinsic::aarch64_neon_ld4: 3389 if (Inst->getType() == ExpectedType) 3390 return Inst; 3391 return nullptr; 3392 } 3393 } 3394 3395 bool AArch64TTIImpl::getTgtMemIntrinsic(IntrinsicInst *Inst, 3396 MemIntrinsicInfo &Info) { 3397 switch (Inst->getIntrinsicID()) { 3398 default: 3399 break; 3400 case Intrinsic::aarch64_neon_ld2: 3401 case Intrinsic::aarch64_neon_ld3: 3402 case Intrinsic::aarch64_neon_ld4: 3403 Info.ReadMem = true; 3404 Info.WriteMem = false; 3405 Info.PtrVal = Inst->getArgOperand(0); 3406 break; 3407 case Intrinsic::aarch64_neon_st2: 3408 case Intrinsic::aarch64_neon_st3: 3409 case Intrinsic::aarch64_neon_st4: 3410 Info.ReadMem = false; 3411 Info.WriteMem = true; 3412 Info.PtrVal = Inst->getArgOperand(Inst->arg_size() - 1); 3413 break; 3414 } 3415 3416 switch (Inst->getIntrinsicID()) { 3417 default: 3418 return false; 3419 case Intrinsic::aarch64_neon_ld2: 3420 case Intrinsic::aarch64_neon_st2: 3421 Info.MatchingId = VECTOR_LDST_TWO_ELEMENTS; 3422 break; 3423 case Intrinsic::aarch64_neon_ld3: 3424 case Intrinsic::aarch64_neon_st3: 3425 Info.MatchingId = VECTOR_LDST_THREE_ELEMENTS; 3426 break; 3427 case Intrinsic::aarch64_neon_ld4: 3428 case Intrinsic::aarch64_neon_st4: 3429 Info.MatchingId = VECTOR_LDST_FOUR_ELEMENTS; 3430 break; 3431 } 3432 return true; 3433 } 3434 3435 /// See if \p I should be considered for address type promotion. We check if \p 3436 /// I is a sext with right type and used in memory accesses. If it used in a 3437 /// "complex" getelementptr, we allow it to be promoted without finding other 3438 /// sext instructions that sign extended the same initial value. A getelementptr 3439 /// is considered as "complex" if it has more than 2 operands. 3440 bool AArch64TTIImpl::shouldConsiderAddressTypePromotion( 3441 const Instruction &I, bool &AllowPromotionWithoutCommonHeader) { 3442 bool Considerable = false; 3443 AllowPromotionWithoutCommonHeader = false; 3444 if (!isa<SExtInst>(&I)) 3445 return false; 3446 Type *ConsideredSExtType = 3447 Type::getInt64Ty(I.getParent()->getParent()->getContext()); 3448 if (I.getType() != ConsideredSExtType) 3449 return false; 3450 // See if the sext is the one with the right type and used in at least one 3451 // GetElementPtrInst. 3452 for (const User *U : I.users()) { 3453 if (const GetElementPtrInst *GEPInst = dyn_cast<GetElementPtrInst>(U)) { 3454 Considerable = true; 3455 // A getelementptr is considered as "complex" if it has more than 2 3456 // operands. We will promote a SExt used in such complex GEP as we 3457 // expect some computation to be merged if they are done on 64 bits. 3458 if (GEPInst->getNumOperands() > 2) { 3459 AllowPromotionWithoutCommonHeader = true; 3460 break; 3461 } 3462 } 3463 } 3464 return Considerable; 3465 } 3466 3467 bool AArch64TTIImpl::isLegalToVectorizeReduction( 3468 const RecurrenceDescriptor &RdxDesc, ElementCount VF) const { 3469 if (!VF.isScalable()) 3470 return true; 3471 3472 Type *Ty = RdxDesc.getRecurrenceType(); 3473 if (Ty->isBFloatTy() || !isElementTypeLegalForScalableVector(Ty)) 3474 return false; 3475 3476 switch (RdxDesc.getRecurrenceKind()) { 3477 case RecurKind::Add: 3478 case RecurKind::FAdd: 3479 case RecurKind::And: 3480 case RecurKind::Or: 3481 case RecurKind::Xor: 3482 case RecurKind::SMin: 3483 case RecurKind::SMax: 3484 case RecurKind::UMin: 3485 case RecurKind::UMax: 3486 case RecurKind::FMin: 3487 case RecurKind::FMax: 3488 case RecurKind::FMulAdd: 3489 case RecurKind::IAnyOf: 3490 case RecurKind::FAnyOf: 3491 return true; 3492 default: 3493 return false; 3494 } 3495 } 3496 3497 InstructionCost 3498 AArch64TTIImpl::getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty, 3499 FastMathFlags FMF, 3500 TTI::TargetCostKind CostKind) { 3501 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Ty); 3502 3503 if (LT.second.getScalarType() == MVT::f16 && !ST->hasFullFP16()) 3504 return BaseT::getMinMaxReductionCost(IID, Ty, FMF, CostKind); 3505 3506 InstructionCost LegalizationCost = 0; 3507 if (LT.first > 1) { 3508 Type *LegalVTy = EVT(LT.second).getTypeForEVT(Ty->getContext()); 3509 IntrinsicCostAttributes Attrs(IID, LegalVTy, {LegalVTy, LegalVTy}, FMF); 3510 LegalizationCost = getIntrinsicInstrCost(Attrs, CostKind) * (LT.first - 1); 3511 } 3512 3513 return LegalizationCost + /*Cost of horizontal reduction*/ 2; 3514 } 3515 3516 InstructionCost AArch64TTIImpl::getArithmeticReductionCostSVE( 3517 unsigned Opcode, VectorType *ValTy, TTI::TargetCostKind CostKind) { 3518 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy); 3519 InstructionCost LegalizationCost = 0; 3520 if (LT.first > 1) { 3521 Type *LegalVTy = EVT(LT.second).getTypeForEVT(ValTy->getContext()); 3522 LegalizationCost = getArithmeticInstrCost(Opcode, LegalVTy, CostKind); 3523 LegalizationCost *= LT.first - 1; 3524 } 3525 3526 int ISD = TLI->InstructionOpcodeToISD(Opcode); 3527 assert(ISD && "Invalid opcode"); 3528 // Add the final reduction cost for the legal horizontal reduction 3529 switch (ISD) { 3530 case ISD::ADD: 3531 case ISD::AND: 3532 case ISD::OR: 3533 case ISD::XOR: 3534 case ISD::FADD: 3535 return LegalizationCost + 2; 3536 default: 3537 return InstructionCost::getInvalid(); 3538 } 3539 } 3540 3541 InstructionCost 3542 AArch64TTIImpl::getArithmeticReductionCost(unsigned Opcode, VectorType *ValTy, 3543 std::optional<FastMathFlags> FMF, 3544 TTI::TargetCostKind CostKind) { 3545 if (TTI::requiresOrderedReduction(FMF)) { 3546 if (auto *FixedVTy = dyn_cast<FixedVectorType>(ValTy)) { 3547 InstructionCost BaseCost = 3548 BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind); 3549 // Add on extra cost to reflect the extra overhead on some CPUs. We still 3550 // end up vectorizing for more computationally intensive loops. 3551 return BaseCost + FixedVTy->getNumElements(); 3552 } 3553 3554 if (Opcode != Instruction::FAdd) 3555 return InstructionCost::getInvalid(); 3556 3557 auto *VTy = cast<ScalableVectorType>(ValTy); 3558 InstructionCost Cost = 3559 getArithmeticInstrCost(Opcode, VTy->getScalarType(), CostKind); 3560 Cost *= getMaxNumElements(VTy->getElementCount()); 3561 return Cost; 3562 } 3563 3564 if (isa<ScalableVectorType>(ValTy)) 3565 return getArithmeticReductionCostSVE(Opcode, ValTy, CostKind); 3566 3567 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(ValTy); 3568 MVT MTy = LT.second; 3569 int ISD = TLI->InstructionOpcodeToISD(Opcode); 3570 assert(ISD && "Invalid opcode"); 3571 3572 // Horizontal adds can use the 'addv' instruction. We model the cost of these 3573 // instructions as twice a normal vector add, plus 1 for each legalization 3574 // step (LT.first). This is the only arithmetic vector reduction operation for 3575 // which we have an instruction. 3576 // OR, XOR and AND costs should match the codegen from: 3577 // OR: llvm/test/CodeGen/AArch64/reduce-or.ll 3578 // XOR: llvm/test/CodeGen/AArch64/reduce-xor.ll 3579 // AND: llvm/test/CodeGen/AArch64/reduce-and.ll 3580 static const CostTblEntry CostTblNoPairwise[]{ 3581 {ISD::ADD, MVT::v8i8, 2}, 3582 {ISD::ADD, MVT::v16i8, 2}, 3583 {ISD::ADD, MVT::v4i16, 2}, 3584 {ISD::ADD, MVT::v8i16, 2}, 3585 {ISD::ADD, MVT::v4i32, 2}, 3586 {ISD::ADD, MVT::v2i64, 2}, 3587 {ISD::OR, MVT::v8i8, 15}, 3588 {ISD::OR, MVT::v16i8, 17}, 3589 {ISD::OR, MVT::v4i16, 7}, 3590 {ISD::OR, MVT::v8i16, 9}, 3591 {ISD::OR, MVT::v2i32, 3}, 3592 {ISD::OR, MVT::v4i32, 5}, 3593 {ISD::OR, MVT::v2i64, 3}, 3594 {ISD::XOR, MVT::v8i8, 15}, 3595 {ISD::XOR, MVT::v16i8, 17}, 3596 {ISD::XOR, MVT::v4i16, 7}, 3597 {ISD::XOR, MVT::v8i16, 9}, 3598 {ISD::XOR, MVT::v2i32, 3}, 3599 {ISD::XOR, MVT::v4i32, 5}, 3600 {ISD::XOR, MVT::v2i64, 3}, 3601 {ISD::AND, MVT::v8i8, 15}, 3602 {ISD::AND, MVT::v16i8, 17}, 3603 {ISD::AND, MVT::v4i16, 7}, 3604 {ISD::AND, MVT::v8i16, 9}, 3605 {ISD::AND, MVT::v2i32, 3}, 3606 {ISD::AND, MVT::v4i32, 5}, 3607 {ISD::AND, MVT::v2i64, 3}, 3608 }; 3609 switch (ISD) { 3610 default: 3611 break; 3612 case ISD::ADD: 3613 if (const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy)) 3614 return (LT.first - 1) + Entry->Cost; 3615 break; 3616 case ISD::XOR: 3617 case ISD::AND: 3618 case ISD::OR: 3619 const auto *Entry = CostTableLookup(CostTblNoPairwise, ISD, MTy); 3620 if (!Entry) 3621 break; 3622 auto *ValVTy = cast<FixedVectorType>(ValTy); 3623 if (MTy.getVectorNumElements() <= ValVTy->getNumElements() && 3624 isPowerOf2_32(ValVTy->getNumElements())) { 3625 InstructionCost ExtraCost = 0; 3626 if (LT.first != 1) { 3627 // Type needs to be split, so there is an extra cost of LT.first - 1 3628 // arithmetic ops. 3629 auto *Ty = FixedVectorType::get(ValTy->getElementType(), 3630 MTy.getVectorNumElements()); 3631 ExtraCost = getArithmeticInstrCost(Opcode, Ty, CostKind); 3632 ExtraCost *= LT.first - 1; 3633 } 3634 // All and/or/xor of i1 will be lowered with maxv/minv/addv + fmov 3635 auto Cost = ValVTy->getElementType()->isIntegerTy(1) ? 2 : Entry->Cost; 3636 return Cost + ExtraCost; 3637 } 3638 break; 3639 } 3640 return BaseT::getArithmeticReductionCost(Opcode, ValTy, FMF, CostKind); 3641 } 3642 3643 InstructionCost AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index) { 3644 static const CostTblEntry ShuffleTbl[] = { 3645 { TTI::SK_Splice, MVT::nxv16i8, 1 }, 3646 { TTI::SK_Splice, MVT::nxv8i16, 1 }, 3647 { TTI::SK_Splice, MVT::nxv4i32, 1 }, 3648 { TTI::SK_Splice, MVT::nxv2i64, 1 }, 3649 { TTI::SK_Splice, MVT::nxv2f16, 1 }, 3650 { TTI::SK_Splice, MVT::nxv4f16, 1 }, 3651 { TTI::SK_Splice, MVT::nxv8f16, 1 }, 3652 { TTI::SK_Splice, MVT::nxv2bf16, 1 }, 3653 { TTI::SK_Splice, MVT::nxv4bf16, 1 }, 3654 { TTI::SK_Splice, MVT::nxv8bf16, 1 }, 3655 { TTI::SK_Splice, MVT::nxv2f32, 1 }, 3656 { TTI::SK_Splice, MVT::nxv4f32, 1 }, 3657 { TTI::SK_Splice, MVT::nxv2f64, 1 }, 3658 }; 3659 3660 // The code-generator is currently not able to handle scalable vectors 3661 // of <vscale x 1 x eltty> yet, so return an invalid cost to avoid selecting 3662 // it. This change will be removed when code-generation for these types is 3663 // sufficiently reliable. 3664 if (Tp->getElementCount() == ElementCount::getScalable(1)) 3665 return InstructionCost::getInvalid(); 3666 3667 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp); 3668 Type *LegalVTy = EVT(LT.second).getTypeForEVT(Tp->getContext()); 3669 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; 3670 EVT PromotedVT = LT.second.getScalarType() == MVT::i1 3671 ? TLI->getPromotedVTForPredicate(EVT(LT.second)) 3672 : LT.second; 3673 Type *PromotedVTy = EVT(PromotedVT).getTypeForEVT(Tp->getContext()); 3674 InstructionCost LegalizationCost = 0; 3675 if (Index < 0) { 3676 LegalizationCost = 3677 getCmpSelInstrCost(Instruction::ICmp, PromotedVTy, PromotedVTy, 3678 CmpInst::BAD_ICMP_PREDICATE, CostKind) + 3679 getCmpSelInstrCost(Instruction::Select, PromotedVTy, LegalVTy, 3680 CmpInst::BAD_ICMP_PREDICATE, CostKind); 3681 } 3682 3683 // Predicated splice are promoted when lowering. See AArch64ISelLowering.cpp 3684 // Cost performed on a promoted type. 3685 if (LT.second.getScalarType() == MVT::i1) { 3686 LegalizationCost += 3687 getCastInstrCost(Instruction::ZExt, PromotedVTy, LegalVTy, 3688 TTI::CastContextHint::None, CostKind) + 3689 getCastInstrCost(Instruction::Trunc, LegalVTy, PromotedVTy, 3690 TTI::CastContextHint::None, CostKind); 3691 } 3692 const auto *Entry = 3693 CostTableLookup(ShuffleTbl, TTI::SK_Splice, PromotedVT.getSimpleVT()); 3694 assert(Entry && "Illegal Type for Splice"); 3695 LegalizationCost += Entry->Cost; 3696 return LegalizationCost * LT.first; 3697 } 3698 3699 InstructionCost AArch64TTIImpl::getShuffleCost(TTI::ShuffleKind Kind, 3700 VectorType *Tp, 3701 ArrayRef<int> Mask, 3702 TTI::TargetCostKind CostKind, 3703 int Index, VectorType *SubTp, 3704 ArrayRef<const Value *> Args) { 3705 std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp); 3706 // If we have a Mask, and the LT is being legalized somehow, split the Mask 3707 // into smaller vectors and sum the cost of each shuffle. 3708 if (!Mask.empty() && isa<FixedVectorType>(Tp) && LT.second.isVector() && 3709 Tp->getScalarSizeInBits() == LT.second.getScalarSizeInBits() && 3710 Mask.size() > LT.second.getVectorNumElements() && !Index && !SubTp) { 3711 unsigned TpNumElts = Mask.size(); 3712 unsigned LTNumElts = LT.second.getVectorNumElements(); 3713 unsigned NumVecs = (TpNumElts + LTNumElts - 1) / LTNumElts; 3714 VectorType *NTp = 3715 VectorType::get(Tp->getScalarType(), LT.second.getVectorElementCount()); 3716 InstructionCost Cost; 3717 for (unsigned N = 0; N < NumVecs; N++) { 3718 SmallVector<int> NMask; 3719 // Split the existing mask into chunks of size LTNumElts. Track the source 3720 // sub-vectors to ensure the result has at most 2 inputs. 3721 unsigned Source1, Source2; 3722 unsigned NumSources = 0; 3723 for (unsigned E = 0; E < LTNumElts; E++) { 3724 int MaskElt = (N * LTNumElts + E < TpNumElts) ? Mask[N * LTNumElts + E] 3725 : PoisonMaskElem; 3726 if (MaskElt < 0) { 3727 NMask.push_back(PoisonMaskElem); 3728 continue; 3729 } 3730 3731 // Calculate which source from the input this comes from and whether it 3732 // is new to us. 3733 unsigned Source = MaskElt / LTNumElts; 3734 if (NumSources == 0) { 3735 Source1 = Source; 3736 NumSources = 1; 3737 } else if (NumSources == 1 && Source != Source1) { 3738 Source2 = Source; 3739 NumSources = 2; 3740 } else if (NumSources >= 2 && Source != Source1 && Source != Source2) { 3741 NumSources++; 3742 } 3743 3744 // Add to the new mask. For the NumSources>2 case these are not correct, 3745 // but are only used for the modular lane number. 3746 if (Source == Source1) 3747 NMask.push_back(MaskElt % LTNumElts); 3748 else if (Source == Source2) 3749 NMask.push_back(MaskElt % LTNumElts + LTNumElts); 3750 else 3751 NMask.push_back(MaskElt % LTNumElts); 3752 } 3753 // If the sub-mask has at most 2 input sub-vectors then re-cost it using 3754 // getShuffleCost. If not then cost it using the worst case. 3755 if (NumSources <= 2) 3756 Cost += getShuffleCost(NumSources <= 1 ? TTI::SK_PermuteSingleSrc 3757 : TTI::SK_PermuteTwoSrc, 3758 NTp, NMask, CostKind, 0, nullptr, Args); 3759 else if (any_of(enumerate(NMask), [&](const auto &ME) { 3760 return ME.value() % LTNumElts == ME.index(); 3761 })) 3762 Cost += LTNumElts - 1; 3763 else 3764 Cost += LTNumElts; 3765 } 3766 return Cost; 3767 } 3768 3769 Kind = improveShuffleKindFromMask(Kind, Mask, Tp, Index, SubTp); 3770 3771 // Check for broadcast loads, which are supported by the LD1R instruction. 3772 // In terms of code-size, the shuffle vector is free when a load + dup get 3773 // folded into a LD1R. That's what we check and return here. For performance 3774 // and reciprocal throughput, a LD1R is not completely free. In this case, we 3775 // return the cost for the broadcast below (i.e. 1 for most/all types), so 3776 // that we model the load + dup sequence slightly higher because LD1R is a 3777 // high latency instruction. 3778 if (CostKind == TTI::TCK_CodeSize && Kind == TTI::SK_Broadcast) { 3779 bool IsLoad = !Args.empty() && isa<LoadInst>(Args[0]); 3780 if (IsLoad && LT.second.isVector() && 3781 isLegalBroadcastLoad(Tp->getElementType(), 3782 LT.second.getVectorElementCount())) 3783 return 0; 3784 } 3785 3786 // If we have 4 elements for the shuffle and a Mask, get the cost straight 3787 // from the perfect shuffle tables. 3788 if (Mask.size() == 4 && Tp->getElementCount() == ElementCount::getFixed(4) && 3789 (Tp->getScalarSizeInBits() == 16 || Tp->getScalarSizeInBits() == 32) && 3790 all_of(Mask, [](int E) { return E < 8; })) 3791 return getPerfectShuffleCost(Mask); 3792 3793 if (Kind == TTI::SK_Broadcast || Kind == TTI::SK_Transpose || 3794 Kind == TTI::SK_Select || Kind == TTI::SK_PermuteSingleSrc || 3795 Kind == TTI::SK_Reverse || Kind == TTI::SK_Splice) { 3796 static const CostTblEntry ShuffleTbl[] = { 3797 // Broadcast shuffle kinds can be performed with 'dup'. 3798 {TTI::SK_Broadcast, MVT::v8i8, 1}, 3799 {TTI::SK_Broadcast, MVT::v16i8, 1}, 3800 {TTI::SK_Broadcast, MVT::v4i16, 1}, 3801 {TTI::SK_Broadcast, MVT::v8i16, 1}, 3802 {TTI::SK_Broadcast, MVT::v2i32, 1}, 3803 {TTI::SK_Broadcast, MVT::v4i32, 1}, 3804 {TTI::SK_Broadcast, MVT::v2i64, 1}, 3805 {TTI::SK_Broadcast, MVT::v4f16, 1}, 3806 {TTI::SK_Broadcast, MVT::v8f16, 1}, 3807 {TTI::SK_Broadcast, MVT::v2f32, 1}, 3808 {TTI::SK_Broadcast, MVT::v4f32, 1}, 3809 {TTI::SK_Broadcast, MVT::v2f64, 1}, 3810 // Transpose shuffle kinds can be performed with 'trn1/trn2' and 3811 // 'zip1/zip2' instructions. 3812 {TTI::SK_Transpose, MVT::v8i8, 1}, 3813 {TTI::SK_Transpose, MVT::v16i8, 1}, 3814 {TTI::SK_Transpose, MVT::v4i16, 1}, 3815 {TTI::SK_Transpose, MVT::v8i16, 1}, 3816 {TTI::SK_Transpose, MVT::v2i32, 1}, 3817 {TTI::SK_Transpose, MVT::v4i32, 1}, 3818 {TTI::SK_Transpose, MVT::v2i64, 1}, 3819 {TTI::SK_Transpose, MVT::v4f16, 1}, 3820 {TTI::SK_Transpose, MVT::v8f16, 1}, 3821 {TTI::SK_Transpose, MVT::v2f32, 1}, 3822 {TTI::SK_Transpose, MVT::v4f32, 1}, 3823 {TTI::SK_Transpose, MVT::v2f64, 1}, 3824 // Select shuffle kinds. 3825 // TODO: handle vXi8/vXi16. 3826 {TTI::SK_Select, MVT::v2i32, 1}, // mov. 3827 {TTI::SK_Select, MVT::v4i32, 2}, // rev+trn (or similar). 3828 {TTI::SK_Select, MVT::v2i64, 1}, // mov. 3829 {TTI::SK_Select, MVT::v2f32, 1}, // mov. 3830 {TTI::SK_Select, MVT::v4f32, 2}, // rev+trn (or similar). 3831 {TTI::SK_Select, MVT::v2f64, 1}, // mov. 3832 // PermuteSingleSrc shuffle kinds. 3833 {TTI::SK_PermuteSingleSrc, MVT::v2i32, 1}, // mov. 3834 {TTI::SK_PermuteSingleSrc, MVT::v4i32, 3}, // perfectshuffle worst case. 3835 {TTI::SK_PermuteSingleSrc, MVT::v2i64, 1}, // mov. 3836 {TTI::SK_PermuteSingleSrc, MVT::v2f32, 1}, // mov. 3837 {TTI::SK_PermuteSingleSrc, MVT::v4f32, 3}, // perfectshuffle worst case. 3838 {TTI::SK_PermuteSingleSrc, MVT::v2f64, 1}, // mov. 3839 {TTI::SK_PermuteSingleSrc, MVT::v4i16, 3}, // perfectshuffle worst case. 3840 {TTI::SK_PermuteSingleSrc, MVT::v4f16, 3}, // perfectshuffle worst case. 3841 {TTI::SK_PermuteSingleSrc, MVT::v4bf16, 3}, // same 3842 {TTI::SK_PermuteSingleSrc, MVT::v8i16, 8}, // constpool + load + tbl 3843 {TTI::SK_PermuteSingleSrc, MVT::v8f16, 8}, // constpool + load + tbl 3844 {TTI::SK_PermuteSingleSrc, MVT::v8bf16, 8}, // constpool + load + tbl 3845 {TTI::SK_PermuteSingleSrc, MVT::v8i8, 8}, // constpool + load + tbl 3846 {TTI::SK_PermuteSingleSrc, MVT::v16i8, 8}, // constpool + load + tbl 3847 // Reverse can be lowered with `rev`. 3848 {TTI::SK_Reverse, MVT::v2i32, 1}, // REV64 3849 {TTI::SK_Reverse, MVT::v4i32, 2}, // REV64; EXT 3850 {TTI::SK_Reverse, MVT::v2i64, 1}, // EXT 3851 {TTI::SK_Reverse, MVT::v2f32, 1}, // REV64 3852 {TTI::SK_Reverse, MVT::v4f32, 2}, // REV64; EXT 3853 {TTI::SK_Reverse, MVT::v2f64, 1}, // EXT 3854 {TTI::SK_Reverse, MVT::v8f16, 2}, // REV64; EXT 3855 {TTI::SK_Reverse, MVT::v8i16, 2}, // REV64; EXT 3856 {TTI::SK_Reverse, MVT::v16i8, 2}, // REV64; EXT 3857 {TTI::SK_Reverse, MVT::v4f16, 1}, // REV64 3858 {TTI::SK_Reverse, MVT::v4i16, 1}, // REV64 3859 {TTI::SK_Reverse, MVT::v8i8, 1}, // REV64 3860 // Splice can all be lowered as `ext`. 3861 {TTI::SK_Splice, MVT::v2i32, 1}, 3862 {TTI::SK_Splice, MVT::v4i32, 1}, 3863 {TTI::SK_Splice, MVT::v2i64, 1}, 3864 {TTI::SK_Splice, MVT::v2f32, 1}, 3865 {TTI::SK_Splice, MVT::v4f32, 1}, 3866 {TTI::SK_Splice, MVT::v2f64, 1}, 3867 {TTI::SK_Splice, MVT::v8f16, 1}, 3868 {TTI::SK_Splice, MVT::v8bf16, 1}, 3869 {TTI::SK_Splice, MVT::v8i16, 1}, 3870 {TTI::SK_Splice, MVT::v16i8, 1}, 3871 {TTI::SK_Splice, MVT::v4bf16, 1}, 3872 {TTI::SK_Splice, MVT::v4f16, 1}, 3873 {TTI::SK_Splice, MVT::v4i16, 1}, 3874 {TTI::SK_Splice, MVT::v8i8, 1}, 3875 // Broadcast shuffle kinds for scalable vectors 3876 {TTI::SK_Broadcast, MVT::nxv16i8, 1}, 3877 {TTI::SK_Broadcast, MVT::nxv8i16, 1}, 3878 {TTI::SK_Broadcast, MVT::nxv4i32, 1}, 3879 {TTI::SK_Broadcast, MVT::nxv2i64, 1}, 3880 {TTI::SK_Broadcast, MVT::nxv2f16, 1}, 3881 {TTI::SK_Broadcast, MVT::nxv4f16, 1}, 3882 {TTI::SK_Broadcast, MVT::nxv8f16, 1}, 3883 {TTI::SK_Broadcast, MVT::nxv2bf16, 1}, 3884 {TTI::SK_Broadcast, MVT::nxv4bf16, 1}, 3885 {TTI::SK_Broadcast, MVT::nxv8bf16, 1}, 3886 {TTI::SK_Broadcast, MVT::nxv2f32, 1}, 3887 {TTI::SK_Broadcast, MVT::nxv4f32, 1}, 3888 {TTI::SK_Broadcast, MVT::nxv2f64, 1}, 3889 {TTI::SK_Broadcast, MVT::nxv16i1, 1}, 3890 {TTI::SK_Broadcast, MVT::nxv8i1, 1}, 3891 {TTI::SK_Broadcast, MVT::nxv4i1, 1}, 3892 {TTI::SK_Broadcast, MVT::nxv2i1, 1}, 3893 // Handle the cases for vector.reverse with scalable vectors 3894 {TTI::SK_Reverse, MVT::nxv16i8, 1}, 3895 {TTI::SK_Reverse, MVT::nxv8i16, 1}, 3896 {TTI::SK_Reverse, MVT::nxv4i32, 1}, 3897 {TTI::SK_Reverse, MVT::nxv2i64, 1}, 3898 {TTI::SK_Reverse, MVT::nxv2f16, 1}, 3899 {TTI::SK_Reverse, MVT::nxv4f16, 1}, 3900 {TTI::SK_Reverse, MVT::nxv8f16, 1}, 3901 {TTI::SK_Reverse, MVT::nxv2bf16, 1}, 3902 {TTI::SK_Reverse, MVT::nxv4bf16, 1}, 3903 {TTI::SK_Reverse, MVT::nxv8bf16, 1}, 3904 {TTI::SK_Reverse, MVT::nxv2f32, 1}, 3905 {TTI::SK_Reverse, MVT::nxv4f32, 1}, 3906 {TTI::SK_Reverse, MVT::nxv2f64, 1}, 3907 {TTI::SK_Reverse, MVT::nxv16i1, 1}, 3908 {TTI::SK_Reverse, MVT::nxv8i1, 1}, 3909 {TTI::SK_Reverse, MVT::nxv4i1, 1}, 3910 {TTI::SK_Reverse, MVT::nxv2i1, 1}, 3911 }; 3912 if (const auto *Entry = CostTableLookup(ShuffleTbl, Kind, LT.second)) 3913 return LT.first * Entry->Cost; 3914 } 3915 3916 if (Kind == TTI::SK_Splice && isa<ScalableVectorType>(Tp)) 3917 return getSpliceCost(Tp, Index); 3918 3919 // Inserting a subvector can often be done with either a D, S or H register 3920 // move, so long as the inserted vector is "aligned". 3921 if (Kind == TTI::SK_InsertSubvector && LT.second.isFixedLengthVector() && 3922 LT.second.getSizeInBits() <= 128 && SubTp) { 3923 std::pair<InstructionCost, MVT> SubLT = getTypeLegalizationCost(SubTp); 3924 if (SubLT.second.isVector()) { 3925 int NumElts = LT.second.getVectorNumElements(); 3926 int NumSubElts = SubLT.second.getVectorNumElements(); 3927 if ((Index % NumSubElts) == 0 && (NumElts % NumSubElts) == 0) 3928 return SubLT.first; 3929 } 3930 } 3931 3932 return BaseT::getShuffleCost(Kind, Tp, Mask, CostKind, Index, SubTp); 3933 } 3934 3935 static bool containsDecreasingPointers(Loop *TheLoop, 3936 PredicatedScalarEvolution *PSE) { 3937 const auto &Strides = DenseMap<Value *, const SCEV *>(); 3938 for (BasicBlock *BB : TheLoop->blocks()) { 3939 // Scan the instructions in the block and look for addresses that are 3940 // consecutive and decreasing. 3941 for (Instruction &I : *BB) { 3942 if (isa<LoadInst>(&I) || isa<StoreInst>(&I)) { 3943 Value *Ptr = getLoadStorePointerOperand(&I); 3944 Type *AccessTy = getLoadStoreType(&I); 3945 if (getPtrStride(*PSE, AccessTy, Ptr, TheLoop, Strides, /*Assume=*/true, 3946 /*ShouldCheckWrap=*/false) 3947 .value_or(0) < 0) 3948 return true; 3949 } 3950 } 3951 } 3952 return false; 3953 } 3954 3955 bool AArch64TTIImpl::preferPredicateOverEpilogue(TailFoldingInfo *TFI) { 3956 if (!ST->hasSVE()) 3957 return false; 3958 3959 // We don't currently support vectorisation with interleaving for SVE - with 3960 // such loops we're better off not using tail-folding. This gives us a chance 3961 // to fall back on fixed-width vectorisation using NEON's ld2/st2/etc. 3962 if (TFI->IAI->hasGroups()) 3963 return false; 3964 3965 TailFoldingOpts Required = TailFoldingOpts::Disabled; 3966 if (TFI->LVL->getReductionVars().size()) 3967 Required |= TailFoldingOpts::Reductions; 3968 if (TFI->LVL->getFixedOrderRecurrences().size()) 3969 Required |= TailFoldingOpts::Recurrences; 3970 3971 // We call this to discover whether any load/store pointers in the loop have 3972 // negative strides. This will require extra work to reverse the loop 3973 // predicate, which may be expensive. 3974 if (containsDecreasingPointers(TFI->LVL->getLoop(), 3975 TFI->LVL->getPredicatedScalarEvolution())) 3976 Required |= TailFoldingOpts::Reverse; 3977 if (Required == TailFoldingOpts::Disabled) 3978 Required |= TailFoldingOpts::Simple; 3979 3980 if (!TailFoldingOptionLoc.satisfies(ST->getSVETailFoldingDefaultOpts(), 3981 Required)) 3982 return false; 3983 3984 // Don't tail-fold for tight loops where we would be better off interleaving 3985 // with an unpredicated loop. 3986 unsigned NumInsns = 0; 3987 for (BasicBlock *BB : TFI->LVL->getLoop()->blocks()) { 3988 NumInsns += BB->sizeWithoutDebug(); 3989 } 3990 3991 // We expect 4 of these to be a IV PHI, IV add, IV compare and branch. 3992 return NumInsns >= SVETailFoldInsnThreshold; 3993 } 3994 3995 InstructionCost 3996 AArch64TTIImpl::getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, 3997 int64_t BaseOffset, bool HasBaseReg, 3998 int64_t Scale, unsigned AddrSpace) const { 3999 // Scaling factors are not free at all. 4000 // Operands | Rt Latency 4001 // ------------------------------------------- 4002 // Rt, [Xn, Xm] | 4 4003 // ------------------------------------------- 4004 // Rt, [Xn, Xm, lsl #imm] | Rn: 4 Rm: 5 4005 // Rt, [Xn, Wm, <extend> #imm] | 4006 TargetLoweringBase::AddrMode AM; 4007 AM.BaseGV = BaseGV; 4008 AM.BaseOffs = BaseOffset; 4009 AM.HasBaseReg = HasBaseReg; 4010 AM.Scale = Scale; 4011 if (getTLI()->isLegalAddressingMode(DL, AM, Ty, AddrSpace)) 4012 // Scale represents reg2 * scale, thus account for 1 if 4013 // it is not equal to 0 or 1. 4014 return AM.Scale != 0 && AM.Scale != 1; 4015 return -1; 4016 } 4017