xref: /freebsd/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- RISCVTargetTransformInfo.h - RISC-V specific TTI ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file defines a TargetTransformInfoImplBase conforming object specific
10 /// to the RISC-V target machine. It uses the target's detailed information to
11 /// provide more precise answers to certain TTI queries, while letting the
12 /// target independent and default TTI implementations handle the rest.
13 ///
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
17 #define LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
18 
19 #include "RISCVSubtarget.h"
20 #include "RISCVTargetMachine.h"
21 #include "llvm/Analysis/TargetTransformInfo.h"
22 #include "llvm/CodeGen/BasicTTIImpl.h"
23 #include "llvm/IR/Function.h"
24 #include <optional>
25 
26 namespace llvm {
27 
28 class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
29   using BaseT = BasicTTIImplBase<RISCVTTIImpl>;
30   using TTI = TargetTransformInfo;
31 
32   friend BaseT;
33 
34   const RISCVSubtarget *ST;
35   const RISCVTargetLowering *TLI;
36 
getST()37   const RISCVSubtarget *getST() const { return ST; }
getTLI()38   const RISCVTargetLowering *getTLI() const { return TLI; }
39 
40   /// This function returns an estimate for VL to be used in VL based terms
41   /// of the cost model.  For fixed length vectors, this is simply the
42   /// vector length.  For scalable vectors, we return results consistent
43   /// with getVScaleForTuning under the assumption that clients are also
44   /// using that when comparing costs between scalar and vector representation.
45   /// This does unfortunately mean that we can both undershoot and overshot
46   /// the true cost significantly if getVScaleForTuning is wildly off for the
47   /// actual target hardware.
48   unsigned getEstimatedVLFor(VectorType *Ty) const;
49 
50   /// This function calculates the costs for one or more RVV opcodes based
51   /// on the vtype and the cost kind.
52   /// \param Opcodes A list of opcodes of the RVV instruction to evaluate.
53   /// \param VT The MVT of vtype associated with the RVV instructions.
54   /// For widening/narrowing instructions where the result and source types
55   /// differ, it is important to check the spec to determine whether the vtype
56   /// refers to the result or source type.
57   /// \param CostKind The type of cost to compute.
58   InstructionCost getRISCVInstructionCost(ArrayRef<unsigned> OpCodes, MVT VT,
59                                           TTI::TargetCostKind CostKind) const;
60 
61   /// Return the cost of accessing a constant pool entry of the specified
62   /// type.
63   InstructionCost getConstantPoolLoadCost(Type *Ty,
64                                           TTI::TargetCostKind CostKind) const;
65 
66   /// If this shuffle can be lowered as a masked slide pair (at worst),
67   /// return a cost for it.
68   InstructionCost getSlideCost(FixedVectorType *Tp, ArrayRef<int> Mask,
69                                TTI::TargetCostKind CostKind) const;
70 
71 public:
RISCVTTIImpl(const RISCVTargetMachine * TM,const Function & F)72   explicit RISCVTTIImpl(const RISCVTargetMachine *TM, const Function &F)
73       : BaseT(TM, F.getDataLayout()), ST(TM->getSubtargetImpl(F)),
74         TLI(ST->getTargetLowering()) {}
75 
76   /// Return the cost of materializing an immediate for a value operand of
77   /// a store instruction.
78   InstructionCost getStoreImmCost(Type *VecTy, TTI::OperandValueInfo OpInfo,
79                                   TTI::TargetCostKind CostKind) const;
80 
81   InstructionCost getIntImmCost(const APInt &Imm, Type *Ty,
82                                 TTI::TargetCostKind CostKind) const override;
83   InstructionCost getIntImmCostInst(unsigned Opcode, unsigned Idx,
84                                     const APInt &Imm, Type *Ty,
85                                     TTI::TargetCostKind CostKind,
86                                     Instruction *Inst = nullptr) const override;
87   InstructionCost
88   getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx, const APInt &Imm,
89                       Type *Ty, TTI::TargetCostKind CostKind) const override;
90 
91   /// \name EVL Support for predicated vectorization.
92   /// Whether the target supports the %evl parameter of VP intrinsic efficiently
93   /// in hardware. (see LLVM Language Reference - "Vector Predication
94   /// Intrinsics",
95   /// https://llvm.org/docs/LangRef.html#vector-predication-intrinsics and
96   /// "IR-level VP intrinsics",
97   /// https://llvm.org/docs/Proposals/VectorPredication.html#ir-level-vp-intrinsics).
98   bool hasActiveVectorLength() const override;
99 
100   TargetTransformInfo::PopcntSupportKind
101   getPopcntSupport(unsigned TyWidth) const override;
102 
103   InstructionCost getPartialReductionCost(
104       unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
105       ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
106       TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
107       TTI::TargetCostKind CostKind) const override;
108 
109   bool shouldExpandReduction(const IntrinsicInst *II) const override;
supportsScalableVectors()110   bool supportsScalableVectors() const override {
111     return ST->hasVInstructions();
112   }
enableOrderedReductions()113   bool enableOrderedReductions() const override { return true; }
enableScalableVectorization()114   bool enableScalableVectorization() const override {
115     return ST->hasVInstructions();
116   }
117   TailFoldingStyle
getPreferredTailFoldingStyle(bool IVUpdateMayOverflow)118   getPreferredTailFoldingStyle(bool IVUpdateMayOverflow) const override {
119     return ST->hasVInstructions() ? TailFoldingStyle::Data
120                                   : TailFoldingStyle::DataWithoutLaneMask;
121   }
122   std::optional<unsigned> getMaxVScale() const override;
123   std::optional<unsigned> getVScaleForTuning() const override;
124 
125   TypeSize
126   getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const override;
127 
128   unsigned getRegUsageForType(Type *Ty) const override;
129 
130   unsigned getMaximumVF(unsigned ElemWidth, unsigned Opcode) const override;
131 
preferAlternateOpcodeVectorization()132   bool preferAlternateOpcodeVectorization() const override { return false; }
133 
preferEpilogueVectorization()134   bool preferEpilogueVectorization() const override {
135     // Epilogue vectorization is usually unprofitable - tail folding or
136     // a smaller VF would have been better.  This a blunt hammer - we
137     // should re-examine this once vectorization is better tuned.
138     return false;
139   }
140 
141   InstructionCost
142   getMaskedMemoryOpCost(unsigned Opcode, Type *Src, Align Alignment,
143                         unsigned AddressSpace,
144                         TTI::TargetCostKind CostKind) const override;
145 
146   InstructionCost
147   getPointersChainCost(ArrayRef<const Value *> Ptrs, const Value *Base,
148                        const TTI::PointersChainInfo &Info, Type *AccessTy,
149                        TTI::TargetCostKind CostKind) const override;
150 
151   void getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
152                                TTI::UnrollingPreferences &UP,
153                                OptimizationRemarkEmitter *ORE) const override;
154 
155   void getPeelingPreferences(Loop *L, ScalarEvolution &SE,
156                              TTI::PeelingPreferences &PP) const override;
157 
getMinVectorRegisterBitWidth()158   unsigned getMinVectorRegisterBitWidth() const override {
159     return ST->useRVVForFixedLengthVectors() ? 16 : 0;
160   }
161 
162   InstructionCost
163   getShuffleCost(TTI::ShuffleKind Kind, VectorType *DstTy, VectorType *SrcTy,
164                  ArrayRef<int> Mask, TTI::TargetCostKind CostKind, int Index,
165                  VectorType *SubTp, ArrayRef<const Value *> Args = {},
166                  const Instruction *CxtI = nullptr) const override;
167 
168   InstructionCost getScalarizationOverhead(
169       VectorType *Ty, const APInt &DemandedElts, bool Insert, bool Extract,
170       TTI::TargetCostKind CostKind, bool ForPoisonSrc = true,
171       ArrayRef<Value *> VL = {}) const override;
172 
173   InstructionCost
174   getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
175                         TTI::TargetCostKind CostKind) const override;
176 
177   InstructionCost getInterleavedMemoryOpCost(
178       unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
179       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
180       bool UseMaskForCond = false, bool UseMaskForGaps = false) const override;
181 
182   InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
183                                          const Value *Ptr, bool VariableMask,
184                                          Align Alignment,
185                                          TTI::TargetCostKind CostKind,
186                                          const Instruction *I) const override;
187 
188   InstructionCost
189   getExpandCompressMemoryOpCost(unsigned Opcode, Type *Src, bool VariableMask,
190                                 Align Alignment, TTI::TargetCostKind CostKind,
191                                 const Instruction *I = nullptr) const override;
192 
193   InstructionCost getStridedMemoryOpCost(unsigned Opcode, Type *DataTy,
194                                          const Value *Ptr, bool VariableMask,
195                                          Align Alignment,
196                                          TTI::TargetCostKind CostKind,
197                                          const Instruction *I) const override;
198 
199   InstructionCost
200   getCostOfKeepingLiveOverCall(ArrayRef<Type *> Tys) const override;
201 
202   InstructionCost
203   getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
204                    TTI::CastContextHint CCH, TTI::TargetCostKind CostKind,
205                    const Instruction *I = nullptr) const override;
206 
207   InstructionCost
208   getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty, FastMathFlags FMF,
209                          TTI::TargetCostKind CostKind) const override;
210 
211   InstructionCost
212   getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
213                              std::optional<FastMathFlags> FMF,
214                              TTI::TargetCostKind CostKind) const override;
215 
216   InstructionCost
217   getExtendedReductionCost(unsigned Opcode, bool IsUnsigned, Type *ResTy,
218                            VectorType *ValTy, std::optional<FastMathFlags> FMF,
219                            TTI::TargetCostKind CostKind) const override;
220 
221   InstructionCost getMemoryOpCost(
222       unsigned Opcode, Type *Src, Align Alignment, unsigned AddressSpace,
223       TTI::TargetCostKind CostKind,
224       TTI::OperandValueInfo OpdInfo = {TTI::OK_AnyValue, TTI::OP_None},
225       const Instruction *I = nullptr) const override;
226 
227   InstructionCost getCmpSelInstrCost(
228       unsigned Opcode, Type *ValTy, Type *CondTy, CmpInst::Predicate VecPred,
229       TTI::TargetCostKind CostKind,
230       TTI::OperandValueInfo Op1Info = {TTI::OK_AnyValue, TTI::OP_None},
231       TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
232       const Instruction *I = nullptr) const override;
233 
234   InstructionCost getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind,
235                                  const Instruction *I = nullptr) const override;
236 
237   using BaseT::getVectorInstrCost;
238   InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
239                                      TTI::TargetCostKind CostKind,
240                                      unsigned Index, const Value *Op0,
241                                      const Value *Op1) const override;
242 
243   InstructionCost getArithmeticInstrCost(
244       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
245       TTI::OperandValueInfo Op1Info = {TTI::OK_AnyValue, TTI::OP_None},
246       TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
247       ArrayRef<const Value *> Args = {},
248       const Instruction *CxtI = nullptr) const override;
249 
isElementTypeLegalForScalableVector(Type * Ty)250   bool isElementTypeLegalForScalableVector(Type *Ty) const override {
251     return TLI->isLegalElementTypeForRVV(TLI->getValueType(DL, Ty));
252   }
253 
isLegalMaskedLoadStore(Type * DataType,Align Alignment)254   bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) const {
255     if (!ST->hasVInstructions())
256       return false;
257 
258     EVT DataTypeVT = TLI->getValueType(DL, DataType);
259 
260     // Only support fixed vectors if we know the minimum vector size.
261     if (DataTypeVT.isFixedLengthVector() && !ST->useRVVForFixedLengthVectors())
262       return false;
263 
264     EVT ElemType = DataTypeVT.getScalarType();
265     if (!ST->enableUnalignedVectorMem() && Alignment < ElemType.getStoreSize())
266       return false;
267 
268     return TLI->isLegalElementTypeForRVV(ElemType);
269   }
270 
isLegalMaskedLoad(Type * DataType,Align Alignment,unsigned)271   bool isLegalMaskedLoad(Type *DataType, Align Alignment,
272                          unsigned /*AddressSpace*/) const override {
273     return isLegalMaskedLoadStore(DataType, Alignment);
274   }
isLegalMaskedStore(Type * DataType,Align Alignment,unsigned)275   bool isLegalMaskedStore(Type *DataType, Align Alignment,
276                           unsigned /*AddressSpace*/) const override {
277     return isLegalMaskedLoadStore(DataType, Alignment);
278   }
279 
isLegalMaskedGatherScatter(Type * DataType,Align Alignment)280   bool isLegalMaskedGatherScatter(Type *DataType, Align Alignment) const {
281     if (!ST->hasVInstructions())
282       return false;
283 
284     EVT DataTypeVT = TLI->getValueType(DL, DataType);
285 
286     // Only support fixed vectors if we know the minimum vector size.
287     if (DataTypeVT.isFixedLengthVector() && !ST->useRVVForFixedLengthVectors())
288       return false;
289 
290     // We also need to check if the vector of address is valid.
291     EVT PointerTypeVT = EVT(TLI->getPointerTy(DL));
292     if (DataTypeVT.isScalableVector() &&
293         !TLI->isLegalElementTypeForRVV(PointerTypeVT))
294       return false;
295 
296     EVT ElemType = DataTypeVT.getScalarType();
297     if (!ST->enableUnalignedVectorMem() && Alignment < ElemType.getStoreSize())
298       return false;
299 
300     return TLI->isLegalElementTypeForRVV(ElemType);
301   }
302 
isLegalMaskedGather(Type * DataType,Align Alignment)303   bool isLegalMaskedGather(Type *DataType, Align Alignment) const override {
304     return isLegalMaskedGatherScatter(DataType, Alignment);
305   }
isLegalMaskedScatter(Type * DataType,Align Alignment)306   bool isLegalMaskedScatter(Type *DataType, Align Alignment) const override {
307     return isLegalMaskedGatherScatter(DataType, Alignment);
308   }
309 
forceScalarizeMaskedGather(VectorType * VTy,Align Alignment)310   bool forceScalarizeMaskedGather(VectorType *VTy,
311                                   Align Alignment) const override {
312     // Scalarize masked gather for RV64 if EEW=64 indices aren't supported.
313     return ST->is64Bit() && !ST->hasVInstructionsI64();
314   }
315 
forceScalarizeMaskedScatter(VectorType * VTy,Align Alignment)316   bool forceScalarizeMaskedScatter(VectorType *VTy,
317                                    Align Alignment) const override {
318     // Scalarize masked scatter for RV64 if EEW=64 indices aren't supported.
319     return ST->is64Bit() && !ST->hasVInstructionsI64();
320   }
321 
isLegalStridedLoadStore(Type * DataType,Align Alignment)322   bool isLegalStridedLoadStore(Type *DataType, Align Alignment) const override {
323     EVT DataTypeVT = TLI->getValueType(DL, DataType);
324     return TLI->isLegalStridedLoadStore(DataTypeVT, Alignment);
325   }
326 
isLegalInterleavedAccessType(VectorType * VTy,unsigned Factor,Align Alignment,unsigned AddrSpace)327   bool isLegalInterleavedAccessType(VectorType *VTy, unsigned Factor,
328                                     Align Alignment,
329                                     unsigned AddrSpace) const override {
330     return TLI->isLegalInterleavedAccessType(VTy, Factor, Alignment, AddrSpace,
331                                              DL);
332   }
333 
334   bool isLegalMaskedExpandLoad(Type *DataType, Align Alignment) const override;
335 
336   bool isLegalMaskedCompressStore(Type *DataTy, Align Alignment) const override;
337 
isVScaleKnownToBeAPowerOfTwo()338   bool isVScaleKnownToBeAPowerOfTwo() const override {
339     return TLI->isVScaleKnownToBeAPowerOfTwo();
340   }
341 
342   /// \returns How the target needs this vector-predicated operation to be
343   /// transformed.
344   TargetTransformInfo::VPLegalization
getVPLegalizationStrategy(const VPIntrinsic & PI)345   getVPLegalizationStrategy(const VPIntrinsic &PI) const override {
346     using VPLegalization = TargetTransformInfo::VPLegalization;
347     if (!ST->hasVInstructions() ||
348         (PI.getIntrinsicID() == Intrinsic::vp_reduce_mul &&
349          cast<VectorType>(PI.getArgOperand(1)->getType())
350                  ->getElementType()
351                  ->getIntegerBitWidth() != 1))
352       return VPLegalization(VPLegalization::Discard, VPLegalization::Convert);
353     return VPLegalization(VPLegalization::Legal, VPLegalization::Legal);
354   }
355 
isLegalToVectorizeReduction(const RecurrenceDescriptor & RdxDesc,ElementCount VF)356   bool isLegalToVectorizeReduction(const RecurrenceDescriptor &RdxDesc,
357                                    ElementCount VF) const override {
358     if (!VF.isScalable())
359       return true;
360 
361     Type *Ty = RdxDesc.getRecurrenceType();
362     if (!TLI->isLegalElementTypeForRVV(TLI->getValueType(DL, Ty)))
363       return false;
364 
365     switch (RdxDesc.getRecurrenceKind()) {
366     case RecurKind::Add:
367     case RecurKind::And:
368     case RecurKind::Or:
369     case RecurKind::Xor:
370     case RecurKind::SMin:
371     case RecurKind::SMax:
372     case RecurKind::UMin:
373     case RecurKind::UMax:
374     case RecurKind::FMin:
375     case RecurKind::FMax:
376       return true;
377     case RecurKind::AnyOf:
378     case RecurKind::FAdd:
379     case RecurKind::FMulAdd:
380       // We can't promote f16/bf16 fadd reductions and scalable vectors can't be
381       // expanded.
382       if (Ty->isBFloatTy() || (Ty->isHalfTy() && !ST->hasVInstructionsF16()))
383         return false;
384       return true;
385     default:
386       return false;
387     }
388   }
389 
getMaxInterleaveFactor(ElementCount VF)390   unsigned getMaxInterleaveFactor(ElementCount VF) const override {
391     // Don't interleave if the loop has been vectorized with scalable vectors.
392     if (VF.isScalable())
393       return 1;
394     // If the loop will not be vectorized, don't interleave the loop.
395     // Let regular unroll to unroll the loop.
396     return VF.isScalar() ? 1 : ST->getMaxInterleaveFactor();
397   }
398 
enableInterleavedAccessVectorization()399   bool enableInterleavedAccessVectorization() const override { return true; }
400 
401   unsigned getMinTripCountTailFoldingThreshold() const override;
402 
403   enum RISCVRegisterClass { GPRRC, FPRRC, VRRC };
getNumberOfRegisters(unsigned ClassID)404   unsigned getNumberOfRegisters(unsigned ClassID) const override {
405     switch (ClassID) {
406     case RISCVRegisterClass::GPRRC:
407       // 31 = 32 GPR - x0 (zero register)
408       // FIXME: Should we exclude fixed registers like SP, TP or GP?
409       return 31;
410     case RISCVRegisterClass::FPRRC:
411       if (ST->hasStdExtF())
412         return 32;
413       return 0;
414     case RISCVRegisterClass::VRRC:
415       // Although there are 32 vector registers, v0 is special in that it is the
416       // only register that can be used to hold a mask.
417       // FIXME: Should we conservatively return 31 as the number of usable
418       // vector registers?
419       return ST->hasVInstructions() ? 32 : 0;
420     }
421     llvm_unreachable("unknown register class");
422   }
423 
424   TTI::AddressingModeKind
425   getPreferredAddressingMode(const Loop *L, ScalarEvolution *SE) const override;
426 
427   unsigned getRegisterClassForType(bool Vector,
428                                    Type *Ty = nullptr) const override {
429     if (Vector)
430       return RISCVRegisterClass::VRRC;
431     if (!Ty)
432       return RISCVRegisterClass::GPRRC;
433 
434     Type *ScalarTy = Ty->getScalarType();
435     if ((ScalarTy->isHalfTy() && ST->hasStdExtZfhmin()) ||
436         (ScalarTy->isFloatTy() && ST->hasStdExtF()) ||
437         (ScalarTy->isDoubleTy() && ST->hasStdExtD())) {
438       return RISCVRegisterClass::FPRRC;
439     }
440 
441     return RISCVRegisterClass::GPRRC;
442   }
443 
getRegisterClassName(unsigned ClassID)444   const char *getRegisterClassName(unsigned ClassID) const override {
445     switch (ClassID) {
446     case RISCVRegisterClass::GPRRC:
447       return "RISCV::GPRRC";
448     case RISCVRegisterClass::FPRRC:
449       return "RISCV::FPRRC";
450     case RISCVRegisterClass::VRRC:
451       return "RISCV::VRRC";
452     }
453     llvm_unreachable("unknown register class");
454   }
455 
456   bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
457                      const TargetTransformInfo::LSRCost &C2) const override;
458 
459   bool shouldConsiderAddressTypePromotion(
460       const Instruction &I,
461       bool &AllowPromotionWithoutCommonHeader) const override;
getMinPageSize()462   std::optional<unsigned> getMinPageSize() const override { return 4096; }
463   /// Return true if the (vector) instruction I will be lowered to an
464   /// instruction with a scalar splat operand for the given Operand number.
465   bool canSplatOperand(Instruction *I, int Operand) const;
466   /// Return true if a vector instruction will lower to a target instruction
467   /// able to splat the given operand.
468   bool canSplatOperand(unsigned Opcode, int Operand) const;
469 
470   bool isProfitableToSinkOperands(Instruction *I,
471                                   SmallVectorImpl<Use *> &Ops) const override;
472 
473   TTI::MemCmpExpansionOptions
474   enableMemCmpExpansion(bool OptSize, bool IsZeroCmp) const override;
475 };
476 
477 } // end namespace llvm
478 
479 #endif // LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
480