1 //===- AArch64TargetTransformInfo.h - AArch64 specific TTI ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// \file 9 /// This file a TargetTransformInfo::Concept conforming object specific to the 10 /// AArch64 target machine. It uses the target's detailed information to 11 /// provide more precise answers to certain TTI queries, while letting the 12 /// target independent and default TTI implementations handle the rest. 13 /// 14 //===----------------------------------------------------------------------===// 15 16 #ifndef LLVM_LIB_TARGET_AARCH64_AARCH64TARGETTRANSFORMINFO_H 17 #define LLVM_LIB_TARGET_AARCH64_AARCH64TARGETTRANSFORMINFO_H 18 19 #include "AArch64.h" 20 #include "AArch64Subtarget.h" 21 #include "AArch64TargetMachine.h" 22 #include "llvm/ADT/ArrayRef.h" 23 #include "llvm/Analysis/TargetTransformInfo.h" 24 #include "llvm/CodeGen/BasicTTIImpl.h" 25 #include "llvm/IR/Function.h" 26 #include "llvm/IR/Intrinsics.h" 27 #include <cstdint> 28 #include <optional> 29 30 namespace llvm { 31 32 class APInt; 33 class Instruction; 34 class IntrinsicInst; 35 class Loop; 36 class SCEV; 37 class ScalarEvolution; 38 class Type; 39 class Value; 40 class VectorType; 41 42 class AArch64TTIImpl : public BasicTTIImplBase<AArch64TTIImpl> { 43 using BaseT = BasicTTIImplBase<AArch64TTIImpl>; 44 using TTI = TargetTransformInfo; 45 46 friend BaseT; 47 48 const AArch64Subtarget *ST; 49 const AArch64TargetLowering *TLI; 50 getST()51 const AArch64Subtarget *getST() const { return ST; } getTLI()52 const AArch64TargetLowering *getTLI() const { return TLI; } 53 54 enum MemIntrinsicType { 55 VECTOR_LDST_TWO_ELEMENTS, 56 VECTOR_LDST_THREE_ELEMENTS, 57 VECTOR_LDST_FOUR_ELEMENTS 58 }; 59 60 bool isWideningInstruction(Type *DstTy, unsigned Opcode, 61 ArrayRef<const Value *> Args, 62 Type *SrcOverrideTy = nullptr); 63 64 // A helper function called by 'getVectorInstrCost'. 65 // 66 // 'Val' and 'Index' are forwarded from 'getVectorInstrCost'; 'HasRealUse' 67 // indicates whether the vector instruction is available in the input IR or 68 // just imaginary in vectorizer passes. 69 InstructionCost getVectorInstrCostHelper(const Instruction *I, Type *Val, 70 unsigned Index, bool HasRealUse); 71 72 public: AArch64TTIImpl(const AArch64TargetMachine * TM,const Function & F)73 explicit AArch64TTIImpl(const AArch64TargetMachine *TM, const Function &F) 74 : BaseT(TM, F.getDataLayout()), ST(TM->getSubtargetImpl(F)), 75 TLI(ST->getTargetLowering()) {} 76 77 bool areInlineCompatible(const Function *Caller, 78 const Function *Callee) const; 79 80 bool areTypesABICompatible(const Function *Caller, const Function *Callee, 81 const ArrayRef<Type *> &Types) const; 82 83 unsigned getInlineCallPenalty(const Function *F, const CallBase &Call, 84 unsigned DefaultCallPenalty) const; 85 86 /// \name Scalar TTI Implementations 87 /// @{ 88 89 using BaseT::getIntImmCost; 90 InstructionCost getIntImmCost(int64_t Val); 91 InstructionCost getIntImmCost(const APInt &Imm, Type *Ty, 92 TTI::TargetCostKind CostKind); 93 InstructionCost getIntImmCostInst(unsigned Opcode, unsigned Idx, 94 const APInt &Imm, Type *Ty, 95 TTI::TargetCostKind CostKind, 96 Instruction *Inst = nullptr); 97 InstructionCost getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx, 98 const APInt &Imm, Type *Ty, 99 TTI::TargetCostKind CostKind); 100 TTI::PopcntSupportKind getPopcntSupport(unsigned TyWidth); 101 102 /// @} 103 104 /// \name Vector TTI Implementations 105 /// @{ 106 enableInterleavedAccessVectorization()107 bool enableInterleavedAccessVectorization() { return true; } 108 enableMaskedInterleavedAccessVectorization()109 bool enableMaskedInterleavedAccessVectorization() { return ST->hasSVE(); } 110 getNumberOfRegisters(unsigned ClassID)111 unsigned getNumberOfRegisters(unsigned ClassID) const { 112 bool Vector = (ClassID == 1); 113 if (Vector) { 114 if (ST->hasNEON()) 115 return 32; 116 return 0; 117 } 118 return 31; 119 } 120 121 InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA, 122 TTI::TargetCostKind CostKind); 123 124 std::optional<Instruction *> instCombineIntrinsic(InstCombiner &IC, 125 IntrinsicInst &II) const; 126 127 std::optional<Value *> simplifyDemandedVectorEltsIntrinsic( 128 InstCombiner &IC, IntrinsicInst &II, APInt DemandedElts, APInt &UndefElts, 129 APInt &UndefElts2, APInt &UndefElts3, 130 std::function<void(Instruction *, unsigned, APInt, APInt &)> 131 SimplifyAndSetOp) const; 132 133 TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const; 134 getMinVectorRegisterBitWidth()135 unsigned getMinVectorRegisterBitWidth() const { 136 return ST->getMinVectorRegisterBitWidth(); 137 } 138 getVScaleForTuning()139 std::optional<unsigned> getVScaleForTuning() const { 140 return ST->getVScaleForTuning(); 141 } 142 isVScaleKnownToBeAPowerOfTwo()143 bool isVScaleKnownToBeAPowerOfTwo() const { return true; } 144 145 bool shouldMaximizeVectorBandwidth(TargetTransformInfo::RegisterKind K) const; 146 147 /// Try to return an estimate cost factor that can be used as a multiplier 148 /// when scalarizing an operation for a vector with ElementCount \p VF. 149 /// For scalable vectors this currently takes the most pessimistic view based 150 /// upon the maximum possible value for vscale. getMaxNumElements(ElementCount VF)151 unsigned getMaxNumElements(ElementCount VF) const { 152 if (!VF.isScalable()) 153 return VF.getFixedValue(); 154 155 return VF.getKnownMinValue() * ST->getVScaleForTuning(); 156 } 157 158 unsigned getMaxInterleaveFactor(ElementCount VF); 159 160 bool prefersVectorizedAddressing() const; 161 162 InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src, 163 Align Alignment, unsigned AddressSpace, 164 TTI::TargetCostKind CostKind); 165 166 InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy, 167 const Value *Ptr, bool VariableMask, 168 Align Alignment, 169 TTI::TargetCostKind CostKind, 170 const Instruction *I = nullptr); 171 172 bool isExtPartOfAvgExpr(const Instruction *ExtUser, Type *Dst, Type *Src); 173 174 InstructionCost getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src, 175 TTI::CastContextHint CCH, 176 TTI::TargetCostKind CostKind, 177 const Instruction *I = nullptr); 178 179 InstructionCost getExtractWithExtendCost(unsigned Opcode, Type *Dst, 180 VectorType *VecTy, unsigned Index); 181 182 InstructionCost getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind, 183 const Instruction *I = nullptr); 184 185 InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val, 186 TTI::TargetCostKind CostKind, 187 unsigned Index, Value *Op0, Value *Op1); 188 InstructionCost getVectorInstrCost(const Instruction &I, Type *Val, 189 TTI::TargetCostKind CostKind, 190 unsigned Index); 191 192 InstructionCost getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty, 193 FastMathFlags FMF, 194 TTI::TargetCostKind CostKind); 195 196 InstructionCost getArithmeticReductionCostSVE(unsigned Opcode, 197 VectorType *ValTy, 198 TTI::TargetCostKind CostKind); 199 200 InstructionCost getSpliceCost(VectorType *Tp, int Index); 201 202 InstructionCost getArithmeticInstrCost( 203 unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind, 204 TTI::OperandValueInfo Op1Info = {TTI::OK_AnyValue, TTI::OP_None}, 205 TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None}, 206 ArrayRef<const Value *> Args = std::nullopt, 207 const Instruction *CxtI = nullptr); 208 209 InstructionCost getAddressComputationCost(Type *Ty, ScalarEvolution *SE, 210 const SCEV *Ptr); 211 212 InstructionCost getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy, 213 CmpInst::Predicate VecPred, 214 TTI::TargetCostKind CostKind, 215 const Instruction *I = nullptr); 216 217 TTI::MemCmpExpansionOptions enableMemCmpExpansion(bool OptSize, 218 bool IsZeroCmp) const; 219 bool useNeonVector(const Type *Ty) const; 220 221 InstructionCost 222 getMemoryOpCost(unsigned Opcode, Type *Src, MaybeAlign Alignment, 223 unsigned AddressSpace, TTI::TargetCostKind CostKind, 224 TTI::OperandValueInfo OpInfo = {TTI::OK_AnyValue, TTI::OP_None}, 225 const Instruction *I = nullptr); 226 227 InstructionCost getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys); 228 229 void getUnrollingPreferences(Loop *L, ScalarEvolution &SE, 230 TTI::UnrollingPreferences &UP, 231 OptimizationRemarkEmitter *ORE); 232 233 void getPeelingPreferences(Loop *L, ScalarEvolution &SE, 234 TTI::PeelingPreferences &PP); 235 236 Value *getOrCreateResultFromMemIntrinsic(IntrinsicInst *Inst, 237 Type *ExpectedType); 238 239 bool getTgtMemIntrinsic(IntrinsicInst *Inst, MemIntrinsicInfo &Info); 240 isElementTypeLegalForScalableVector(Type * Ty)241 bool isElementTypeLegalForScalableVector(Type *Ty) const { 242 if (Ty->isPointerTy()) 243 return true; 244 245 if (Ty->isBFloatTy() && ST->hasBF16()) 246 return true; 247 248 if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) 249 return true; 250 251 if (Ty->isIntegerTy(1) || Ty->isIntegerTy(8) || Ty->isIntegerTy(16) || 252 Ty->isIntegerTy(32) || Ty->isIntegerTy(64)) 253 return true; 254 255 return false; 256 } 257 isLegalMaskedLoadStore(Type * DataType,Align Alignment)258 bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) { 259 if (!ST->hasSVE()) 260 return false; 261 262 // For fixed vectors, avoid scalarization if using SVE for them. 263 if (isa<FixedVectorType>(DataType) && !ST->useSVEForFixedLengthVectors() && 264 DataType->getPrimitiveSizeInBits() != 128) 265 return false; // Fall back to scalarization of masked operations. 266 267 return isElementTypeLegalForScalableVector(DataType->getScalarType()); 268 } 269 isLegalMaskedLoad(Type * DataType,Align Alignment)270 bool isLegalMaskedLoad(Type *DataType, Align Alignment) { 271 return isLegalMaskedLoadStore(DataType, Alignment); 272 } 273 isLegalMaskedStore(Type * DataType,Align Alignment)274 bool isLegalMaskedStore(Type *DataType, Align Alignment) { 275 return isLegalMaskedLoadStore(DataType, Alignment); 276 } 277 isLegalMaskedGatherScatter(Type * DataType)278 bool isLegalMaskedGatherScatter(Type *DataType) const { 279 if (!ST->isSVEAvailable()) 280 return false; 281 282 // For fixed vectors, scalarize if not using SVE for them. 283 auto *DataTypeFVTy = dyn_cast<FixedVectorType>(DataType); 284 if (DataTypeFVTy && (!ST->useSVEForFixedLengthVectors() || 285 DataTypeFVTy->getNumElements() < 2)) 286 return false; 287 288 return isElementTypeLegalForScalableVector(DataType->getScalarType()); 289 } 290 isLegalMaskedGather(Type * DataType,Align Alignment)291 bool isLegalMaskedGather(Type *DataType, Align Alignment) const { 292 return isLegalMaskedGatherScatter(DataType); 293 } 294 isLegalMaskedScatter(Type * DataType,Align Alignment)295 bool isLegalMaskedScatter(Type *DataType, Align Alignment) const { 296 return isLegalMaskedGatherScatter(DataType); 297 } 298 isLegalBroadcastLoad(Type * ElementTy,ElementCount NumElements)299 bool isLegalBroadcastLoad(Type *ElementTy, ElementCount NumElements) const { 300 // Return true if we can generate a `ld1r` splat load instruction. 301 if (!ST->hasNEON() || NumElements.isScalable()) 302 return false; 303 switch (unsigned ElementBits = ElementTy->getScalarSizeInBits()) { 304 case 8: 305 case 16: 306 case 32: 307 case 64: { 308 // We accept bit-widths >= 64bits and elements {8,16,32,64} bits. 309 unsigned VectorBits = NumElements.getFixedValue() * ElementBits; 310 return VectorBits >= 64; 311 } 312 } 313 return false; 314 } 315 isLegalNTStoreLoad(Type * DataType,Align Alignment)316 bool isLegalNTStoreLoad(Type *DataType, Align Alignment) { 317 // NOTE: The logic below is mostly geared towards LV, which calls it with 318 // vectors with 2 elements. We might want to improve that, if other 319 // users show up. 320 // Nontemporal vector loads/stores can be directly lowered to LDNP/STNP, if 321 // the vector can be halved so that each half fits into a register. That's 322 // the case if the element type fits into a register and the number of 323 // elements is a power of 2 > 1. 324 if (auto *DataTypeTy = dyn_cast<FixedVectorType>(DataType)) { 325 unsigned NumElements = DataTypeTy->getNumElements(); 326 unsigned EltSize = DataTypeTy->getElementType()->getScalarSizeInBits(); 327 return NumElements > 1 && isPowerOf2_64(NumElements) && EltSize >= 8 && 328 EltSize <= 128 && isPowerOf2_64(EltSize); 329 } 330 return BaseT::isLegalNTStore(DataType, Alignment); 331 } 332 isLegalNTStore(Type * DataType,Align Alignment)333 bool isLegalNTStore(Type *DataType, Align Alignment) { 334 return isLegalNTStoreLoad(DataType, Alignment); 335 } 336 isLegalNTLoad(Type * DataType,Align Alignment)337 bool isLegalNTLoad(Type *DataType, Align Alignment) { 338 // Only supports little-endian targets. 339 if (ST->isLittleEndian()) 340 return isLegalNTStoreLoad(DataType, Alignment); 341 return BaseT::isLegalNTLoad(DataType, Alignment); 342 } 343 enableOrderedReductions()344 bool enableOrderedReductions() const { return true; } 345 346 InstructionCost getInterleavedMemoryOpCost( 347 unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices, 348 Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind, 349 bool UseMaskForCond = false, bool UseMaskForGaps = false); 350 351 bool 352 shouldConsiderAddressTypePromotion(const Instruction &I, 353 bool &AllowPromotionWithoutCommonHeader); 354 shouldExpandReduction(const IntrinsicInst * II)355 bool shouldExpandReduction(const IntrinsicInst *II) const { return false; } 356 getGISelRematGlobalCost()357 unsigned getGISelRematGlobalCost() const { 358 return 2; 359 } 360 getMinTripCountTailFoldingThreshold()361 unsigned getMinTripCountTailFoldingThreshold() const { 362 return ST->hasSVE() ? 5 : 0; 363 } 364 getPreferredTailFoldingStyle(bool IVUpdateMayOverflow)365 TailFoldingStyle getPreferredTailFoldingStyle(bool IVUpdateMayOverflow) const { 366 if (ST->hasSVE()) 367 return IVUpdateMayOverflow 368 ? TailFoldingStyle::DataAndControlFlowWithoutRuntimeCheck 369 : TailFoldingStyle::DataAndControlFlow; 370 371 return TailFoldingStyle::DataWithoutLaneMask; 372 } 373 preferFixedOverScalableIfEqualCost()374 bool preferFixedOverScalableIfEqualCost() const { 375 return ST->useFixedOverScalableIfEqualCost(); 376 } 377 378 bool preferPredicateOverEpilogue(TailFoldingInfo *TFI); 379 supportsScalableVectors()380 bool supportsScalableVectors() const { 381 return ST->isSVEorStreamingSVEAvailable(); 382 } 383 384 bool enableScalableVectorization() const; 385 386 bool isLegalToVectorizeReduction(const RecurrenceDescriptor &RdxDesc, 387 ElementCount VF) const; 388 preferPredicatedReductionSelect(unsigned Opcode,Type * Ty,TTI::ReductionFlags Flags)389 bool preferPredicatedReductionSelect(unsigned Opcode, Type *Ty, 390 TTI::ReductionFlags Flags) const { 391 return ST->hasSVE(); 392 } 393 394 InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty, 395 std::optional<FastMathFlags> FMF, 396 TTI::TargetCostKind CostKind); 397 398 InstructionCost getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp, 399 ArrayRef<int> Mask, 400 TTI::TargetCostKind CostKind, int Index, 401 VectorType *SubTp, 402 ArrayRef<const Value *> Args = std::nullopt, 403 const Instruction *CxtI = nullptr); 404 405 InstructionCost getScalarizationOverhead(VectorType *Ty, 406 const APInt &DemandedElts, 407 bool Insert, bool Extract, 408 TTI::TargetCostKind CostKind); 409 410 /// Return the cost of the scaling factor used in the addressing 411 /// mode represented by AM for this target, for a load/store 412 /// of the specified type. 413 /// If the AM is supported, the return value must be >= 0. 414 /// If the AM is not supported, it returns a negative value. 415 InstructionCost getScalingFactorCost(Type *Ty, GlobalValue *BaseGV, 416 StackOffset BaseOffset, bool HasBaseReg, 417 int64_t Scale, unsigned AddrSpace) const; 418 /// @} 419 enableSelectOptimize()420 bool enableSelectOptimize() { return ST->enableSelectOptimize(); } 421 422 bool shouldTreatInstructionLikeSelect(const Instruction *I); 423 getStoreMinimumVF(unsigned VF,Type * ScalarMemTy,Type * ScalarValTy)424 unsigned getStoreMinimumVF(unsigned VF, Type *ScalarMemTy, 425 Type *ScalarValTy) const { 426 // We can vectorize store v4i8. 427 if (ScalarMemTy->isIntegerTy(8) && isPowerOf2_32(VF) && VF >= 4) 428 return 4; 429 430 return BaseT::getStoreMinimumVF(VF, ScalarMemTy, ScalarValTy); 431 } 432 getMinPageSize()433 std::optional<unsigned> getMinPageSize() const { return 4096; } 434 435 bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1, 436 const TargetTransformInfo::LSRCost &C2); 437 }; 438 439 } // end namespace llvm 440 441 #endif // LLVM_LIB_TARGET_AARCH64_AARCH64TARGETTRANSFORMINFO_H 442