xref: /freebsd/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h (revision 357378bbdedf24ce2b90e9bd831af4a9db3ec70a)
1 //===- RISCVTargetTransformInfo.h - RISC-V specific TTI ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file defines a TargetTransformInfo::Concept conforming object specific
10 /// to the RISC-V target machine. It uses the target's detailed information to
11 /// provide more precise answers to certain TTI queries, while letting the
12 /// target independent and default TTI implementations handle the rest.
13 ///
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
17 #define LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
18 
19 #include "RISCVSubtarget.h"
20 #include "RISCVTargetMachine.h"
21 #include "llvm/Analysis/IVDescriptors.h"
22 #include "llvm/Analysis/TargetTransformInfo.h"
23 #include "llvm/CodeGen/BasicTTIImpl.h"
24 #include "llvm/IR/Function.h"
25 #include <optional>
26 
27 namespace llvm {
28 
29 class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
30   using BaseT = BasicTTIImplBase<RISCVTTIImpl>;
31   using TTI = TargetTransformInfo;
32 
33   friend BaseT;
34 
35   const RISCVSubtarget *ST;
36   const RISCVTargetLowering *TLI;
37 
38   const RISCVSubtarget *getST() const { return ST; }
39   const RISCVTargetLowering *getTLI() const { return TLI; }
40 
41   /// This function returns an estimate for VL to be used in VL based terms
42   /// of the cost model.  For fixed length vectors, this is simply the
43   /// vector length.  For scalable vectors, we return results consistent
44   /// with getVScaleForTuning under the assumption that clients are also
45   /// using that when comparing costs between scalar and vector representation.
46   /// This does unfortunately mean that we can both undershoot and overshot
47   /// the true cost significantly if getVScaleForTuning is wildly off for the
48   /// actual target hardware.
49   unsigned getEstimatedVLFor(VectorType *Ty);
50 
51   InstructionCost getRISCVInstructionCost(ArrayRef<unsigned> OpCodes, MVT VT,
52                                           TTI::TargetCostKind CostKind);
53 
54   /// Return the cost of accessing a constant pool entry of the specified
55   /// type.
56   InstructionCost getConstantPoolLoadCost(Type *Ty,
57                                           TTI::TargetCostKind CostKind);
58 public:
59   explicit RISCVTTIImpl(const RISCVTargetMachine *TM, const Function &F)
60       : BaseT(TM, F.getParent()->getDataLayout()), ST(TM->getSubtargetImpl(F)),
61         TLI(ST->getTargetLowering()) {}
62 
63   /// Return the cost of materializing an immediate for a value operand of
64   /// a store instruction.
65   InstructionCost getStoreImmCost(Type *VecTy, TTI::OperandValueInfo OpInfo,
66                                   TTI::TargetCostKind CostKind);
67 
68   InstructionCost getIntImmCost(const APInt &Imm, Type *Ty,
69                                 TTI::TargetCostKind CostKind);
70   InstructionCost getIntImmCostInst(unsigned Opcode, unsigned Idx,
71                                     const APInt &Imm, Type *Ty,
72                                     TTI::TargetCostKind CostKind,
73                                     Instruction *Inst = nullptr);
74   InstructionCost getIntImmCostIntrin(Intrinsic::ID IID, unsigned Idx,
75                                       const APInt &Imm, Type *Ty,
76                                       TTI::TargetCostKind CostKind);
77 
78   TargetTransformInfo::PopcntSupportKind getPopcntSupport(unsigned TyWidth);
79 
80   bool shouldExpandReduction(const IntrinsicInst *II) const;
81   bool supportsScalableVectors() const { return ST->hasVInstructions(); }
82   bool enableOrderedReductions() const { return true; }
83   bool enableScalableVectorization() const { return ST->hasVInstructions(); }
84   TailFoldingStyle
85   getPreferredTailFoldingStyle(bool IVUpdateMayOverflow) const {
86     return ST->hasVInstructions() ? TailFoldingStyle::Data
87                                   : TailFoldingStyle::DataWithoutLaneMask;
88   }
89   std::optional<unsigned> getMaxVScale() const;
90   std::optional<unsigned> getVScaleForTuning() const;
91 
92   TypeSize getRegisterBitWidth(TargetTransformInfo::RegisterKind K) const;
93 
94   unsigned getRegUsageForType(Type *Ty);
95 
96   unsigned getMaximumVF(unsigned ElemWidth, unsigned Opcode) const;
97 
98   bool preferEpilogueVectorization() const {
99     // Epilogue vectorization is usually unprofitable - tail folding or
100     // a smaller VF would have been better.  This a blunt hammer - we
101     // should re-examine this once vectorization is better tuned.
102     return false;
103   }
104 
105   InstructionCost getMaskedMemoryOpCost(unsigned Opcode, Type *Src,
106                                         Align Alignment, unsigned AddressSpace,
107                                         TTI::TargetCostKind CostKind);
108 
109   InstructionCost getPointersChainCost(ArrayRef<const Value *> Ptrs,
110                                        const Value *Base,
111                                        const TTI::PointersChainInfo &Info,
112                                        Type *AccessTy,
113                                        TTI::TargetCostKind CostKind);
114 
115   void getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
116                                TTI::UnrollingPreferences &UP,
117                                OptimizationRemarkEmitter *ORE);
118 
119   void getPeelingPreferences(Loop *L, ScalarEvolution &SE,
120                              TTI::PeelingPreferences &PP);
121 
122   unsigned getMinVectorRegisterBitWidth() const {
123     return ST->useRVVForFixedLengthVectors() ? 16 : 0;
124   }
125 
126   InstructionCost getShuffleCost(TTI::ShuffleKind Kind, VectorType *Tp,
127                                  ArrayRef<int> Mask,
128                                  TTI::TargetCostKind CostKind, int Index,
129                                  VectorType *SubTp,
130                                  ArrayRef<const Value *> Args = std::nullopt);
131 
132   InstructionCost getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
133                                         TTI::TargetCostKind CostKind);
134 
135   InstructionCost getInterleavedMemoryOpCost(
136       unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices,
137       Align Alignment, unsigned AddressSpace, TTI::TargetCostKind CostKind,
138       bool UseMaskForCond = false, bool UseMaskForGaps = false);
139 
140   InstructionCost getGatherScatterOpCost(unsigned Opcode, Type *DataTy,
141                                          const Value *Ptr, bool VariableMask,
142                                          Align Alignment,
143                                          TTI::TargetCostKind CostKind,
144                                          const Instruction *I);
145 
146   InstructionCost getCastInstrCost(unsigned Opcode, Type *Dst, Type *Src,
147                                    TTI::CastContextHint CCH,
148                                    TTI::TargetCostKind CostKind,
149                                    const Instruction *I = nullptr);
150 
151   InstructionCost getMinMaxReductionCost(Intrinsic::ID IID, VectorType *Ty,
152                                          FastMathFlags FMF,
153                                          TTI::TargetCostKind CostKind);
154 
155   InstructionCost getArithmeticReductionCost(unsigned Opcode, VectorType *Ty,
156                                              std::optional<FastMathFlags> FMF,
157                                              TTI::TargetCostKind CostKind);
158 
159   InstructionCost getExtendedReductionCost(unsigned Opcode, bool IsUnsigned,
160                                            Type *ResTy, VectorType *ValTy,
161                                            FastMathFlags FMF,
162                                            TTI::TargetCostKind CostKind);
163 
164   InstructionCost
165   getMemoryOpCost(unsigned Opcode, Type *Src, MaybeAlign Alignment,
166                   unsigned AddressSpace, TTI::TargetCostKind CostKind,
167                   TTI::OperandValueInfo OpdInfo = {TTI::OK_AnyValue, TTI::OP_None},
168                   const Instruction *I = nullptr);
169 
170   InstructionCost getCmpSelInstrCost(unsigned Opcode, Type *ValTy, Type *CondTy,
171                                      CmpInst::Predicate VecPred,
172                                      TTI::TargetCostKind CostKind,
173                                      const Instruction *I = nullptr);
174 
175   InstructionCost getCFInstrCost(unsigned Opcode, TTI::TargetCostKind CostKind,
176                                  const Instruction *I = nullptr);
177 
178   using BaseT::getVectorInstrCost;
179   InstructionCost getVectorInstrCost(unsigned Opcode, Type *Val,
180                                      TTI::TargetCostKind CostKind,
181                                      unsigned Index, Value *Op0, Value *Op1);
182 
183   InstructionCost getArithmeticInstrCost(
184       unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
185       TTI::OperandValueInfo Op1Info = {TTI::OK_AnyValue, TTI::OP_None},
186       TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
187       ArrayRef<const Value *> Args = ArrayRef<const Value *>(),
188       const Instruction *CxtI = nullptr);
189 
190   bool isElementTypeLegalForScalableVector(Type *Ty) const {
191     return TLI->isLegalElementTypeForRVV(TLI->getValueType(DL, Ty));
192   }
193 
194   bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) {
195     if (!ST->hasVInstructions())
196       return false;
197 
198     EVT DataTypeVT = TLI->getValueType(DL, DataType);
199 
200     // Only support fixed vectors if we know the minimum vector size.
201     if (DataTypeVT.isFixedLengthVector() && !ST->useRVVForFixedLengthVectors())
202       return false;
203 
204     EVT ElemType = DataTypeVT.getScalarType();
205     if (!ST->hasFastUnalignedAccess() && Alignment < ElemType.getStoreSize())
206       return false;
207 
208     return TLI->isLegalElementTypeForRVV(ElemType);
209 
210   }
211 
212   bool isLegalMaskedLoad(Type *DataType, Align Alignment) {
213     return isLegalMaskedLoadStore(DataType, Alignment);
214   }
215   bool isLegalMaskedStore(Type *DataType, Align Alignment) {
216     return isLegalMaskedLoadStore(DataType, Alignment);
217   }
218 
219   bool isLegalMaskedGatherScatter(Type *DataType, Align Alignment) {
220     if (!ST->hasVInstructions())
221       return false;
222 
223     EVT DataTypeVT = TLI->getValueType(DL, DataType);
224 
225     // Only support fixed vectors if we know the minimum vector size.
226     if (DataTypeVT.isFixedLengthVector() && !ST->useRVVForFixedLengthVectors())
227       return false;
228 
229     EVT ElemType = DataTypeVT.getScalarType();
230     if (!ST->hasFastUnalignedAccess() && Alignment < ElemType.getStoreSize())
231       return false;
232 
233     return TLI->isLegalElementTypeForRVV(ElemType);
234   }
235 
236   bool isLegalMaskedGather(Type *DataType, Align Alignment) {
237     return isLegalMaskedGatherScatter(DataType, Alignment);
238   }
239   bool isLegalMaskedScatter(Type *DataType, Align Alignment) {
240     return isLegalMaskedGatherScatter(DataType, Alignment);
241   }
242 
243   bool forceScalarizeMaskedGather(VectorType *VTy, Align Alignment) {
244     // Scalarize masked gather for RV64 if EEW=64 indices aren't supported.
245     return ST->is64Bit() && !ST->hasVInstructionsI64();
246   }
247 
248   bool forceScalarizeMaskedScatter(VectorType *VTy, Align Alignment) {
249     // Scalarize masked scatter for RV64 if EEW=64 indices aren't supported.
250     return ST->is64Bit() && !ST->hasVInstructionsI64();
251   }
252 
253   bool isVScaleKnownToBeAPowerOfTwo() const {
254     return TLI->isVScaleKnownToBeAPowerOfTwo();
255   }
256 
257   /// \returns How the target needs this vector-predicated operation to be
258   /// transformed.
259   TargetTransformInfo::VPLegalization
260   getVPLegalizationStrategy(const VPIntrinsic &PI) const {
261     using VPLegalization = TargetTransformInfo::VPLegalization;
262     if (!ST->hasVInstructions() ||
263         (PI.getIntrinsicID() == Intrinsic::vp_reduce_mul &&
264          cast<VectorType>(PI.getArgOperand(1)->getType())
265                  ->getElementType()
266                  ->getIntegerBitWidth() != 1))
267       return VPLegalization(VPLegalization::Discard, VPLegalization::Convert);
268     return VPLegalization(VPLegalization::Legal, VPLegalization::Legal);
269   }
270 
271   bool isLegalToVectorizeReduction(const RecurrenceDescriptor &RdxDesc,
272                                    ElementCount VF) const {
273     if (!VF.isScalable())
274       return true;
275 
276     Type *Ty = RdxDesc.getRecurrenceType();
277     if (!TLI->isLegalElementTypeForRVV(TLI->getValueType(DL, Ty)))
278       return false;
279 
280     switch (RdxDesc.getRecurrenceKind()) {
281     case RecurKind::Add:
282     case RecurKind::FAdd:
283     case RecurKind::And:
284     case RecurKind::Or:
285     case RecurKind::Xor:
286     case RecurKind::SMin:
287     case RecurKind::SMax:
288     case RecurKind::UMin:
289     case RecurKind::UMax:
290     case RecurKind::FMin:
291     case RecurKind::FMax:
292     case RecurKind::FMulAdd:
293     case RecurKind::IAnyOf:
294     case RecurKind::FAnyOf:
295       return true;
296     default:
297       return false;
298     }
299   }
300 
301   unsigned getMaxInterleaveFactor(ElementCount VF) {
302     // Don't interleave if the loop has been vectorized with scalable vectors.
303     if (VF.isScalable())
304       return 1;
305     // If the loop will not be vectorized, don't interleave the loop.
306     // Let regular unroll to unroll the loop.
307     return VF.isScalar() ? 1 : ST->getMaxInterleaveFactor();
308   }
309 
310   bool enableInterleavedAccessVectorization() { return true; }
311 
312   enum RISCVRegisterClass { GPRRC, FPRRC, VRRC };
313   unsigned getNumberOfRegisters(unsigned ClassID) const {
314     switch (ClassID) {
315     case RISCVRegisterClass::GPRRC:
316       // 31 = 32 GPR - x0 (zero register)
317       // FIXME: Should we exclude fixed registers like SP, TP or GP?
318       return 31;
319     case RISCVRegisterClass::FPRRC:
320       if (ST->hasStdExtF())
321         return 32;
322       return 0;
323     case RISCVRegisterClass::VRRC:
324       // Although there are 32 vector registers, v0 is special in that it is the
325       // only register that can be used to hold a mask.
326       // FIXME: Should we conservatively return 31 as the number of usable
327       // vector registers?
328       return ST->hasVInstructions() ? 32 : 0;
329     }
330     llvm_unreachable("unknown register class");
331   }
332 
333   unsigned getRegisterClassForType(bool Vector, Type *Ty = nullptr) const {
334     if (Vector)
335       return RISCVRegisterClass::VRRC;
336     if (!Ty)
337       return RISCVRegisterClass::GPRRC;
338 
339     Type *ScalarTy = Ty->getScalarType();
340     if ((ScalarTy->isHalfTy() && ST->hasStdExtZfhmin()) ||
341         (ScalarTy->isFloatTy() && ST->hasStdExtF()) ||
342         (ScalarTy->isDoubleTy() && ST->hasStdExtD())) {
343       return RISCVRegisterClass::FPRRC;
344     }
345 
346     return RISCVRegisterClass::GPRRC;
347   }
348 
349   const char *getRegisterClassName(unsigned ClassID) const {
350     switch (ClassID) {
351     case RISCVRegisterClass::GPRRC:
352       return "RISCV::GPRRC";
353     case RISCVRegisterClass::FPRRC:
354       return "RISCV::FPRRC";
355     case RISCVRegisterClass::VRRC:
356       return "RISCV::VRRC";
357     }
358     llvm_unreachable("unknown register class");
359   }
360 
361   bool isLSRCostLess(const TargetTransformInfo::LSRCost &C1,
362                      const TargetTransformInfo::LSRCost &C2);
363 
364   bool shouldFoldTerminatingConditionAfterLSR() const {
365     return true;
366   }
367 };
368 
369 } // end namespace llvm
370 
371 #endif // LLVM_LIB_TARGET_RISCV_RISCVTARGETTRANSFORMINFO_H
372