xref: /freebsd/contrib/llvm-project/llvm/lib/Target/NVPTX/NVPTXISelLowering.h (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===-- NVPTXISelLowering.h - NVPTX DAG Lowering Interface ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the interfaces that NVPTX uses to lower LLVM code into a
10 // selection DAG.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef LLVM_LIB_TARGET_NVPTX_NVPTXISELLOWERING_H
15 #define LLVM_LIB_TARGET_NVPTX_NVPTXISELLOWERING_H
16 
17 #include "NVPTX.h"
18 #include "llvm/CodeGen/SelectionDAG.h"
19 #include "llvm/CodeGen/TargetLowering.h"
20 #include "llvm/Support/AtomicOrdering.h"
21 
22 namespace llvm {
23 namespace NVPTXISD {
24 enum NodeType : unsigned {
25   // Start the numbering from where ISD NodeType finishes.
26   FIRST_NUMBER = ISD::BUILTIN_OP_END,
27   RET_GLUE,
28 
29   /// These nodes represent a parameter declaration. In PTX this will look like:
30   ///   .param .align 16 .b8 param0[1024];
31   ///   .param .b32 retval0;
32   ///
33   /// DeclareArrayParam(Chain, Externalsym, Align, Size, Glue)
34   /// DeclareScalarParam(Chain, Externalsym, Size, Glue)
35   DeclareScalarParam,
36   DeclareArrayParam,
37 
38   /// This node represents a PTX call instruction. It's operands are as follows:
39   ///
40   /// CALL(Chain, IsConvergent, IsIndirectCall/IsUniform, NumReturns,
41   ///      NumParams, Callee, Proto, InGlue)
42   CALL,
43 
44   MoveParam,
45   CallPrototype,
46   ProxyReg,
47   FSHL_CLAMP,
48   FSHR_CLAMP,
49   MUL_WIDE_SIGNED,
50   MUL_WIDE_UNSIGNED,
51   SETP_F16X2,
52   SETP_BF16X2,
53   BFI,
54   PRMT,
55 
56   /// This node is similar to ISD::BUILD_VECTOR except that the output may be
57   /// implicitly bitcast to a scalar. This allows for the representation of
58   /// packing move instructions for vector types which are not legal i.e. v2i32
59   BUILD_VECTOR,
60 
61   /// This node is the inverse of NVPTX::BUILD_VECTOR. It takes a single value
62   /// which may be a scalar and unpacks it into multiple values by implicitly
63   /// converting it to a vector.
64   UNPACK_VECTOR,
65 
66   FCOPYSIGN,
67   DYNAMIC_STACKALLOC,
68   STACKRESTORE,
69   STACKSAVE,
70   BrxStart,
71   BrxItem,
72   BrxEnd,
73   CLUSTERLAUNCHCONTROL_QUERY_CANCEL_IS_CANCELED,
74   CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_X,
75   CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Y,
76   CLUSTERLAUNCHCONTROL_QUERY_CANCEL_GET_FIRST_CTAID_Z,
77 
78   FIRST_MEMORY_OPCODE,
79   LoadV2 = FIRST_MEMORY_OPCODE,
80   LoadV4,
81   LoadV8,
82   LDUV2, // LDU.v2
83   LDUV4, // LDU.v4
84   StoreV2,
85   StoreV4,
86   StoreV8,
87   LoadParam,
88   LoadParamV2,
89   LoadParamV4,
90   StoreParam,
91   StoreParamV2,
92   StoreParamV4,
93   LAST_MEMORY_OPCODE = StoreParamV4,
94 };
95 }
96 
97 class NVPTXSubtarget;
98 
99 //===--------------------------------------------------------------------===//
100 // TargetLowering Implementation
101 //===--------------------------------------------------------------------===//
102 class NVPTXTargetLowering : public TargetLowering {
103 public:
104   explicit NVPTXTargetLowering(const NVPTXTargetMachine &TM,
105                                const NVPTXSubtarget &STI);
106   SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override;
107 
108   const char *getTargetNodeName(unsigned Opcode) const override;
109 
110   bool getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I,
111                           MachineFunction &MF,
112                           unsigned Intrinsic) const override;
113 
114   Align getFunctionArgumentAlignment(const Function *F, Type *Ty, unsigned Idx,
115                                      const DataLayout &DL) const;
116 
117   /// getFunctionParamOptimizedAlign - since function arguments are passed via
118   /// .param space, we may want to increase their alignment in a way that
119   /// ensures that we can effectively vectorize their loads & stores. We can
120   /// increase alignment only if the function has internal or has private
121   /// linkage as for other linkage types callers may already rely on default
122   /// alignment. To allow using 128-bit vectorized loads/stores, this function
123   /// ensures that alignment is 16 or greater.
124   Align getFunctionParamOptimizedAlign(const Function *F, Type *ArgTy,
125                                        const DataLayout &DL) const;
126 
127   /// Helper for computing alignment of a device function byval parameter.
128   Align getFunctionByValParamAlign(const Function *F, Type *ArgTy,
129                                    Align InitialAlign,
130                                    const DataLayout &DL) const;
131 
132   // Helper for getting a function parameter name. Name is composed from
133   // its index and the function name. Negative index corresponds to special
134   // parameter (unsized array) used for passing variable arguments.
135   std::string getParamName(const Function *F, int Idx) const;
136 
137   /// isLegalAddressingMode - Return true if the addressing mode represented
138   /// by AM is legal for this target, for a load/store of the specified type
139   /// Used to guide target specific optimizations, like loop strength
140   /// reduction (LoopStrengthReduce.cpp) and memory optimization for
141   /// address mode (CodeGenPrepare.cpp)
142   bool isLegalAddressingMode(const DataLayout &DL, const AddrMode &AM, Type *Ty,
143                              unsigned AS,
144                              Instruction *I = nullptr) const override;
145 
isTruncateFree(Type * SrcTy,Type * DstTy)146   bool isTruncateFree(Type *SrcTy, Type *DstTy) const override {
147     // Truncating 64-bit to 32-bit is free in SASS.
148     if (!SrcTy->isIntegerTy() || !DstTy->isIntegerTy())
149       return false;
150     return SrcTy->getPrimitiveSizeInBits() == 64 &&
151            DstTy->getPrimitiveSizeInBits() == 32;
152   }
153 
getSetCCResultType(const DataLayout & DL,LLVMContext & Ctx,EVT VT)154   EVT getSetCCResultType(const DataLayout &DL, LLVMContext &Ctx,
155                          EVT VT) const override {
156     if (VT.isVector())
157       return EVT::getVectorVT(Ctx, MVT::i1, VT.getVectorNumElements());
158     return MVT::i1;
159   }
160 
161   ConstraintType getConstraintType(StringRef Constraint) const override;
162   std::pair<unsigned, const TargetRegisterClass *>
163   getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
164                                StringRef Constraint, MVT VT) const override;
165 
166   SDValue LowerFormalArguments(SDValue Chain, CallingConv::ID CallConv,
167                                bool isVarArg,
168                                const SmallVectorImpl<ISD::InputArg> &Ins,
169                                const SDLoc &dl, SelectionDAG &DAG,
170                                SmallVectorImpl<SDValue> &InVals) const override;
171 
172   SDValue LowerCall(CallLoweringInfo &CLI,
173                     SmallVectorImpl<SDValue> &InVals) const override;
174 
175   SDValue LowerDYNAMIC_STACKALLOC(SDValue Op, SelectionDAG &DAG) const;
176   SDValue LowerSTACKSAVE(SDValue Op, SelectionDAG &DAG) const;
177   SDValue LowerSTACKRESTORE(SDValue Op, SelectionDAG &DAG) const;
178 
179   std::string getPrototype(const DataLayout &DL, Type *, const ArgListTy &,
180                            const SmallVectorImpl<ISD::OutputArg> &,
181                            std::optional<unsigned> FirstVAArg,
182                            const CallBase &CB, unsigned UniqueCallSite) const;
183 
184   SDValue LowerReturn(SDValue Chain, CallingConv::ID CallConv, bool isVarArg,
185                       const SmallVectorImpl<ISD::OutputArg> &Outs,
186                       const SmallVectorImpl<SDValue> &OutVals, const SDLoc &dl,
187                       SelectionDAG &DAG) const override;
188 
189   void LowerAsmOperandForConstraint(SDValue Op, StringRef Constraint,
190                                     std::vector<SDValue> &Ops,
191                                     SelectionDAG &DAG) const override;
192 
193   const NVPTXTargetMachine *nvTM;
194 
195   // PTX always uses 32-bit shift amounts
getScalarShiftAmountTy(const DataLayout &,EVT)196   MVT getScalarShiftAmountTy(const DataLayout &, EVT) const override {
197     return MVT::i32;
198   }
199 
200   TargetLoweringBase::LegalizeTypeAction
201   getPreferredVectorAction(MVT VT) const override;
202 
203   // Get the degree of precision we want from 32-bit floating point division
204   // operations.
205   NVPTX::DivPrecisionLevel getDivF32Level(const MachineFunction &MF,
206                                           const SDNode &N) const;
207 
208   // Get whether we should use a precise or approximate 32-bit floating point
209   // sqrt instruction.
210   bool usePrecSqrtF32(const MachineFunction &MF,
211                       const SDNode *N = nullptr) const;
212 
213   // Get whether we should use instructions that flush floating-point denormals
214   // to sign-preserving zero.
215   bool useF32FTZ(const MachineFunction &MF) const;
216 
217   SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
218                           int &ExtraSteps, bool &UseOneConst,
219                           bool Reciprocal) const override;
220 
combineRepeatedFPDivisors()221   unsigned combineRepeatedFPDivisors() const override { return 2; }
222 
223   bool allowFMA(MachineFunction &MF, CodeGenOptLevel OptLevel) const;
224   bool allowUnsafeFPMath(const MachineFunction &MF) const;
225 
isFMAFasterThanFMulAndFAdd(const MachineFunction & MF,EVT)226   bool isFMAFasterThanFMulAndFAdd(const MachineFunction &MF,
227                                   EVT) const override {
228     return true;
229   }
230 
231   // The default is the same as pointer type, but brx.idx only accepts i32
getJumpTableRegTy(const DataLayout &)232   MVT getJumpTableRegTy(const DataLayout &) const override { return MVT::i32; }
233 
234   unsigned getJumpTableEncoding() const override;
235 
enableAggressiveFMAFusion(EVT VT)236   bool enableAggressiveFMAFusion(EVT VT) const override { return true; }
237 
238   // The default is to transform llvm.ctlz(x, false) (where false indicates that
239   // x == 0 is not undefined behavior) into a branch that checks whether x is 0
240   // and avoids calling ctlz in that case.  We have a dedicated ctlz
241   // instruction, so we say that ctlz is cheap to speculate.
isCheapToSpeculateCtlz(Type * Ty)242   bool isCheapToSpeculateCtlz(Type *Ty) const override { return true; }
243 
shouldCastAtomicLoadInIR(LoadInst * LI)244   AtomicExpansionKind shouldCastAtomicLoadInIR(LoadInst *LI) const override {
245     return AtomicExpansionKind::None;
246   }
247 
shouldCastAtomicStoreInIR(StoreInst * SI)248   AtomicExpansionKind shouldCastAtomicStoreInIR(StoreInst *SI) const override {
249     return AtomicExpansionKind::None;
250   }
251 
252   AtomicExpansionKind
253   shouldExpandAtomicRMWInIR(AtomicRMWInst *AI) const override;
254 
aggressivelyPreferBuildVectorSources(EVT VecVT)255   bool aggressivelyPreferBuildVectorSources(EVT VecVT) const override {
256     // There's rarely any point of packing something into a vector type if we
257     // already have the source data.
258     return true;
259   }
260 
261   bool shouldInsertFencesForAtomic(const Instruction *) const override;
262 
263   AtomicOrdering
264   atomicOperationOrderAfterFenceSplit(const Instruction *I) const override;
265 
266   Instruction *emitLeadingFence(IRBuilderBase &Builder, Instruction *Inst,
267                                 AtomicOrdering Ord) const override;
268   Instruction *emitTrailingFence(IRBuilderBase &Builder, Instruction *Inst,
269                                  AtomicOrdering Ord) const override;
270 
271   unsigned getPreferredFPToIntOpcode(unsigned Op, EVT FromVT,
272                                      EVT ToVT) const override;
273 
274   void computeKnownBitsForTargetNode(const SDValue Op, KnownBits &Known,
275                                      const APInt &DemandedElts,
276                                      const SelectionDAG &DAG,
277                                      unsigned Depth = 0) const override;
278 
279 private:
280   const NVPTXSubtarget &STI; // cache the subtarget here
281   mutable unsigned GlobalUniqueCallSite;
282 
283   SDValue getParamSymbol(SelectionDAG &DAG, int I, EVT T) const;
284   SDValue getCallParamSymbol(SelectionDAG &DAG, int I, EVT T) const;
285   SDValue LowerADDRSPACECAST(SDValue Op, SelectionDAG &DAG) const;
286   SDValue LowerBITCAST(SDValue Op, SelectionDAG &DAG) const;
287 
288   SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
289   SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
290   SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
291   SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
292   SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
293 
294   SDValue LowerFCOPYSIGN(SDValue Op, SelectionDAG &DAG) const;
295 
296   SDValue LowerFROUND(SDValue Op, SelectionDAG &DAG) const;
297   SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
298   SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
299 
300   SDValue PromoteBinOpIfF32FTZ(SDValue Op, SelectionDAG &DAG) const;
301 
302   SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
303   SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
304 
305   SDValue LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const;
306   SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const;
307 
308   SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
309   SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const;
310 
311   SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
312   SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const;
313   SDValue LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const;
314 
315   SDValue LowerShiftRightParts(SDValue Op, SelectionDAG &DAG) const;
316   SDValue LowerShiftLeftParts(SDValue Op, SelectionDAG &DAG) const;
317 
318   SDValue LowerBR_JT(SDValue Op, SelectionDAG &DAG) const;
319 
320   SDValue LowerVAARG(SDValue Op, SelectionDAG &DAG) const;
321   SDValue LowerVASTART(SDValue Op, SelectionDAG &DAG) const;
322 
323   SDValue LowerCopyToReg_128(SDValue Op, SelectionDAG &DAG) const;
324   unsigned getNumRegisters(LLVMContext &Context, EVT VT,
325                            std::optional<MVT> RegisterVT) const override;
326   bool
327   splitValueIntoRegisterParts(SelectionDAG &DAG, const SDLoc &DL, SDValue Val,
328                               SDValue *Parts, unsigned NumParts, MVT PartVT,
329                               std::optional<CallingConv::ID> CC) const override;
330 
331   void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
332                           SelectionDAG &DAG) const override;
333   SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override;
334 
335   Align getArgumentAlignment(const CallBase *CB, Type *Ty, unsigned Idx,
336                              const DataLayout &DL) const;
337 };
338 
339 } // namespace llvm
340 
341 #endif
342