xref: /freebsd/contrib/llvm-project/llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===-- WebAssemblyTargetTransformInfo.cpp - WebAssembly-specific TTI -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// This file defines the WebAssembly-specific TargetTransformInfo
11 /// implementation.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "WebAssemblyTargetTransformInfo.h"
16 
17 #include "llvm/CodeGen/CostTable.h"
18 using namespace llvm;
19 
20 #define DEBUG_TYPE "wasmtti"
21 
22 TargetTransformInfo::PopcntSupportKind
23 WebAssemblyTTIImpl::getPopcntSupport(unsigned TyWidth) const {
24   assert(isPowerOf2_32(TyWidth) && "Ty width must be power of 2");
25   return TargetTransformInfo::PSK_FastHardware;
26 }
27 
28 unsigned WebAssemblyTTIImpl::getNumberOfRegisters(unsigned ClassID) const {
29   unsigned Result = BaseT::getNumberOfRegisters(ClassID);
30 
31   // For SIMD, use at least 16 registers, as a rough guess.
32   bool Vector = (ClassID == 1);
33   if (Vector)
34     Result = std::max(Result, 16u);
35 
36   return Result;
37 }
38 
39 TypeSize WebAssemblyTTIImpl::getRegisterBitWidth(
40     TargetTransformInfo::RegisterKind K) const {
41   switch (K) {
42   case TargetTransformInfo::RGK_Scalar:
43     return TypeSize::getFixed(64);
44   case TargetTransformInfo::RGK_FixedWidthVector:
45     return TypeSize::getFixed(getST()->hasSIMD128() ? 128 : 64);
46   case TargetTransformInfo::RGK_ScalableVector:
47     return TypeSize::getScalable(0);
48   }
49 
50   llvm_unreachable("Unsupported register kind");
51 }
52 
53 InstructionCost WebAssemblyTTIImpl::getArithmeticInstrCost(
54     unsigned Opcode, Type *Ty, TTI::TargetCostKind CostKind,
55     TTI::OperandValueInfo Op1Info, TTI::OperandValueInfo Op2Info,
56     ArrayRef<const Value *> Args, const Instruction *CxtI) const {
57 
58   InstructionCost Cost =
59       BasicTTIImplBase<WebAssemblyTTIImpl>::getArithmeticInstrCost(
60           Opcode, Ty, CostKind, Op1Info, Op2Info);
61 
62   if (auto *VTy = dyn_cast<VectorType>(Ty)) {
63     switch (Opcode) {
64     case Instruction::LShr:
65     case Instruction::AShr:
66     case Instruction::Shl:
67       // SIMD128's shifts currently only accept a scalar shift count. For each
68       // element, we'll need to extract, op, insert. The following is a rough
69       // approximation.
70       if (!Op2Info.isUniform())
71         Cost =
72             cast<FixedVectorType>(VTy)->getNumElements() *
73             (TargetTransformInfo::TCC_Basic +
74              getArithmeticInstrCost(Opcode, VTy->getElementType(), CostKind) +
75              TargetTransformInfo::TCC_Basic);
76       break;
77     }
78   }
79   return Cost;
80 }
81 
82 InstructionCost WebAssemblyTTIImpl::getCastInstrCost(
83     unsigned Opcode, Type *Dst, Type *Src, TTI::CastContextHint CCH,
84     TTI::TargetCostKind CostKind, const Instruction *I) const {
85   int ISD = TLI->InstructionOpcodeToISD(Opcode);
86   auto SrcTy = TLI->getValueType(DL, Src);
87   auto DstTy = TLI->getValueType(DL, Dst);
88 
89   if (!SrcTy.isSimple() || !DstTy.isSimple()) {
90     return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
91   }
92 
93   if (!ST->hasSIMD128()) {
94     return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
95   }
96 
97   auto DstVT = DstTy.getSimpleVT();
98   auto SrcVT = SrcTy.getSimpleVT();
99 
100   if (I && I->hasOneUser()) {
101     auto *SingleUser = cast<Instruction>(*I->user_begin());
102     int UserISD = TLI->InstructionOpcodeToISD(SingleUser->getOpcode());
103 
104     // extmul_low support
105     if (UserISD == ISD::MUL &&
106         (ISD == ISD::ZERO_EXTEND || ISD == ISD::SIGN_EXTEND)) {
107       // Free low extensions.
108       if ((SrcVT == MVT::v8i8 && DstVT == MVT::v8i16) ||
109           (SrcVT == MVT::v4i16 && DstVT == MVT::v4i32) ||
110           (SrcVT == MVT::v2i32 && DstVT == MVT::v2i64)) {
111         return 0;
112       }
113       // Will require an additional extlow operation for the intermediate
114       // i16/i32 value.
115       if ((SrcVT == MVT::v4i8 && DstVT == MVT::v4i32) ||
116           (SrcVT == MVT::v2i16 && DstVT == MVT::v2i64)) {
117         return 1;
118       }
119     }
120   }
121 
122   // extend_low
123   static constexpr TypeConversionCostTblEntry ConversionTbl[] = {
124       {ISD::SIGN_EXTEND, MVT::v2i64, MVT::v2i32, 1},
125       {ISD::ZERO_EXTEND, MVT::v2i64, MVT::v2i32, 1},
126       {ISD::SIGN_EXTEND, MVT::v4i32, MVT::v4i16, 1},
127       {ISD::ZERO_EXTEND, MVT::v4i32, MVT::v4i16, 1},
128       {ISD::SIGN_EXTEND, MVT::v8i16, MVT::v8i8, 1},
129       {ISD::ZERO_EXTEND, MVT::v8i16, MVT::v8i8, 1},
130       {ISD::SIGN_EXTEND, MVT::v2i64, MVT::v2i16, 2},
131       {ISD::ZERO_EXTEND, MVT::v2i64, MVT::v2i16, 2},
132       {ISD::SIGN_EXTEND, MVT::v4i32, MVT::v4i8, 2},
133       {ISD::ZERO_EXTEND, MVT::v4i32, MVT::v4i8, 2},
134   };
135 
136   if (const auto *Entry =
137           ConvertCostTableLookup(ConversionTbl, ISD, DstVT, SrcVT)) {
138     return Entry->Cost;
139   }
140 
141   return BaseT::getCastInstrCost(Opcode, Dst, Src, CCH, CostKind, I);
142 }
143 
144 InstructionCost WebAssemblyTTIImpl::getMemoryOpCost(
145     unsigned Opcode, Type *Ty, Align Alignment, unsigned AddressSpace,
146     TTI::TargetCostKind CostKind, TTI::OperandValueInfo OpInfo,
147     const Instruction *I) const {
148   if (!ST->hasSIMD128() || !isa<FixedVectorType>(Ty)) {
149     return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
150                                   CostKind);
151   }
152 
153   int ISD = TLI->InstructionOpcodeToISD(Opcode);
154   if (ISD != ISD::LOAD) {
155     return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
156                                   CostKind);
157   }
158 
159   EVT VT = TLI->getValueType(DL, Ty, true);
160   // Type legalization can't handle structs
161   if (VT == MVT::Other)
162     return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace,
163                                   CostKind);
164 
165   auto LT = getTypeLegalizationCost(Ty);
166   if (!LT.first.isValid())
167     return InstructionCost::getInvalid();
168 
169   // 128-bit loads are a single instruction. 32-bit and 64-bit vector loads can
170   // be lowered to load32_zero and load64_zero respectively. Assume SIMD loads
171   // are twice as expensive as scalar.
172   unsigned width = VT.getSizeInBits();
173   switch (width) {
174   default:
175     break;
176   case 32:
177   case 64:
178   case 128:
179     return 2;
180   }
181 
182   return BaseT::getMemoryOpCost(Opcode, Ty, Alignment, AddressSpace, CostKind);
183 }
184 
185 InstructionCost WebAssemblyTTIImpl::getVectorInstrCost(
186     unsigned Opcode, Type *Val, TTI::TargetCostKind CostKind, unsigned Index,
187     const Value *Op0, const Value *Op1) const {
188   InstructionCost Cost = BasicTTIImplBase::getVectorInstrCost(
189       Opcode, Val, CostKind, Index, Op0, Op1);
190 
191   // SIMD128's insert/extract currently only take constant indices.
192   if (Index == -1u)
193     return Cost + 25 * TargetTransformInfo::TCC_Expensive;
194 
195   return Cost;
196 }
197 
198 InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
199     unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
200     ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
201     TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
202     TTI::TargetCostKind CostKind) const {
203   InstructionCost Invalid = InstructionCost::getInvalid();
204   if (!VF.isFixed() || !ST->hasSIMD128())
205     return Invalid;
206 
207   if (CostKind != TTI::TCK_RecipThroughput)
208     return Invalid;
209 
210   InstructionCost Cost(TTI::TCC_Basic);
211 
212   // Possible options:
213   // - i16x8.extadd_pairwise_i8x16_sx
214   // - i32x4.extadd_pairwise_i16x8_sx
215   // - i32x4.dot_i16x8_s
216   // Only try to support dot, for now.
217 
218   if (Opcode != Instruction::Add)
219     return Invalid;
220 
221   if (!BinOp || *BinOp != Instruction::Mul)
222     return Invalid;
223 
224   if (InputTypeA != InputTypeB)
225     return Invalid;
226 
227   if (OpAExtend != OpBExtend)
228     return Invalid;
229 
230   EVT InputEVT = EVT::getEVT(InputTypeA);
231   EVT AccumEVT = EVT::getEVT(AccumType);
232 
233   // TODO: Add i64 accumulator.
234   if (AccumEVT != MVT::i32)
235     return Invalid;
236 
237   // Signed inputs can lower to dot
238   if (InputEVT == MVT::i16 && VF.getFixedValue() == 8)
239     return OpAExtend == TTI::PR_SignExtend ? Cost : Cost * 2;
240 
241   // Double the size of the lowered sequence.
242   if (InputEVT == MVT::i8 && VF.getFixedValue() == 16)
243     return OpAExtend == TTI::PR_SignExtend ? Cost * 2 : Cost * 4;
244 
245   return Invalid;
246 }
247 
248 TTI::ReductionShuffle WebAssemblyTTIImpl::getPreferredExpandedReductionShuffle(
249     const IntrinsicInst *II) const {
250 
251   switch (II->getIntrinsicID()) {
252   default:
253     break;
254   case Intrinsic::vector_reduce_fadd:
255     return TTI::ReductionShuffle::Pairwise;
256   }
257   return TTI::ReductionShuffle::SplitHalf;
258 }
259 
260 void WebAssemblyTTIImpl::getUnrollingPreferences(
261     Loop *L, ScalarEvolution &SE, TTI::UnrollingPreferences &UP,
262     OptimizationRemarkEmitter *ORE) const {
263   // Scan the loop: don't unroll loops with calls. This is a standard approach
264   // for most (all?) targets.
265   for (BasicBlock *BB : L->blocks())
266     for (Instruction &I : *BB)
267       if (isa<CallInst>(I) || isa<InvokeInst>(I))
268         if (const Function *F = cast<CallBase>(I).getCalledFunction())
269           if (isLoweredToCall(F))
270             return;
271 
272   // The chosen threshold is within the range of 'LoopMicroOpBufferSize' of
273   // the various microarchitectures that use the BasicTTI implementation and
274   // has been selected through heuristics across multiple cores and runtimes.
275   UP.Partial = UP.Runtime = UP.UpperBound = true;
276   UP.PartialThreshold = 30;
277 
278   // Avoid unrolling when optimizing for size.
279   UP.OptSizeThreshold = 0;
280   UP.PartialOptSizeThreshold = 0;
281 
282   // Set number of instructions optimized when "back edge"
283   // becomes "fall through" to default value of 2.
284   UP.BEInsns = 2;
285 }
286 
287 bool WebAssemblyTTIImpl::supportsTailCalls() const {
288   return getST()->hasTailCall();
289 }
290 
291 bool WebAssemblyTTIImpl::isProfitableToSinkOperands(
292     Instruction *I, SmallVectorImpl<Use *> &Ops) const {
293   using namespace llvm::PatternMatch;
294 
295   if (!I->getType()->isVectorTy() || !I->isShift())
296     return false;
297 
298   Value *V = I->getOperand(1);
299   // We dont need to sink constant splat.
300   if (isa<Constant>(V))
301     return false;
302 
303   if (match(V, m_Shuffle(m_InsertElt(m_Value(), m_Value(), m_ZeroInt()),
304                          m_Value(), m_ZeroMask()))) {
305     // Sink insert
306     Ops.push_back(&cast<Instruction>(V)->getOperandUse(0));
307     // Sink shuffle
308     Ops.push_back(&I->getOperandUse(1));
309     return true;
310   }
311 
312   return false;
313 }
314