xref: /freebsd/contrib/llvm-project/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===-- NVPTXISelDAGToDAG.cpp - A dag to dag inst selector for NVPTX ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines an instruction selector for the NVPTX target.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "NVPTXISelDAGToDAG.h"
14 #include "MCTargetDesc/NVPTXBaseInfo.h"
15 #include "NVPTXUtilities.h"
16 #include "llvm/Analysis/ValueTracking.h"
17 #include "llvm/CodeGen/ISDOpcodes.h"
18 #include "llvm/IR/GlobalValue.h"
19 #include "llvm/IR/Instructions.h"
20 #include "llvm/IR/IntrinsicsNVPTX.h"
21 #include "llvm/Support/AtomicOrdering.h"
22 #include "llvm/Support/CommandLine.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/ErrorHandling.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "llvm/Target/TargetIntrinsicInfo.h"
27 
28 using namespace llvm;
29 
30 #define DEBUG_TYPE "nvptx-isel"
31 #define PASS_NAME "NVPTX DAG->DAG Pattern Instruction Selection"
32 
33 static cl::opt<bool>
34     EnableRsqrtOpt("nvptx-rsqrt-approx-opt", cl::init(true), cl::Hidden,
35                    cl::desc("Enable reciprocal sqrt optimization"));
36 
37 /// createNVPTXISelDag - This pass converts a legalized DAG into a
38 /// NVPTX-specific DAG, ready for instruction scheduling.
createNVPTXISelDag(NVPTXTargetMachine & TM,llvm::CodeGenOptLevel OptLevel)39 FunctionPass *llvm::createNVPTXISelDag(NVPTXTargetMachine &TM,
40                                        llvm::CodeGenOptLevel OptLevel) {
41   return new NVPTXDAGToDAGISelLegacy(TM, OptLevel);
42 }
43 
NVPTXDAGToDAGISelLegacy(NVPTXTargetMachine & tm,CodeGenOptLevel OptLevel)44 NVPTXDAGToDAGISelLegacy::NVPTXDAGToDAGISelLegacy(NVPTXTargetMachine &tm,
45                                                  CodeGenOptLevel OptLevel)
46     : SelectionDAGISelLegacy(
47           ID, std::make_unique<NVPTXDAGToDAGISel>(tm, OptLevel)) {}
48 
49 char NVPTXDAGToDAGISelLegacy::ID = 0;
50 
INITIALIZE_PASS(NVPTXDAGToDAGISelLegacy,DEBUG_TYPE,PASS_NAME,false,false)51 INITIALIZE_PASS(NVPTXDAGToDAGISelLegacy, DEBUG_TYPE, PASS_NAME, false, false)
52 
53 NVPTXDAGToDAGISel::NVPTXDAGToDAGISel(NVPTXTargetMachine &tm,
54                                      CodeGenOptLevel OptLevel)
55     : SelectionDAGISel(tm, OptLevel), TM(tm) {
56   doMulWide = (OptLevel > CodeGenOptLevel::None);
57 }
58 
runOnMachineFunction(MachineFunction & MF)59 bool NVPTXDAGToDAGISel::runOnMachineFunction(MachineFunction &MF) {
60   Subtarget = &MF.getSubtarget<NVPTXSubtarget>();
61   return SelectionDAGISel::runOnMachineFunction(MF);
62 }
63 
getDivF32Level() const64 int NVPTXDAGToDAGISel::getDivF32Level() const {
65   return Subtarget->getTargetLowering()->getDivF32Level();
66 }
67 
usePrecSqrtF32() const68 bool NVPTXDAGToDAGISel::usePrecSqrtF32() const {
69   return Subtarget->getTargetLowering()->usePrecSqrtF32();
70 }
71 
useF32FTZ() const72 bool NVPTXDAGToDAGISel::useF32FTZ() const {
73   return Subtarget->getTargetLowering()->useF32FTZ(*MF);
74 }
75 
allowFMA() const76 bool NVPTXDAGToDAGISel::allowFMA() const {
77   const NVPTXTargetLowering *TL = Subtarget->getTargetLowering();
78   return TL->allowFMA(*MF, OptLevel);
79 }
80 
allowUnsafeFPMath() const81 bool NVPTXDAGToDAGISel::allowUnsafeFPMath() const {
82   const NVPTXTargetLowering *TL = Subtarget->getTargetLowering();
83   return TL->allowUnsafeFPMath(*MF);
84 }
85 
doRsqrtOpt() const86 bool NVPTXDAGToDAGISel::doRsqrtOpt() const { return EnableRsqrtOpt; }
87 
88 /// Select - Select instructions not customized! Used for
89 /// expanded, promoted and normal instructions.
Select(SDNode * N)90 void NVPTXDAGToDAGISel::Select(SDNode *N) {
91 
92   if (N->isMachineOpcode()) {
93     N->setNodeId(-1);
94     return; // Already selected.
95   }
96 
97   switch (N->getOpcode()) {
98   case ISD::LOAD:
99   case ISD::ATOMIC_LOAD:
100     if (tryLoad(N))
101       return;
102     break;
103   case ISD::STORE:
104   case ISD::ATOMIC_STORE:
105     if (tryStore(N))
106       return;
107     break;
108   case ISD::EXTRACT_VECTOR_ELT:
109     if (tryEXTRACT_VECTOR_ELEMENT(N))
110       return;
111     break;
112   case NVPTXISD::SETP_F16X2:
113     SelectSETP_F16X2(N);
114     return;
115   case NVPTXISD::SETP_BF16X2:
116     SelectSETP_BF16X2(N);
117     return;
118   case NVPTXISD::LoadV2:
119   case NVPTXISD::LoadV4:
120     if (tryLoadVector(N))
121       return;
122     break;
123   case NVPTXISD::LDGV2:
124   case NVPTXISD::LDGV4:
125   case NVPTXISD::LDUV2:
126   case NVPTXISD::LDUV4:
127     if (tryLDGLDU(N))
128       return;
129     break;
130   case NVPTXISD::StoreV2:
131   case NVPTXISD::StoreV4:
132     if (tryStoreVector(N))
133       return;
134     break;
135   case NVPTXISD::LoadParam:
136   case NVPTXISD::LoadParamV2:
137   case NVPTXISD::LoadParamV4:
138     if (tryLoadParam(N))
139       return;
140     break;
141   case NVPTXISD::StoreRetval:
142   case NVPTXISD::StoreRetvalV2:
143   case NVPTXISD::StoreRetvalV4:
144     if (tryStoreRetval(N))
145       return;
146     break;
147   case NVPTXISD::StoreParam:
148   case NVPTXISD::StoreParamV2:
149   case NVPTXISD::StoreParamV4:
150   case NVPTXISD::StoreParamS32:
151   case NVPTXISD::StoreParamU32:
152     if (tryStoreParam(N))
153       return;
154     break;
155   case ISD::INTRINSIC_WO_CHAIN:
156     if (tryIntrinsicNoChain(N))
157       return;
158     break;
159   case ISD::INTRINSIC_W_CHAIN:
160     if (tryIntrinsicChain(N))
161       return;
162     break;
163   case NVPTXISD::Tex1DFloatS32:
164   case NVPTXISD::Tex1DFloatFloat:
165   case NVPTXISD::Tex1DFloatFloatLevel:
166   case NVPTXISD::Tex1DFloatFloatGrad:
167   case NVPTXISD::Tex1DS32S32:
168   case NVPTXISD::Tex1DS32Float:
169   case NVPTXISD::Tex1DS32FloatLevel:
170   case NVPTXISD::Tex1DS32FloatGrad:
171   case NVPTXISD::Tex1DU32S32:
172   case NVPTXISD::Tex1DU32Float:
173   case NVPTXISD::Tex1DU32FloatLevel:
174   case NVPTXISD::Tex1DU32FloatGrad:
175   case NVPTXISD::Tex1DArrayFloatS32:
176   case NVPTXISD::Tex1DArrayFloatFloat:
177   case NVPTXISD::Tex1DArrayFloatFloatLevel:
178   case NVPTXISD::Tex1DArrayFloatFloatGrad:
179   case NVPTXISD::Tex1DArrayS32S32:
180   case NVPTXISD::Tex1DArrayS32Float:
181   case NVPTXISD::Tex1DArrayS32FloatLevel:
182   case NVPTXISD::Tex1DArrayS32FloatGrad:
183   case NVPTXISD::Tex1DArrayU32S32:
184   case NVPTXISD::Tex1DArrayU32Float:
185   case NVPTXISD::Tex1DArrayU32FloatLevel:
186   case NVPTXISD::Tex1DArrayU32FloatGrad:
187   case NVPTXISD::Tex2DFloatS32:
188   case NVPTXISD::Tex2DFloatFloat:
189   case NVPTXISD::Tex2DFloatFloatLevel:
190   case NVPTXISD::Tex2DFloatFloatGrad:
191   case NVPTXISD::Tex2DS32S32:
192   case NVPTXISD::Tex2DS32Float:
193   case NVPTXISD::Tex2DS32FloatLevel:
194   case NVPTXISD::Tex2DS32FloatGrad:
195   case NVPTXISD::Tex2DU32S32:
196   case NVPTXISD::Tex2DU32Float:
197   case NVPTXISD::Tex2DU32FloatLevel:
198   case NVPTXISD::Tex2DU32FloatGrad:
199   case NVPTXISD::Tex2DArrayFloatS32:
200   case NVPTXISD::Tex2DArrayFloatFloat:
201   case NVPTXISD::Tex2DArrayFloatFloatLevel:
202   case NVPTXISD::Tex2DArrayFloatFloatGrad:
203   case NVPTXISD::Tex2DArrayS32S32:
204   case NVPTXISD::Tex2DArrayS32Float:
205   case NVPTXISD::Tex2DArrayS32FloatLevel:
206   case NVPTXISD::Tex2DArrayS32FloatGrad:
207   case NVPTXISD::Tex2DArrayU32S32:
208   case NVPTXISD::Tex2DArrayU32Float:
209   case NVPTXISD::Tex2DArrayU32FloatLevel:
210   case NVPTXISD::Tex2DArrayU32FloatGrad:
211   case NVPTXISD::Tex3DFloatS32:
212   case NVPTXISD::Tex3DFloatFloat:
213   case NVPTXISD::Tex3DFloatFloatLevel:
214   case NVPTXISD::Tex3DFloatFloatGrad:
215   case NVPTXISD::Tex3DS32S32:
216   case NVPTXISD::Tex3DS32Float:
217   case NVPTXISD::Tex3DS32FloatLevel:
218   case NVPTXISD::Tex3DS32FloatGrad:
219   case NVPTXISD::Tex3DU32S32:
220   case NVPTXISD::Tex3DU32Float:
221   case NVPTXISD::Tex3DU32FloatLevel:
222   case NVPTXISD::Tex3DU32FloatGrad:
223   case NVPTXISD::TexCubeFloatFloat:
224   case NVPTXISD::TexCubeFloatFloatLevel:
225   case NVPTXISD::TexCubeS32Float:
226   case NVPTXISD::TexCubeS32FloatLevel:
227   case NVPTXISD::TexCubeU32Float:
228   case NVPTXISD::TexCubeU32FloatLevel:
229   case NVPTXISD::TexCubeArrayFloatFloat:
230   case NVPTXISD::TexCubeArrayFloatFloatLevel:
231   case NVPTXISD::TexCubeArrayS32Float:
232   case NVPTXISD::TexCubeArrayS32FloatLevel:
233   case NVPTXISD::TexCubeArrayU32Float:
234   case NVPTXISD::TexCubeArrayU32FloatLevel:
235   case NVPTXISD::Tld4R2DFloatFloat:
236   case NVPTXISD::Tld4G2DFloatFloat:
237   case NVPTXISD::Tld4B2DFloatFloat:
238   case NVPTXISD::Tld4A2DFloatFloat:
239   case NVPTXISD::Tld4R2DS64Float:
240   case NVPTXISD::Tld4G2DS64Float:
241   case NVPTXISD::Tld4B2DS64Float:
242   case NVPTXISD::Tld4A2DS64Float:
243   case NVPTXISD::Tld4R2DU64Float:
244   case NVPTXISD::Tld4G2DU64Float:
245   case NVPTXISD::Tld4B2DU64Float:
246   case NVPTXISD::Tld4A2DU64Float:
247   case NVPTXISD::TexUnified1DFloatS32:
248   case NVPTXISD::TexUnified1DFloatFloat:
249   case NVPTXISD::TexUnified1DFloatFloatLevel:
250   case NVPTXISD::TexUnified1DFloatFloatGrad:
251   case NVPTXISD::TexUnified1DS32S32:
252   case NVPTXISD::TexUnified1DS32Float:
253   case NVPTXISD::TexUnified1DS32FloatLevel:
254   case NVPTXISD::TexUnified1DS32FloatGrad:
255   case NVPTXISD::TexUnified1DU32S32:
256   case NVPTXISD::TexUnified1DU32Float:
257   case NVPTXISD::TexUnified1DU32FloatLevel:
258   case NVPTXISD::TexUnified1DU32FloatGrad:
259   case NVPTXISD::TexUnified1DArrayFloatS32:
260   case NVPTXISD::TexUnified1DArrayFloatFloat:
261   case NVPTXISD::TexUnified1DArrayFloatFloatLevel:
262   case NVPTXISD::TexUnified1DArrayFloatFloatGrad:
263   case NVPTXISD::TexUnified1DArrayS32S32:
264   case NVPTXISD::TexUnified1DArrayS32Float:
265   case NVPTXISD::TexUnified1DArrayS32FloatLevel:
266   case NVPTXISD::TexUnified1DArrayS32FloatGrad:
267   case NVPTXISD::TexUnified1DArrayU32S32:
268   case NVPTXISD::TexUnified1DArrayU32Float:
269   case NVPTXISD::TexUnified1DArrayU32FloatLevel:
270   case NVPTXISD::TexUnified1DArrayU32FloatGrad:
271   case NVPTXISD::TexUnified2DFloatS32:
272   case NVPTXISD::TexUnified2DFloatFloat:
273   case NVPTXISD::TexUnified2DFloatFloatLevel:
274   case NVPTXISD::TexUnified2DFloatFloatGrad:
275   case NVPTXISD::TexUnified2DS32S32:
276   case NVPTXISD::TexUnified2DS32Float:
277   case NVPTXISD::TexUnified2DS32FloatLevel:
278   case NVPTXISD::TexUnified2DS32FloatGrad:
279   case NVPTXISD::TexUnified2DU32S32:
280   case NVPTXISD::TexUnified2DU32Float:
281   case NVPTXISD::TexUnified2DU32FloatLevel:
282   case NVPTXISD::TexUnified2DU32FloatGrad:
283   case NVPTXISD::TexUnified2DArrayFloatS32:
284   case NVPTXISD::TexUnified2DArrayFloatFloat:
285   case NVPTXISD::TexUnified2DArrayFloatFloatLevel:
286   case NVPTXISD::TexUnified2DArrayFloatFloatGrad:
287   case NVPTXISD::TexUnified2DArrayS32S32:
288   case NVPTXISD::TexUnified2DArrayS32Float:
289   case NVPTXISD::TexUnified2DArrayS32FloatLevel:
290   case NVPTXISD::TexUnified2DArrayS32FloatGrad:
291   case NVPTXISD::TexUnified2DArrayU32S32:
292   case NVPTXISD::TexUnified2DArrayU32Float:
293   case NVPTXISD::TexUnified2DArrayU32FloatLevel:
294   case NVPTXISD::TexUnified2DArrayU32FloatGrad:
295   case NVPTXISD::TexUnified3DFloatS32:
296   case NVPTXISD::TexUnified3DFloatFloat:
297   case NVPTXISD::TexUnified3DFloatFloatLevel:
298   case NVPTXISD::TexUnified3DFloatFloatGrad:
299   case NVPTXISD::TexUnified3DS32S32:
300   case NVPTXISD::TexUnified3DS32Float:
301   case NVPTXISD::TexUnified3DS32FloatLevel:
302   case NVPTXISD::TexUnified3DS32FloatGrad:
303   case NVPTXISD::TexUnified3DU32S32:
304   case NVPTXISD::TexUnified3DU32Float:
305   case NVPTXISD::TexUnified3DU32FloatLevel:
306   case NVPTXISD::TexUnified3DU32FloatGrad:
307   case NVPTXISD::TexUnifiedCubeFloatFloat:
308   case NVPTXISD::TexUnifiedCubeFloatFloatLevel:
309   case NVPTXISD::TexUnifiedCubeS32Float:
310   case NVPTXISD::TexUnifiedCubeS32FloatLevel:
311   case NVPTXISD::TexUnifiedCubeU32Float:
312   case NVPTXISD::TexUnifiedCubeU32FloatLevel:
313   case NVPTXISD::TexUnifiedCubeArrayFloatFloat:
314   case NVPTXISD::TexUnifiedCubeArrayFloatFloatLevel:
315   case NVPTXISD::TexUnifiedCubeArrayS32Float:
316   case NVPTXISD::TexUnifiedCubeArrayS32FloatLevel:
317   case NVPTXISD::TexUnifiedCubeArrayU32Float:
318   case NVPTXISD::TexUnifiedCubeArrayU32FloatLevel:
319   case NVPTXISD::TexUnifiedCubeFloatFloatGrad:
320   case NVPTXISD::TexUnifiedCubeS32FloatGrad:
321   case NVPTXISD::TexUnifiedCubeU32FloatGrad:
322   case NVPTXISD::TexUnifiedCubeArrayFloatFloatGrad:
323   case NVPTXISD::TexUnifiedCubeArrayS32FloatGrad:
324   case NVPTXISD::TexUnifiedCubeArrayU32FloatGrad:
325   case NVPTXISD::Tld4UnifiedR2DFloatFloat:
326   case NVPTXISD::Tld4UnifiedG2DFloatFloat:
327   case NVPTXISD::Tld4UnifiedB2DFloatFloat:
328   case NVPTXISD::Tld4UnifiedA2DFloatFloat:
329   case NVPTXISD::Tld4UnifiedR2DS64Float:
330   case NVPTXISD::Tld4UnifiedG2DS64Float:
331   case NVPTXISD::Tld4UnifiedB2DS64Float:
332   case NVPTXISD::Tld4UnifiedA2DS64Float:
333   case NVPTXISD::Tld4UnifiedR2DU64Float:
334   case NVPTXISD::Tld4UnifiedG2DU64Float:
335   case NVPTXISD::Tld4UnifiedB2DU64Float:
336   case NVPTXISD::Tld4UnifiedA2DU64Float:
337     if (tryTextureIntrinsic(N))
338       return;
339     break;
340   case NVPTXISD::Suld1DI8Clamp:
341   case NVPTXISD::Suld1DI16Clamp:
342   case NVPTXISD::Suld1DI32Clamp:
343   case NVPTXISD::Suld1DI64Clamp:
344   case NVPTXISD::Suld1DV2I8Clamp:
345   case NVPTXISD::Suld1DV2I16Clamp:
346   case NVPTXISD::Suld1DV2I32Clamp:
347   case NVPTXISD::Suld1DV2I64Clamp:
348   case NVPTXISD::Suld1DV4I8Clamp:
349   case NVPTXISD::Suld1DV4I16Clamp:
350   case NVPTXISD::Suld1DV4I32Clamp:
351   case NVPTXISD::Suld1DArrayI8Clamp:
352   case NVPTXISD::Suld1DArrayI16Clamp:
353   case NVPTXISD::Suld1DArrayI32Clamp:
354   case NVPTXISD::Suld1DArrayI64Clamp:
355   case NVPTXISD::Suld1DArrayV2I8Clamp:
356   case NVPTXISD::Suld1DArrayV2I16Clamp:
357   case NVPTXISD::Suld1DArrayV2I32Clamp:
358   case NVPTXISD::Suld1DArrayV2I64Clamp:
359   case NVPTXISD::Suld1DArrayV4I8Clamp:
360   case NVPTXISD::Suld1DArrayV4I16Clamp:
361   case NVPTXISD::Suld1DArrayV4I32Clamp:
362   case NVPTXISD::Suld2DI8Clamp:
363   case NVPTXISD::Suld2DI16Clamp:
364   case NVPTXISD::Suld2DI32Clamp:
365   case NVPTXISD::Suld2DI64Clamp:
366   case NVPTXISD::Suld2DV2I8Clamp:
367   case NVPTXISD::Suld2DV2I16Clamp:
368   case NVPTXISD::Suld2DV2I32Clamp:
369   case NVPTXISD::Suld2DV2I64Clamp:
370   case NVPTXISD::Suld2DV4I8Clamp:
371   case NVPTXISD::Suld2DV4I16Clamp:
372   case NVPTXISD::Suld2DV4I32Clamp:
373   case NVPTXISD::Suld2DArrayI8Clamp:
374   case NVPTXISD::Suld2DArrayI16Clamp:
375   case NVPTXISD::Suld2DArrayI32Clamp:
376   case NVPTXISD::Suld2DArrayI64Clamp:
377   case NVPTXISD::Suld2DArrayV2I8Clamp:
378   case NVPTXISD::Suld2DArrayV2I16Clamp:
379   case NVPTXISD::Suld2DArrayV2I32Clamp:
380   case NVPTXISD::Suld2DArrayV2I64Clamp:
381   case NVPTXISD::Suld2DArrayV4I8Clamp:
382   case NVPTXISD::Suld2DArrayV4I16Clamp:
383   case NVPTXISD::Suld2DArrayV4I32Clamp:
384   case NVPTXISD::Suld3DI8Clamp:
385   case NVPTXISD::Suld3DI16Clamp:
386   case NVPTXISD::Suld3DI32Clamp:
387   case NVPTXISD::Suld3DI64Clamp:
388   case NVPTXISD::Suld3DV2I8Clamp:
389   case NVPTXISD::Suld3DV2I16Clamp:
390   case NVPTXISD::Suld3DV2I32Clamp:
391   case NVPTXISD::Suld3DV2I64Clamp:
392   case NVPTXISD::Suld3DV4I8Clamp:
393   case NVPTXISD::Suld3DV4I16Clamp:
394   case NVPTXISD::Suld3DV4I32Clamp:
395   case NVPTXISD::Suld1DI8Trap:
396   case NVPTXISD::Suld1DI16Trap:
397   case NVPTXISD::Suld1DI32Trap:
398   case NVPTXISD::Suld1DI64Trap:
399   case NVPTXISD::Suld1DV2I8Trap:
400   case NVPTXISD::Suld1DV2I16Trap:
401   case NVPTXISD::Suld1DV2I32Trap:
402   case NVPTXISD::Suld1DV2I64Trap:
403   case NVPTXISD::Suld1DV4I8Trap:
404   case NVPTXISD::Suld1DV4I16Trap:
405   case NVPTXISD::Suld1DV4I32Trap:
406   case NVPTXISD::Suld1DArrayI8Trap:
407   case NVPTXISD::Suld1DArrayI16Trap:
408   case NVPTXISD::Suld1DArrayI32Trap:
409   case NVPTXISD::Suld1DArrayI64Trap:
410   case NVPTXISD::Suld1DArrayV2I8Trap:
411   case NVPTXISD::Suld1DArrayV2I16Trap:
412   case NVPTXISD::Suld1DArrayV2I32Trap:
413   case NVPTXISD::Suld1DArrayV2I64Trap:
414   case NVPTXISD::Suld1DArrayV4I8Trap:
415   case NVPTXISD::Suld1DArrayV4I16Trap:
416   case NVPTXISD::Suld1DArrayV4I32Trap:
417   case NVPTXISD::Suld2DI8Trap:
418   case NVPTXISD::Suld2DI16Trap:
419   case NVPTXISD::Suld2DI32Trap:
420   case NVPTXISD::Suld2DI64Trap:
421   case NVPTXISD::Suld2DV2I8Trap:
422   case NVPTXISD::Suld2DV2I16Trap:
423   case NVPTXISD::Suld2DV2I32Trap:
424   case NVPTXISD::Suld2DV2I64Trap:
425   case NVPTXISD::Suld2DV4I8Trap:
426   case NVPTXISD::Suld2DV4I16Trap:
427   case NVPTXISD::Suld2DV4I32Trap:
428   case NVPTXISD::Suld2DArrayI8Trap:
429   case NVPTXISD::Suld2DArrayI16Trap:
430   case NVPTXISD::Suld2DArrayI32Trap:
431   case NVPTXISD::Suld2DArrayI64Trap:
432   case NVPTXISD::Suld2DArrayV2I8Trap:
433   case NVPTXISD::Suld2DArrayV2I16Trap:
434   case NVPTXISD::Suld2DArrayV2I32Trap:
435   case NVPTXISD::Suld2DArrayV2I64Trap:
436   case NVPTXISD::Suld2DArrayV4I8Trap:
437   case NVPTXISD::Suld2DArrayV4I16Trap:
438   case NVPTXISD::Suld2DArrayV4I32Trap:
439   case NVPTXISD::Suld3DI8Trap:
440   case NVPTXISD::Suld3DI16Trap:
441   case NVPTXISD::Suld3DI32Trap:
442   case NVPTXISD::Suld3DI64Trap:
443   case NVPTXISD::Suld3DV2I8Trap:
444   case NVPTXISD::Suld3DV2I16Trap:
445   case NVPTXISD::Suld3DV2I32Trap:
446   case NVPTXISD::Suld3DV2I64Trap:
447   case NVPTXISD::Suld3DV4I8Trap:
448   case NVPTXISD::Suld3DV4I16Trap:
449   case NVPTXISD::Suld3DV4I32Trap:
450   case NVPTXISD::Suld1DI8Zero:
451   case NVPTXISD::Suld1DI16Zero:
452   case NVPTXISD::Suld1DI32Zero:
453   case NVPTXISD::Suld1DI64Zero:
454   case NVPTXISD::Suld1DV2I8Zero:
455   case NVPTXISD::Suld1DV2I16Zero:
456   case NVPTXISD::Suld1DV2I32Zero:
457   case NVPTXISD::Suld1DV2I64Zero:
458   case NVPTXISD::Suld1DV4I8Zero:
459   case NVPTXISD::Suld1DV4I16Zero:
460   case NVPTXISD::Suld1DV4I32Zero:
461   case NVPTXISD::Suld1DArrayI8Zero:
462   case NVPTXISD::Suld1DArrayI16Zero:
463   case NVPTXISD::Suld1DArrayI32Zero:
464   case NVPTXISD::Suld1DArrayI64Zero:
465   case NVPTXISD::Suld1DArrayV2I8Zero:
466   case NVPTXISD::Suld1DArrayV2I16Zero:
467   case NVPTXISD::Suld1DArrayV2I32Zero:
468   case NVPTXISD::Suld1DArrayV2I64Zero:
469   case NVPTXISD::Suld1DArrayV4I8Zero:
470   case NVPTXISD::Suld1DArrayV4I16Zero:
471   case NVPTXISD::Suld1DArrayV4I32Zero:
472   case NVPTXISD::Suld2DI8Zero:
473   case NVPTXISD::Suld2DI16Zero:
474   case NVPTXISD::Suld2DI32Zero:
475   case NVPTXISD::Suld2DI64Zero:
476   case NVPTXISD::Suld2DV2I8Zero:
477   case NVPTXISD::Suld2DV2I16Zero:
478   case NVPTXISD::Suld2DV2I32Zero:
479   case NVPTXISD::Suld2DV2I64Zero:
480   case NVPTXISD::Suld2DV4I8Zero:
481   case NVPTXISD::Suld2DV4I16Zero:
482   case NVPTXISD::Suld2DV4I32Zero:
483   case NVPTXISD::Suld2DArrayI8Zero:
484   case NVPTXISD::Suld2DArrayI16Zero:
485   case NVPTXISD::Suld2DArrayI32Zero:
486   case NVPTXISD::Suld2DArrayI64Zero:
487   case NVPTXISD::Suld2DArrayV2I8Zero:
488   case NVPTXISD::Suld2DArrayV2I16Zero:
489   case NVPTXISD::Suld2DArrayV2I32Zero:
490   case NVPTXISD::Suld2DArrayV2I64Zero:
491   case NVPTXISD::Suld2DArrayV4I8Zero:
492   case NVPTXISD::Suld2DArrayV4I16Zero:
493   case NVPTXISD::Suld2DArrayV4I32Zero:
494   case NVPTXISD::Suld3DI8Zero:
495   case NVPTXISD::Suld3DI16Zero:
496   case NVPTXISD::Suld3DI32Zero:
497   case NVPTXISD::Suld3DI64Zero:
498   case NVPTXISD::Suld3DV2I8Zero:
499   case NVPTXISD::Suld3DV2I16Zero:
500   case NVPTXISD::Suld3DV2I32Zero:
501   case NVPTXISD::Suld3DV2I64Zero:
502   case NVPTXISD::Suld3DV4I8Zero:
503   case NVPTXISD::Suld3DV4I16Zero:
504   case NVPTXISD::Suld3DV4I32Zero:
505     if (trySurfaceIntrinsic(N))
506       return;
507     break;
508   case ISD::AND:
509   case ISD::SRA:
510   case ISD::SRL:
511     // Try to select BFE
512     if (tryBFE(N))
513       return;
514     break;
515   case ISD::ADDRSPACECAST:
516     SelectAddrSpaceCast(N);
517     return;
518   case ISD::ConstantFP:
519     if (tryConstantFP(N))
520       return;
521     break;
522   case ISD::CopyToReg: {
523     if (N->getOperand(1).getValueType() == MVT::i128) {
524       SelectV2I64toI128(N);
525       return;
526     }
527     break;
528   }
529   case ISD::CopyFromReg: {
530     if (N->getOperand(1).getValueType() == MVT::i128) {
531       SelectI128toV2I64(N);
532       return;
533     }
534     break;
535   }
536   default:
537     break;
538   }
539   SelectCode(N);
540 }
541 
tryIntrinsicChain(SDNode * N)542 bool NVPTXDAGToDAGISel::tryIntrinsicChain(SDNode *N) {
543   unsigned IID = N->getConstantOperandVal(1);
544   switch (IID) {
545   default:
546     return false;
547   case Intrinsic::nvvm_ldg_global_f:
548   case Intrinsic::nvvm_ldg_global_i:
549   case Intrinsic::nvvm_ldg_global_p:
550   case Intrinsic::nvvm_ldu_global_f:
551   case Intrinsic::nvvm_ldu_global_i:
552   case Intrinsic::nvvm_ldu_global_p:
553     return tryLDGLDU(N);
554   }
555 }
556 
557 // There's no way to specify FP16 and BF16 immediates in .(b)f16 ops, so we
558 // have to load them into an .(b)f16 register first.
tryConstantFP(SDNode * N)559 bool NVPTXDAGToDAGISel::tryConstantFP(SDNode *N) {
560   if (N->getValueType(0) != MVT::f16 && N->getValueType(0) != MVT::bf16)
561     return false;
562   SDValue Val = CurDAG->getTargetConstantFP(
563       cast<ConstantFPSDNode>(N)->getValueAPF(), SDLoc(N), N->getValueType(0));
564   SDNode *LoadConstF16 = CurDAG->getMachineNode(
565       (N->getValueType(0) == MVT::f16 ? NVPTX::LOAD_CONST_F16
566                                       : NVPTX::LOAD_CONST_BF16),
567       SDLoc(N), N->getValueType(0), Val);
568   ReplaceNode(N, LoadConstF16);
569   return true;
570 }
571 
572 // Map ISD:CONDCODE value to appropriate CmpMode expected by
573 // NVPTXInstPrinter::printCmpMode()
getPTXCmpMode(const CondCodeSDNode & CondCode,bool FTZ)574 static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
575   using NVPTX::PTXCmpMode::CmpMode;
576   unsigned PTXCmpMode = [](ISD::CondCode CC) {
577     switch (CC) {
578     default:
579       llvm_unreachable("Unexpected condition code.");
580     case ISD::SETOEQ:
581       return CmpMode::EQ;
582     case ISD::SETOGT:
583       return CmpMode::GT;
584     case ISD::SETOGE:
585       return CmpMode::GE;
586     case ISD::SETOLT:
587       return CmpMode::LT;
588     case ISD::SETOLE:
589       return CmpMode::LE;
590     case ISD::SETONE:
591       return CmpMode::NE;
592     case ISD::SETO:
593       return CmpMode::NUM;
594     case ISD::SETUO:
595       return CmpMode::NotANumber;
596     case ISD::SETUEQ:
597       return CmpMode::EQU;
598     case ISD::SETUGT:
599       return CmpMode::GTU;
600     case ISD::SETUGE:
601       return CmpMode::GEU;
602     case ISD::SETULT:
603       return CmpMode::LTU;
604     case ISD::SETULE:
605       return CmpMode::LEU;
606     case ISD::SETUNE:
607       return CmpMode::NEU;
608     case ISD::SETEQ:
609       return CmpMode::EQ;
610     case ISD::SETGT:
611       return CmpMode::GT;
612     case ISD::SETGE:
613       return CmpMode::GE;
614     case ISD::SETLT:
615       return CmpMode::LT;
616     case ISD::SETLE:
617       return CmpMode::LE;
618     case ISD::SETNE:
619       return CmpMode::NE;
620     }
621   }(CondCode.get());
622 
623   if (FTZ)
624     PTXCmpMode |= NVPTX::PTXCmpMode::FTZ_FLAG;
625 
626   return PTXCmpMode;
627 }
628 
SelectSETP_F16X2(SDNode * N)629 bool NVPTXDAGToDAGISel::SelectSETP_F16X2(SDNode *N) {
630   unsigned PTXCmpMode =
631       getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
632   SDLoc DL(N);
633   SDNode *SetP = CurDAG->getMachineNode(
634       NVPTX::SETP_f16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
635       N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
636   ReplaceNode(N, SetP);
637   return true;
638 }
639 
SelectSETP_BF16X2(SDNode * N)640 bool NVPTXDAGToDAGISel::SelectSETP_BF16X2(SDNode *N) {
641   unsigned PTXCmpMode =
642       getPTXCmpMode(*cast<CondCodeSDNode>(N->getOperand(2)), useF32FTZ());
643   SDLoc DL(N);
644   SDNode *SetP = CurDAG->getMachineNode(
645       NVPTX::SETP_bf16x2rr, DL, MVT::i1, MVT::i1, N->getOperand(0),
646       N->getOperand(1), CurDAG->getTargetConstant(PTXCmpMode, DL, MVT::i32));
647   ReplaceNode(N, SetP);
648   return true;
649 }
650 
651 // Find all instances of extract_vector_elt that use this v2f16 vector
652 // and coalesce them into a scattering move instruction.
tryEXTRACT_VECTOR_ELEMENT(SDNode * N)653 bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
654   SDValue Vector = N->getOperand(0);
655 
656   // We only care about 16x2 as it's the only real vector type we
657   // need to deal with.
658   MVT VT = Vector.getSimpleValueType();
659   if (!Isv2x16VT(VT))
660     return false;
661   // Find and record all uses of this vector that extract element 0 or 1.
662   SmallVector<SDNode *, 4> E0, E1;
663   for (auto *U : Vector.getNode()->uses()) {
664     if (U->getOpcode() != ISD::EXTRACT_VECTOR_ELT)
665       continue;
666     if (U->getOperand(0) != Vector)
667       continue;
668     if (const ConstantSDNode *IdxConst =
669             dyn_cast<ConstantSDNode>(U->getOperand(1))) {
670       if (IdxConst->getZExtValue() == 0)
671         E0.push_back(U);
672       else if (IdxConst->getZExtValue() == 1)
673         E1.push_back(U);
674       else
675         llvm_unreachable("Invalid vector index.");
676     }
677   }
678 
679   // There's no point scattering f16x2 if we only ever access one
680   // element of it.
681   if (E0.empty() || E1.empty())
682     return false;
683 
684   // Merge (f16 extractelt(V, 0), f16 extractelt(V,1))
685   // into f16,f16 SplitF16x2(V)
686   MVT EltVT = VT.getVectorElementType();
687   SDNode *ScatterOp =
688       CurDAG->getMachineNode(NVPTX::I32toV2I16, SDLoc(N), EltVT, EltVT, Vector);
689   for (auto *Node : E0)
690     ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 0));
691   for (auto *Node : E1)
692     ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 1));
693 
694   return true;
695 }
696 
getCodeAddrSpace(MemSDNode * N)697 static unsigned int getCodeAddrSpace(MemSDNode *N) {
698   const Value *Src = N->getMemOperand()->getValue();
699 
700   if (!Src)
701     return NVPTX::PTXLdStInstCode::GENERIC;
702 
703   if (auto *PT = dyn_cast<PointerType>(Src->getType())) {
704     switch (PT->getAddressSpace()) {
705     case llvm::ADDRESS_SPACE_LOCAL: return NVPTX::PTXLdStInstCode::LOCAL;
706     case llvm::ADDRESS_SPACE_GLOBAL: return NVPTX::PTXLdStInstCode::GLOBAL;
707     case llvm::ADDRESS_SPACE_SHARED: return NVPTX::PTXLdStInstCode::SHARED;
708     case llvm::ADDRESS_SPACE_GENERIC: return NVPTX::PTXLdStInstCode::GENERIC;
709     case llvm::ADDRESS_SPACE_PARAM: return NVPTX::PTXLdStInstCode::PARAM;
710     case llvm::ADDRESS_SPACE_CONST: return NVPTX::PTXLdStInstCode::CONSTANT;
711     default: break;
712     }
713   }
714   return NVPTX::PTXLdStInstCode::GENERIC;
715 }
716 
canLowerToLDG(MemSDNode * N,const NVPTXSubtarget & Subtarget,unsigned CodeAddrSpace,MachineFunction * F)717 static bool canLowerToLDG(MemSDNode *N, const NVPTXSubtarget &Subtarget,
718                           unsigned CodeAddrSpace, MachineFunction *F) {
719   // We use ldg (i.e. ld.global.nc) for invariant loads from the global address
720   // space.
721   //
722   // We have two ways of identifying invariant loads: Loads may be explicitly
723   // marked as invariant, or we may infer them to be invariant.
724   //
725   // We currently infer invariance for loads from
726   //  - constant global variables, and
727   //  - kernel function pointer params that are noalias (i.e. __restrict) and
728   //    never written to.
729   //
730   // TODO: Perform a more powerful invariance analysis (ideally IPO, and ideally
731   // not during the SelectionDAG phase).
732   //
733   // TODO: Infer invariance only at -O2.  We still want to use ldg at -O0 for
734   // explicitly invariant loads because these are how clang tells us to use ldg
735   // when the user uses a builtin.
736   if (!Subtarget.hasLDG() || CodeAddrSpace != NVPTX::PTXLdStInstCode::GLOBAL)
737     return false;
738 
739   if (N->isInvariant())
740     return true;
741 
742   bool IsKernelFn = isKernelFunction(F->getFunction());
743 
744   // We use getUnderlyingObjects() here instead of getUnderlyingObject() mainly
745   // because the former looks through phi nodes while the latter does not. We
746   // need to look through phi nodes to handle pointer induction variables.
747   SmallVector<const Value *, 8> Objs;
748   getUnderlyingObjects(N->getMemOperand()->getValue(), Objs);
749 
750   return all_of(Objs, [&](const Value *V) {
751     if (auto *A = dyn_cast<const Argument>(V))
752       return IsKernelFn && A->onlyReadsMemory() && A->hasNoAliasAttr();
753     if (auto *GV = dyn_cast<const GlobalVariable>(V))
754       return GV->isConstant();
755     return false;
756   });
757 }
758 
tryIntrinsicNoChain(SDNode * N)759 bool NVPTXDAGToDAGISel::tryIntrinsicNoChain(SDNode *N) {
760   unsigned IID = N->getConstantOperandVal(0);
761   switch (IID) {
762   default:
763     return false;
764   case Intrinsic::nvvm_texsurf_handle_internal:
765     SelectTexSurfHandle(N);
766     return true;
767   }
768 }
769 
SelectTexSurfHandle(SDNode * N)770 void NVPTXDAGToDAGISel::SelectTexSurfHandle(SDNode *N) {
771   // Op 0 is the intrinsic ID
772   SDValue Wrapper = N->getOperand(1);
773   SDValue GlobalVal = Wrapper.getOperand(0);
774   ReplaceNode(N, CurDAG->getMachineNode(NVPTX::texsurf_handles, SDLoc(N),
775                                         MVT::i64, GlobalVal));
776 }
777 
SelectAddrSpaceCast(SDNode * N)778 void NVPTXDAGToDAGISel::SelectAddrSpaceCast(SDNode *N) {
779   SDValue Src = N->getOperand(0);
780   AddrSpaceCastSDNode *CastN = cast<AddrSpaceCastSDNode>(N);
781   unsigned SrcAddrSpace = CastN->getSrcAddressSpace();
782   unsigned DstAddrSpace = CastN->getDestAddressSpace();
783   assert(SrcAddrSpace != DstAddrSpace &&
784          "addrspacecast must be between different address spaces");
785 
786   if (DstAddrSpace == ADDRESS_SPACE_GENERIC) {
787     // Specific to generic
788     unsigned Opc;
789     switch (SrcAddrSpace) {
790     default: report_fatal_error("Bad address space in addrspacecast");
791     case ADDRESS_SPACE_GLOBAL:
792       Opc = TM.is64Bit() ? NVPTX::cvta_global_64 : NVPTX::cvta_global;
793       break;
794     case ADDRESS_SPACE_SHARED:
795       Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32
796                                 ? NVPTX::cvta_shared_6432
797                                 : NVPTX::cvta_shared_64)
798                          : NVPTX::cvta_shared;
799       break;
800     case ADDRESS_SPACE_CONST:
801       Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32
802                                 ? NVPTX::cvta_const_6432
803                                 : NVPTX::cvta_const_64)
804                          : NVPTX::cvta_const;
805       break;
806     case ADDRESS_SPACE_LOCAL:
807       Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(SrcAddrSpace) == 32
808                                 ? NVPTX::cvta_local_6432
809                                 : NVPTX::cvta_local_64)
810                          : NVPTX::cvta_local;
811       break;
812     }
813     ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0),
814                                           Src));
815     return;
816   } else {
817     // Generic to specific
818     if (SrcAddrSpace != 0)
819       report_fatal_error("Cannot cast between two non-generic address spaces");
820     unsigned Opc;
821     switch (DstAddrSpace) {
822     default: report_fatal_error("Bad address space in addrspacecast");
823     case ADDRESS_SPACE_GLOBAL:
824       Opc = TM.is64Bit() ? NVPTX::cvta_to_global_64 : NVPTX::cvta_to_global;
825       break;
826     case ADDRESS_SPACE_SHARED:
827       Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32
828                                 ? NVPTX::cvta_to_shared_3264
829                                 : NVPTX::cvta_to_shared_64)
830                          : NVPTX::cvta_to_shared;
831       break;
832     case ADDRESS_SPACE_CONST:
833       Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32
834                                 ? NVPTX::cvta_to_const_3264
835                                 : NVPTX::cvta_to_const_64)
836                          : NVPTX::cvta_to_const;
837       break;
838     case ADDRESS_SPACE_LOCAL:
839       Opc = TM.is64Bit() ? (TM.getPointerSizeInBits(DstAddrSpace) == 32
840                                 ? NVPTX::cvta_to_local_3264
841                                 : NVPTX::cvta_to_local_64)
842                          : NVPTX::cvta_to_local;
843       break;
844     case ADDRESS_SPACE_PARAM:
845       Opc = TM.is64Bit() ? NVPTX::nvvm_ptr_gen_to_param_64
846                          : NVPTX::nvvm_ptr_gen_to_param;
847       break;
848     }
849     ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getValueType(0),
850                                           Src));
851     return;
852   }
853 }
854 
855 // Helper function template to reduce amount of boilerplate code for
856 // opcode selection.
857 static std::optional<unsigned>
pickOpcodeForVT(MVT::SimpleValueType VT,unsigned Opcode_i8,unsigned Opcode_i16,unsigned Opcode_i32,std::optional<unsigned> Opcode_i64,unsigned Opcode_f32,std::optional<unsigned> Opcode_f64)858 pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
859                 unsigned Opcode_i16, unsigned Opcode_i32,
860                 std::optional<unsigned> Opcode_i64, unsigned Opcode_f32,
861                 std::optional<unsigned> Opcode_f64) {
862   switch (VT) {
863   case MVT::i1:
864   case MVT::i8:
865     return Opcode_i8;
866   case MVT::i16:
867     return Opcode_i16;
868   case MVT::i32:
869     return Opcode_i32;
870   case MVT::i64:
871     return Opcode_i64;
872   case MVT::f16:
873   case MVT::bf16:
874     return Opcode_i16;
875   case MVT::v2f16:
876   case MVT::v2bf16:
877   case MVT::v2i16:
878   case MVT::v4i8:
879     return Opcode_i32;
880   case MVT::f32:
881     return Opcode_f32;
882   case MVT::f64:
883     return Opcode_f64;
884   default:
885     return std::nullopt;
886   }
887 }
888 
getLdStRegType(EVT VT)889 static int getLdStRegType(EVT VT) {
890   if (VT.isFloatingPoint())
891     switch (VT.getSimpleVT().SimpleTy) {
892     case MVT::f16:
893     case MVT::bf16:
894     case MVT::v2f16:
895     case MVT::v2bf16:
896       return NVPTX::PTXLdStInstCode::Untyped;
897     default:
898       return NVPTX::PTXLdStInstCode::Float;
899     }
900   else
901     return NVPTX::PTXLdStInstCode::Unsigned;
902 }
903 
tryLoad(SDNode * N)904 bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
905   SDLoc dl(N);
906   MemSDNode *LD = cast<MemSDNode>(N);
907   assert(LD->readMem() && "Expected load");
908   LoadSDNode *PlainLoad = dyn_cast<LoadSDNode>(N);
909   EVT LoadedVT = LD->getMemoryVT();
910   SDNode *NVPTXLD = nullptr;
911 
912   // do not support pre/post inc/dec
913   if (PlainLoad && PlainLoad->isIndexed())
914     return false;
915 
916   if (!LoadedVT.isSimple())
917     return false;
918 
919   AtomicOrdering Ordering = LD->getSuccessOrdering();
920   // In order to lower atomic loads with stronger guarantees we would need to
921   // use load.acquire or insert fences. However these features were only added
922   // with PTX ISA 6.0 / sm_70.
923   // TODO: Check if we can actually use the new instructions and implement them.
924   if (isStrongerThanMonotonic(Ordering))
925     return false;
926 
927   // Address Space Setting
928   unsigned int CodeAddrSpace = getCodeAddrSpace(LD);
929   if (canLowerToLDG(LD, *Subtarget, CodeAddrSpace, MF)) {
930     return tryLDGLDU(N);
931   }
932 
933   unsigned int PointerSize =
934       CurDAG->getDataLayout().getPointerSizeInBits(LD->getAddressSpace());
935 
936   // Volatile Setting
937   // - .volatile is only available for .global and .shared
938   // - .volatile has the same memory synchronization semantics as .relaxed.sys
939   bool isVolatile = LD->isVolatile() || Ordering == AtomicOrdering::Monotonic;
940   if (CodeAddrSpace != NVPTX::PTXLdStInstCode::GLOBAL &&
941       CodeAddrSpace != NVPTX::PTXLdStInstCode::SHARED &&
942       CodeAddrSpace != NVPTX::PTXLdStInstCode::GENERIC)
943     isVolatile = false;
944 
945   // Type Setting: fromType + fromTypeWidth
946   //
947   // Sign   : ISD::SEXTLOAD
948   // Unsign : ISD::ZEXTLOAD, ISD::NON_EXTLOAD or ISD::EXTLOAD and the
949   //          type is integer
950   // Float  : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
951   MVT SimpleVT = LoadedVT.getSimpleVT();
952   MVT ScalarVT = SimpleVT.getScalarType();
953   // Read at least 8 bits (predicates are stored as 8-bit values)
954   unsigned fromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
955   unsigned int fromType;
956 
957   // Vector Setting
958   unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
959   if (SimpleVT.isVector()) {
960     assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
961            "Unexpected vector type");
962     // v2f16/v2bf16/v2i16 is loaded using ld.b32
963     fromTypeWidth = 32;
964   }
965 
966   if (PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))
967     fromType = NVPTX::PTXLdStInstCode::Signed;
968   else
969     fromType = getLdStRegType(ScalarVT);
970 
971   // Create the machine instruction DAG
972   SDValue Chain = N->getOperand(0);
973   SDValue N1 = N->getOperand(1);
974   SDValue Addr;
975   SDValue Offset, Base;
976   std::optional<unsigned> Opcode;
977   MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
978 
979   if (SelectDirectAddr(N1, Addr)) {
980     Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_i8_avar, NVPTX::LD_i16_avar,
981                              NVPTX::LD_i32_avar, NVPTX::LD_i64_avar,
982                              NVPTX::LD_f32_avar, NVPTX::LD_f64_avar);
983     if (!Opcode)
984       return false;
985     SDValue Ops[] = { getI32Imm(isVolatile, dl), getI32Imm(CodeAddrSpace, dl),
986                       getI32Imm(vecType, dl), getI32Imm(fromType, dl),
987                       getI32Imm(fromTypeWidth, dl), Addr, Chain };
988     NVPTXLD = CurDAG->getMachineNode(*Opcode, dl, TargetVT, MVT::Other, Ops);
989   } else if (PointerSize == 64 ? SelectADDRsi64(N1.getNode(), N1, Base, Offset)
990                                : SelectADDRsi(N1.getNode(), N1, Base, Offset)) {
991     Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_i8_asi, NVPTX::LD_i16_asi,
992                              NVPTX::LD_i32_asi, NVPTX::LD_i64_asi,
993                              NVPTX::LD_f32_asi, NVPTX::LD_f64_asi);
994     if (!Opcode)
995       return false;
996     SDValue Ops[] = { getI32Imm(isVolatile, dl), getI32Imm(CodeAddrSpace, dl),
997                       getI32Imm(vecType, dl), getI32Imm(fromType, dl),
998                       getI32Imm(fromTypeWidth, dl), Base, Offset, Chain };
999     NVPTXLD = CurDAG->getMachineNode(*Opcode, dl, TargetVT, MVT::Other, Ops);
1000   } else if (PointerSize == 64 ? SelectADDRri64(N1.getNode(), N1, Base, Offset)
1001                                : SelectADDRri(N1.getNode(), N1, Base, Offset)) {
1002     if (PointerSize == 64)
1003       Opcode =
1004           pickOpcodeForVT(TargetVT, NVPTX::LD_i8_ari_64, NVPTX::LD_i16_ari_64,
1005                           NVPTX::LD_i32_ari_64, NVPTX::LD_i64_ari_64,
1006                           NVPTX::LD_f32_ari_64, NVPTX::LD_f64_ari_64);
1007     else
1008       Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_i8_ari, NVPTX::LD_i16_ari,
1009                                NVPTX::LD_i32_ari, NVPTX::LD_i64_ari,
1010                                NVPTX::LD_f32_ari, NVPTX::LD_f64_ari);
1011     if (!Opcode)
1012       return false;
1013     SDValue Ops[] = { getI32Imm(isVolatile, dl), getI32Imm(CodeAddrSpace, dl),
1014                       getI32Imm(vecType, dl), getI32Imm(fromType, dl),
1015                       getI32Imm(fromTypeWidth, dl), Base, Offset, Chain };
1016     NVPTXLD = CurDAG->getMachineNode(*Opcode, dl, TargetVT, MVT::Other, Ops);
1017   } else {
1018     if (PointerSize == 64)
1019       Opcode =
1020           pickOpcodeForVT(TargetVT, NVPTX::LD_i8_areg_64, NVPTX::LD_i16_areg_64,
1021                           NVPTX::LD_i32_areg_64, NVPTX::LD_i64_areg_64,
1022                           NVPTX::LD_f32_areg_64, NVPTX::LD_f64_areg_64);
1023     else
1024       Opcode = pickOpcodeForVT(TargetVT, NVPTX::LD_i8_areg, NVPTX::LD_i16_areg,
1025                                NVPTX::LD_i32_areg, NVPTX::LD_i64_areg,
1026                                NVPTX::LD_f32_areg, NVPTX::LD_f64_areg);
1027     if (!Opcode)
1028       return false;
1029     SDValue Ops[] = { getI32Imm(isVolatile, dl), getI32Imm(CodeAddrSpace, dl),
1030                       getI32Imm(vecType, dl), getI32Imm(fromType, dl),
1031                       getI32Imm(fromTypeWidth, dl), N1, Chain };
1032     NVPTXLD = CurDAG->getMachineNode(*Opcode, dl, TargetVT, MVT::Other, Ops);
1033   }
1034 
1035   if (!NVPTXLD)
1036     return false;
1037 
1038   MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
1039   CurDAG->setNodeMemRefs(cast<MachineSDNode>(NVPTXLD), {MemRef});
1040 
1041   ReplaceNode(N, NVPTXLD);
1042   return true;
1043 }
1044 
tryLoadVector(SDNode * N)1045 bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
1046 
1047   SDValue Chain = N->getOperand(0);
1048   SDValue Op1 = N->getOperand(1);
1049   SDValue Addr, Offset, Base;
1050   std::optional<unsigned> Opcode;
1051   SDLoc DL(N);
1052   SDNode *LD;
1053   MemSDNode *MemSD = cast<MemSDNode>(N);
1054   EVT LoadedVT = MemSD->getMemoryVT();
1055 
1056   if (!LoadedVT.isSimple())
1057     return false;
1058 
1059   // Address Space Setting
1060   unsigned int CodeAddrSpace = getCodeAddrSpace(MemSD);
1061   if (canLowerToLDG(MemSD, *Subtarget, CodeAddrSpace, MF)) {
1062     return tryLDGLDU(N);
1063   }
1064 
1065   unsigned int PointerSize =
1066       CurDAG->getDataLayout().getPointerSizeInBits(MemSD->getAddressSpace());
1067 
1068   // Volatile Setting
1069   // - .volatile is only availalble for .global and .shared
1070   bool IsVolatile = MemSD->isVolatile();
1071   if (CodeAddrSpace != NVPTX::PTXLdStInstCode::GLOBAL &&
1072       CodeAddrSpace != NVPTX::PTXLdStInstCode::SHARED &&
1073       CodeAddrSpace != NVPTX::PTXLdStInstCode::GENERIC)
1074     IsVolatile = false;
1075 
1076   // Vector Setting
1077   MVT SimpleVT = LoadedVT.getSimpleVT();
1078 
1079   // Type Setting: fromType + fromTypeWidth
1080   //
1081   // Sign   : ISD::SEXTLOAD
1082   // Unsign : ISD::ZEXTLOAD, ISD::NON_EXTLOAD or ISD::EXTLOAD and the
1083   //          type is integer
1084   // Float  : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
1085   MVT ScalarVT = SimpleVT.getScalarType();
1086   // Read at least 8 bits (predicates are stored as 8-bit values)
1087   unsigned FromTypeWidth = std::max(8U, (unsigned)ScalarVT.getSizeInBits());
1088   unsigned int FromType;
1089   // The last operand holds the original LoadSDNode::getExtensionType() value
1090   unsigned ExtensionType = cast<ConstantSDNode>(
1091       N->getOperand(N->getNumOperands() - 1))->getZExtValue();
1092   if (ExtensionType == ISD::SEXTLOAD)
1093     FromType = NVPTX::PTXLdStInstCode::Signed;
1094   else
1095     FromType = getLdStRegType(ScalarVT);
1096 
1097   unsigned VecType;
1098 
1099   switch (N->getOpcode()) {
1100   case NVPTXISD::LoadV2:
1101     VecType = NVPTX::PTXLdStInstCode::V2;
1102     break;
1103   case NVPTXISD::LoadV4:
1104     VecType = NVPTX::PTXLdStInstCode::V4;
1105     break;
1106   default:
1107     return false;
1108   }
1109 
1110   EVT EltVT = N->getValueType(0);
1111 
1112   // v8x16 is a special case. PTX doesn't have ld.v8.16
1113   // instruction. Instead, we split the vector into v2x16 chunks and
1114   // load them with ld.v4.b32.
1115   if (Isv2x16VT(EltVT)) {
1116     assert(N->getOpcode() == NVPTXISD::LoadV4 && "Unexpected load opcode.");
1117     EltVT = MVT::i32;
1118     FromType = NVPTX::PTXLdStInstCode::Untyped;
1119     FromTypeWidth = 32;
1120   }
1121 
1122   if (SelectDirectAddr(Op1, Addr)) {
1123     switch (N->getOpcode()) {
1124     default:
1125       return false;
1126     case NVPTXISD::LoadV2:
1127       Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1128                                NVPTX::LDV_i8_v2_avar, NVPTX::LDV_i16_v2_avar,
1129                                NVPTX::LDV_i32_v2_avar, NVPTX::LDV_i64_v2_avar,
1130                                NVPTX::LDV_f32_v2_avar, NVPTX::LDV_f64_v2_avar);
1131       break;
1132     case NVPTXISD::LoadV4:
1133       Opcode =
1134           pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_avar,
1135                           NVPTX::LDV_i16_v4_avar, NVPTX::LDV_i32_v4_avar,
1136                           std::nullopt, NVPTX::LDV_f32_v4_avar, std::nullopt);
1137       break;
1138     }
1139     if (!Opcode)
1140       return false;
1141     SDValue Ops[] = { getI32Imm(IsVolatile, DL), getI32Imm(CodeAddrSpace, DL),
1142                       getI32Imm(VecType, DL), getI32Imm(FromType, DL),
1143                       getI32Imm(FromTypeWidth, DL), Addr, Chain };
1144     LD = CurDAG->getMachineNode(*Opcode, DL, N->getVTList(), Ops);
1145   } else if (PointerSize == 64
1146                  ? SelectADDRsi64(Op1.getNode(), Op1, Base, Offset)
1147                  : SelectADDRsi(Op1.getNode(), Op1, Base, Offset)) {
1148     switch (N->getOpcode()) {
1149     default:
1150       return false;
1151     case NVPTXISD::LoadV2:
1152       Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1153                                NVPTX::LDV_i8_v2_asi, NVPTX::LDV_i16_v2_asi,
1154                                NVPTX::LDV_i32_v2_asi, NVPTX::LDV_i64_v2_asi,
1155                                NVPTX::LDV_f32_v2_asi, NVPTX::LDV_f64_v2_asi);
1156       break;
1157     case NVPTXISD::LoadV4:
1158       Opcode =
1159           pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_asi,
1160                           NVPTX::LDV_i16_v4_asi, NVPTX::LDV_i32_v4_asi,
1161                           std::nullopt, NVPTX::LDV_f32_v4_asi, std::nullopt);
1162       break;
1163     }
1164     if (!Opcode)
1165       return false;
1166     SDValue Ops[] = { getI32Imm(IsVolatile, DL), getI32Imm(CodeAddrSpace, DL),
1167                       getI32Imm(VecType, DL), getI32Imm(FromType, DL),
1168                       getI32Imm(FromTypeWidth, DL), Base, Offset, Chain };
1169     LD = CurDAG->getMachineNode(*Opcode, DL, N->getVTList(), Ops);
1170   } else if (PointerSize == 64
1171                  ? SelectADDRri64(Op1.getNode(), Op1, Base, Offset)
1172                  : SelectADDRri(Op1.getNode(), Op1, Base, Offset)) {
1173     if (PointerSize == 64) {
1174       switch (N->getOpcode()) {
1175       default:
1176         return false;
1177       case NVPTXISD::LoadV2:
1178         Opcode =
1179             pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1180                             NVPTX::LDV_i8_v2_ari_64, NVPTX::LDV_i16_v2_ari_64,
1181                             NVPTX::LDV_i32_v2_ari_64, NVPTX::LDV_i64_v2_ari_64,
1182                             NVPTX::LDV_f32_v2_ari_64, NVPTX::LDV_f64_v2_ari_64);
1183         break;
1184       case NVPTXISD::LoadV4:
1185         Opcode = pickOpcodeForVT(
1186             EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_ari_64,
1187             NVPTX::LDV_i16_v4_ari_64, NVPTX::LDV_i32_v4_ari_64, std::nullopt,
1188             NVPTX::LDV_f32_v4_ari_64, std::nullopt);
1189         break;
1190       }
1191     } else {
1192       switch (N->getOpcode()) {
1193       default:
1194         return false;
1195       case NVPTXISD::LoadV2:
1196         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1197                                  NVPTX::LDV_i8_v2_ari, NVPTX::LDV_i16_v2_ari,
1198                                  NVPTX::LDV_i32_v2_ari, NVPTX::LDV_i64_v2_ari,
1199                                  NVPTX::LDV_f32_v2_ari, NVPTX::LDV_f64_v2_ari);
1200         break;
1201       case NVPTXISD::LoadV4:
1202         Opcode =
1203             pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_ari,
1204                             NVPTX::LDV_i16_v4_ari, NVPTX::LDV_i32_v4_ari,
1205                             std::nullopt, NVPTX::LDV_f32_v4_ari, std::nullopt);
1206         break;
1207       }
1208     }
1209     if (!Opcode)
1210       return false;
1211     SDValue Ops[] = { getI32Imm(IsVolatile, DL), getI32Imm(CodeAddrSpace, DL),
1212                       getI32Imm(VecType, DL), getI32Imm(FromType, DL),
1213                       getI32Imm(FromTypeWidth, DL), Base, Offset, Chain };
1214 
1215     LD = CurDAG->getMachineNode(*Opcode, DL, N->getVTList(), Ops);
1216   } else {
1217     if (PointerSize == 64) {
1218       switch (N->getOpcode()) {
1219       default:
1220         return false;
1221       case NVPTXISD::LoadV2:
1222         Opcode = pickOpcodeForVT(
1223             EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v2_areg_64,
1224             NVPTX::LDV_i16_v2_areg_64, NVPTX::LDV_i32_v2_areg_64,
1225             NVPTX::LDV_i64_v2_areg_64, NVPTX::LDV_f32_v2_areg_64,
1226             NVPTX::LDV_f64_v2_areg_64);
1227         break;
1228       case NVPTXISD::LoadV4:
1229         Opcode = pickOpcodeForVT(
1230             EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_areg_64,
1231             NVPTX::LDV_i16_v4_areg_64, NVPTX::LDV_i32_v4_areg_64, std::nullopt,
1232             NVPTX::LDV_f32_v4_areg_64, std::nullopt);
1233         break;
1234       }
1235     } else {
1236       switch (N->getOpcode()) {
1237       default:
1238         return false;
1239       case NVPTXISD::LoadV2:
1240         Opcode =
1241             pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v2_areg,
1242                             NVPTX::LDV_i16_v2_areg, NVPTX::LDV_i32_v2_areg,
1243                             NVPTX::LDV_i64_v2_areg, NVPTX::LDV_f32_v2_areg,
1244                             NVPTX::LDV_f64_v2_areg);
1245         break;
1246       case NVPTXISD::LoadV4:
1247         Opcode =
1248             pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::LDV_i8_v4_areg,
1249                             NVPTX::LDV_i16_v4_areg, NVPTX::LDV_i32_v4_areg,
1250                             std::nullopt, NVPTX::LDV_f32_v4_areg, std::nullopt);
1251         break;
1252       }
1253     }
1254     if (!Opcode)
1255       return false;
1256     SDValue Ops[] = { getI32Imm(IsVolatile, DL), getI32Imm(CodeAddrSpace, DL),
1257                       getI32Imm(VecType, DL), getI32Imm(FromType, DL),
1258                       getI32Imm(FromTypeWidth, DL), Op1, Chain };
1259     LD = CurDAG->getMachineNode(*Opcode, DL, N->getVTList(), Ops);
1260   }
1261 
1262   MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
1263   CurDAG->setNodeMemRefs(cast<MachineSDNode>(LD), {MemRef});
1264 
1265   ReplaceNode(N, LD);
1266   return true;
1267 }
1268 
tryLDGLDU(SDNode * N)1269 bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
1270 
1271   SDValue Chain = N->getOperand(0);
1272   SDValue Op1;
1273   MemSDNode *Mem;
1274   bool IsLDG = true;
1275 
1276   // If this is an LDG intrinsic, the address is the third operand. If its an
1277   // LDG/LDU SD node (from custom vector handling), then its the second operand
1278   if (N->getOpcode() == ISD::INTRINSIC_W_CHAIN) {
1279     Op1 = N->getOperand(2);
1280     Mem = cast<MemIntrinsicSDNode>(N);
1281     unsigned IID = N->getConstantOperandVal(1);
1282     switch (IID) {
1283     default:
1284       return false;
1285     case Intrinsic::nvvm_ldg_global_f:
1286     case Intrinsic::nvvm_ldg_global_i:
1287     case Intrinsic::nvvm_ldg_global_p:
1288       IsLDG = true;
1289       break;
1290     case Intrinsic::nvvm_ldu_global_f:
1291     case Intrinsic::nvvm_ldu_global_i:
1292     case Intrinsic::nvvm_ldu_global_p:
1293       IsLDG = false;
1294       break;
1295     }
1296   } else {
1297     Op1 = N->getOperand(1);
1298     Mem = cast<MemSDNode>(N);
1299   }
1300 
1301   std::optional<unsigned> Opcode;
1302   SDLoc DL(N);
1303   SDNode *LD;
1304   SDValue Base, Offset, Addr;
1305   EVT OrigType = N->getValueType(0);
1306 
1307   EVT EltVT = Mem->getMemoryVT();
1308   unsigned NumElts = 1;
1309   if (EltVT.isVector()) {
1310     NumElts = EltVT.getVectorNumElements();
1311     EltVT = EltVT.getVectorElementType();
1312     // vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
1313     if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
1314         (EltVT == MVT::bf16 && OrigType == MVT::v2bf16) ||
1315         (EltVT == MVT::i16 && OrigType == MVT::v2i16)) {
1316       assert(NumElts % 2 == 0 && "Vector must have even number of elements");
1317       EltVT = OrigType;
1318       NumElts /= 2;
1319     } else if (OrigType == MVT::v4i8) {
1320       EltVT = OrigType;
1321       NumElts = 1;
1322     }
1323   }
1324 
1325   // Build the "promoted" result VTList for the load. If we are really loading
1326   // i8s, then the return type will be promoted to i16 since we do not expose
1327   // 8-bit registers in NVPTX.
1328   EVT NodeVT = (EltVT == MVT::i8) ? MVT::i16 : EltVT;
1329   SmallVector<EVT, 5> InstVTs;
1330   for (unsigned i = 0; i != NumElts; ++i) {
1331     InstVTs.push_back(NodeVT);
1332   }
1333   InstVTs.push_back(MVT::Other);
1334   SDVTList InstVTList = CurDAG->getVTList(InstVTs);
1335 
1336   if (SelectDirectAddr(Op1, Addr)) {
1337     switch (N->getOpcode()) {
1338     default:
1339       return false;
1340     case ISD::LOAD:
1341     case ISD::INTRINSIC_W_CHAIN:
1342       if (IsLDG)
1343         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1344                                  NVPTX::INT_PTX_LDG_GLOBAL_i8avar,
1345                                  NVPTX::INT_PTX_LDG_GLOBAL_i16avar,
1346                                  NVPTX::INT_PTX_LDG_GLOBAL_i32avar,
1347                                  NVPTX::INT_PTX_LDG_GLOBAL_i64avar,
1348                                  NVPTX::INT_PTX_LDG_GLOBAL_f32avar,
1349                                  NVPTX::INT_PTX_LDG_GLOBAL_f64avar);
1350       else
1351         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1352                                  NVPTX::INT_PTX_LDU_GLOBAL_i8avar,
1353                                  NVPTX::INT_PTX_LDU_GLOBAL_i16avar,
1354                                  NVPTX::INT_PTX_LDU_GLOBAL_i32avar,
1355                                  NVPTX::INT_PTX_LDU_GLOBAL_i64avar,
1356                                  NVPTX::INT_PTX_LDU_GLOBAL_f32avar,
1357                                  NVPTX::INT_PTX_LDU_GLOBAL_f64avar);
1358       break;
1359     case NVPTXISD::LoadV2:
1360     case NVPTXISD::LDGV2:
1361       Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1362                                NVPTX::INT_PTX_LDG_G_v2i8_ELE_avar,
1363                                NVPTX::INT_PTX_LDG_G_v2i16_ELE_avar,
1364                                NVPTX::INT_PTX_LDG_G_v2i32_ELE_avar,
1365                                NVPTX::INT_PTX_LDG_G_v2i64_ELE_avar,
1366                                NVPTX::INT_PTX_LDG_G_v2f32_ELE_avar,
1367                                NVPTX::INT_PTX_LDG_G_v2f64_ELE_avar);
1368       break;
1369     case NVPTXISD::LDUV2:
1370       Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1371                                NVPTX::INT_PTX_LDU_G_v2i8_ELE_avar,
1372                                NVPTX::INT_PTX_LDU_G_v2i16_ELE_avar,
1373                                NVPTX::INT_PTX_LDU_G_v2i32_ELE_avar,
1374                                NVPTX::INT_PTX_LDU_G_v2i64_ELE_avar,
1375                                NVPTX::INT_PTX_LDU_G_v2f32_ELE_avar,
1376                                NVPTX::INT_PTX_LDU_G_v2f64_ELE_avar);
1377       break;
1378     case NVPTXISD::LoadV4:
1379     case NVPTXISD::LDGV4:
1380       Opcode = pickOpcodeForVT(
1381           EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_avar,
1382           NVPTX::INT_PTX_LDG_G_v4i16_ELE_avar,
1383           NVPTX::INT_PTX_LDG_G_v4i32_ELE_avar, std::nullopt,
1384           NVPTX::INT_PTX_LDG_G_v4f32_ELE_avar, std::nullopt);
1385       break;
1386     case NVPTXISD::LDUV4:
1387       Opcode = pickOpcodeForVT(
1388           EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_avar,
1389           NVPTX::INT_PTX_LDU_G_v4i16_ELE_avar,
1390           NVPTX::INT_PTX_LDU_G_v4i32_ELE_avar, std::nullopt,
1391           NVPTX::INT_PTX_LDU_G_v4f32_ELE_avar, std::nullopt);
1392       break;
1393     }
1394     if (!Opcode)
1395       return false;
1396     SDValue Ops[] = { Addr, Chain };
1397     LD = CurDAG->getMachineNode(*Opcode, DL, InstVTList, Ops);
1398   } else if (TM.is64Bit() ? SelectADDRri64(Op1.getNode(), Op1, Base, Offset)
1399                           : SelectADDRri(Op1.getNode(), Op1, Base, Offset)) {
1400     if (TM.is64Bit()) {
1401       switch (N->getOpcode()) {
1402       default:
1403         return false;
1404       case ISD::LOAD:
1405       case ISD::INTRINSIC_W_CHAIN:
1406         if (IsLDG)
1407           Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1408                                        NVPTX::INT_PTX_LDG_GLOBAL_i8ari64,
1409                                        NVPTX::INT_PTX_LDG_GLOBAL_i16ari64,
1410                                        NVPTX::INT_PTX_LDG_GLOBAL_i32ari64,
1411                                        NVPTX::INT_PTX_LDG_GLOBAL_i64ari64,
1412                                        NVPTX::INT_PTX_LDG_GLOBAL_f32ari64,
1413                                        NVPTX::INT_PTX_LDG_GLOBAL_f64ari64);
1414         else
1415           Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1416                                        NVPTX::INT_PTX_LDU_GLOBAL_i8ari64,
1417                                        NVPTX::INT_PTX_LDU_GLOBAL_i16ari64,
1418                                        NVPTX::INT_PTX_LDU_GLOBAL_i32ari64,
1419                                        NVPTX::INT_PTX_LDU_GLOBAL_i64ari64,
1420                                        NVPTX::INT_PTX_LDU_GLOBAL_f32ari64,
1421                                        NVPTX::INT_PTX_LDU_GLOBAL_f64ari64);
1422         break;
1423       case NVPTXISD::LoadV2:
1424       case NVPTXISD::LDGV2:
1425         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1426                                      NVPTX::INT_PTX_LDG_G_v2i8_ELE_ari64,
1427                                      NVPTX::INT_PTX_LDG_G_v2i16_ELE_ari64,
1428                                      NVPTX::INT_PTX_LDG_G_v2i32_ELE_ari64,
1429                                      NVPTX::INT_PTX_LDG_G_v2i64_ELE_ari64,
1430                                      NVPTX::INT_PTX_LDG_G_v2f32_ELE_ari64,
1431                                      NVPTX::INT_PTX_LDG_G_v2f64_ELE_ari64);
1432         break;
1433       case NVPTXISD::LDUV2:
1434         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1435                                      NVPTX::INT_PTX_LDU_G_v2i8_ELE_ari64,
1436                                      NVPTX::INT_PTX_LDU_G_v2i16_ELE_ari64,
1437                                      NVPTX::INT_PTX_LDU_G_v2i32_ELE_ari64,
1438                                      NVPTX::INT_PTX_LDU_G_v2i64_ELE_ari64,
1439                                      NVPTX::INT_PTX_LDU_G_v2f32_ELE_ari64,
1440                                      NVPTX::INT_PTX_LDU_G_v2f64_ELE_ari64);
1441         break;
1442       case NVPTXISD::LoadV4:
1443       case NVPTXISD::LDGV4:
1444         Opcode = pickOpcodeForVT(
1445             EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_ari64,
1446             NVPTX::INT_PTX_LDG_G_v4i16_ELE_ari64,
1447             NVPTX::INT_PTX_LDG_G_v4i32_ELE_ari64, std::nullopt,
1448             NVPTX::INT_PTX_LDG_G_v4f32_ELE_ari64, std::nullopt);
1449         break;
1450       case NVPTXISD::LDUV4:
1451         Opcode = pickOpcodeForVT(
1452             EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_ari64,
1453             NVPTX::INT_PTX_LDU_G_v4i16_ELE_ari64,
1454             NVPTX::INT_PTX_LDU_G_v4i32_ELE_ari64, std::nullopt,
1455             NVPTX::INT_PTX_LDU_G_v4f32_ELE_ari64, std::nullopt);
1456         break;
1457       }
1458     } else {
1459       switch (N->getOpcode()) {
1460       default:
1461         return false;
1462       case ISD::LOAD:
1463       case ISD::INTRINSIC_W_CHAIN:
1464         if (IsLDG)
1465           Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1466                                    NVPTX::INT_PTX_LDG_GLOBAL_i8ari,
1467                                    NVPTX::INT_PTX_LDG_GLOBAL_i16ari,
1468                                    NVPTX::INT_PTX_LDG_GLOBAL_i32ari,
1469                                    NVPTX::INT_PTX_LDG_GLOBAL_i64ari,
1470                                    NVPTX::INT_PTX_LDG_GLOBAL_f32ari,
1471                                    NVPTX::INT_PTX_LDG_GLOBAL_f64ari);
1472         else
1473           Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1474                                    NVPTX::INT_PTX_LDU_GLOBAL_i8ari,
1475                                    NVPTX::INT_PTX_LDU_GLOBAL_i16ari,
1476                                    NVPTX::INT_PTX_LDU_GLOBAL_i32ari,
1477                                    NVPTX::INT_PTX_LDU_GLOBAL_i64ari,
1478                                    NVPTX::INT_PTX_LDU_GLOBAL_f32ari,
1479                                    NVPTX::INT_PTX_LDU_GLOBAL_f64ari);
1480         break;
1481       case NVPTXISD::LoadV2:
1482       case NVPTXISD::LDGV2:
1483         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1484                                  NVPTX::INT_PTX_LDG_G_v2i8_ELE_ari32,
1485                                  NVPTX::INT_PTX_LDG_G_v2i16_ELE_ari32,
1486                                  NVPTX::INT_PTX_LDG_G_v2i32_ELE_ari32,
1487                                  NVPTX::INT_PTX_LDG_G_v2i64_ELE_ari32,
1488                                  NVPTX::INT_PTX_LDG_G_v2f32_ELE_ari32,
1489                                  NVPTX::INT_PTX_LDG_G_v2f64_ELE_ari32);
1490         break;
1491       case NVPTXISD::LDUV2:
1492         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1493                                  NVPTX::INT_PTX_LDU_G_v2i8_ELE_ari32,
1494                                  NVPTX::INT_PTX_LDU_G_v2i16_ELE_ari32,
1495                                  NVPTX::INT_PTX_LDU_G_v2i32_ELE_ari32,
1496                                  NVPTX::INT_PTX_LDU_G_v2i64_ELE_ari32,
1497                                  NVPTX::INT_PTX_LDU_G_v2f32_ELE_ari32,
1498                                  NVPTX::INT_PTX_LDU_G_v2f64_ELE_ari32);
1499         break;
1500       case NVPTXISD::LoadV4:
1501       case NVPTXISD::LDGV4:
1502         Opcode = pickOpcodeForVT(
1503             EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_ari32,
1504             NVPTX::INT_PTX_LDG_G_v4i16_ELE_ari32,
1505             NVPTX::INT_PTX_LDG_G_v4i32_ELE_ari32, std::nullopt,
1506             NVPTX::INT_PTX_LDG_G_v4f32_ELE_ari32, std::nullopt);
1507         break;
1508       case NVPTXISD::LDUV4:
1509         Opcode = pickOpcodeForVT(
1510             EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_ari32,
1511             NVPTX::INT_PTX_LDU_G_v4i16_ELE_ari32,
1512             NVPTX::INT_PTX_LDU_G_v4i32_ELE_ari32, std::nullopt,
1513             NVPTX::INT_PTX_LDU_G_v4f32_ELE_ari32, std::nullopt);
1514         break;
1515       }
1516     }
1517     if (!Opcode)
1518       return false;
1519     SDValue Ops[] = {Base, Offset, Chain};
1520     LD = CurDAG->getMachineNode(*Opcode, DL, InstVTList, Ops);
1521   } else {
1522     if (TM.is64Bit()) {
1523       switch (N->getOpcode()) {
1524       default:
1525         return false;
1526       case ISD::LOAD:
1527       case ISD::INTRINSIC_W_CHAIN:
1528         if (IsLDG)
1529           Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1530                                        NVPTX::INT_PTX_LDG_GLOBAL_i8areg64,
1531                                        NVPTX::INT_PTX_LDG_GLOBAL_i16areg64,
1532                                        NVPTX::INT_PTX_LDG_GLOBAL_i32areg64,
1533                                        NVPTX::INT_PTX_LDG_GLOBAL_i64areg64,
1534                                        NVPTX::INT_PTX_LDG_GLOBAL_f32areg64,
1535                                        NVPTX::INT_PTX_LDG_GLOBAL_f64areg64);
1536         else
1537           Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1538                                        NVPTX::INT_PTX_LDU_GLOBAL_i8areg64,
1539                                        NVPTX::INT_PTX_LDU_GLOBAL_i16areg64,
1540                                        NVPTX::INT_PTX_LDU_GLOBAL_i32areg64,
1541                                        NVPTX::INT_PTX_LDU_GLOBAL_i64areg64,
1542                                        NVPTX::INT_PTX_LDU_GLOBAL_f32areg64,
1543                                        NVPTX::INT_PTX_LDU_GLOBAL_f64areg64);
1544         break;
1545       case NVPTXISD::LoadV2:
1546       case NVPTXISD::LDGV2:
1547         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1548                                      NVPTX::INT_PTX_LDG_G_v2i8_ELE_areg64,
1549                                      NVPTX::INT_PTX_LDG_G_v2i16_ELE_areg64,
1550                                      NVPTX::INT_PTX_LDG_G_v2i32_ELE_areg64,
1551                                      NVPTX::INT_PTX_LDG_G_v2i64_ELE_areg64,
1552                                      NVPTX::INT_PTX_LDG_G_v2f32_ELE_areg64,
1553                                      NVPTX::INT_PTX_LDG_G_v2f64_ELE_areg64);
1554         break;
1555       case NVPTXISD::LDUV2:
1556         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1557                                      NVPTX::INT_PTX_LDU_G_v2i8_ELE_areg64,
1558                                      NVPTX::INT_PTX_LDU_G_v2i16_ELE_areg64,
1559                                      NVPTX::INT_PTX_LDU_G_v2i32_ELE_areg64,
1560                                      NVPTX::INT_PTX_LDU_G_v2i64_ELE_areg64,
1561                                      NVPTX::INT_PTX_LDU_G_v2f32_ELE_areg64,
1562                                      NVPTX::INT_PTX_LDU_G_v2f64_ELE_areg64);
1563         break;
1564       case NVPTXISD::LoadV4:
1565       case NVPTXISD::LDGV4:
1566         Opcode = pickOpcodeForVT(
1567             EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_areg64,
1568             NVPTX::INT_PTX_LDG_G_v4i16_ELE_areg64,
1569             NVPTX::INT_PTX_LDG_G_v4i32_ELE_areg64, std::nullopt,
1570             NVPTX::INT_PTX_LDG_G_v4f32_ELE_areg64, std::nullopt);
1571         break;
1572       case NVPTXISD::LDUV4:
1573         Opcode = pickOpcodeForVT(
1574             EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_areg64,
1575             NVPTX::INT_PTX_LDU_G_v4i16_ELE_areg64,
1576             NVPTX::INT_PTX_LDU_G_v4i32_ELE_areg64, std::nullopt,
1577             NVPTX::INT_PTX_LDU_G_v4f32_ELE_areg64, std::nullopt);
1578         break;
1579       }
1580     } else {
1581       switch (N->getOpcode()) {
1582       default:
1583         return false;
1584       case ISD::LOAD:
1585       case ISD::INTRINSIC_W_CHAIN:
1586         if (IsLDG)
1587           Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1588                                    NVPTX::INT_PTX_LDG_GLOBAL_i8areg,
1589                                    NVPTX::INT_PTX_LDG_GLOBAL_i16areg,
1590                                    NVPTX::INT_PTX_LDG_GLOBAL_i32areg,
1591                                    NVPTX::INT_PTX_LDG_GLOBAL_i64areg,
1592                                    NVPTX::INT_PTX_LDG_GLOBAL_f32areg,
1593                                    NVPTX::INT_PTX_LDG_GLOBAL_f64areg);
1594         else
1595           Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1596                                    NVPTX::INT_PTX_LDU_GLOBAL_i8areg,
1597                                    NVPTX::INT_PTX_LDU_GLOBAL_i16areg,
1598                                    NVPTX::INT_PTX_LDU_GLOBAL_i32areg,
1599                                    NVPTX::INT_PTX_LDU_GLOBAL_i64areg,
1600                                    NVPTX::INT_PTX_LDU_GLOBAL_f32areg,
1601                                    NVPTX::INT_PTX_LDU_GLOBAL_f64areg);
1602         break;
1603       case NVPTXISD::LoadV2:
1604       case NVPTXISD::LDGV2:
1605         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1606                                  NVPTX::INT_PTX_LDG_G_v2i8_ELE_areg32,
1607                                  NVPTX::INT_PTX_LDG_G_v2i16_ELE_areg32,
1608                                  NVPTX::INT_PTX_LDG_G_v2i32_ELE_areg32,
1609                                  NVPTX::INT_PTX_LDG_G_v2i64_ELE_areg32,
1610                                  NVPTX::INT_PTX_LDG_G_v2f32_ELE_areg32,
1611                                  NVPTX::INT_PTX_LDG_G_v2f64_ELE_areg32);
1612         break;
1613       case NVPTXISD::LDUV2:
1614         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1615                                  NVPTX::INT_PTX_LDU_G_v2i8_ELE_areg32,
1616                                  NVPTX::INT_PTX_LDU_G_v2i16_ELE_areg32,
1617                                  NVPTX::INT_PTX_LDU_G_v2i32_ELE_areg32,
1618                                  NVPTX::INT_PTX_LDU_G_v2i64_ELE_areg32,
1619                                  NVPTX::INT_PTX_LDU_G_v2f32_ELE_areg32,
1620                                  NVPTX::INT_PTX_LDU_G_v2f64_ELE_areg32);
1621         break;
1622       case NVPTXISD::LoadV4:
1623       case NVPTXISD::LDGV4:
1624         Opcode = pickOpcodeForVT(
1625             EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE_areg32,
1626             NVPTX::INT_PTX_LDG_G_v4i16_ELE_areg32,
1627             NVPTX::INT_PTX_LDG_G_v4i32_ELE_areg32, std::nullopt,
1628             NVPTX::INT_PTX_LDG_G_v4f32_ELE_areg32, std::nullopt);
1629         break;
1630       case NVPTXISD::LDUV4:
1631         Opcode = pickOpcodeForVT(
1632             EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE_areg32,
1633             NVPTX::INT_PTX_LDU_G_v4i16_ELE_areg32,
1634             NVPTX::INT_PTX_LDU_G_v4i32_ELE_areg32, std::nullopt,
1635             NVPTX::INT_PTX_LDU_G_v4f32_ELE_areg32, std::nullopt);
1636         break;
1637       }
1638     }
1639     if (!Opcode)
1640       return false;
1641     SDValue Ops[] = { Op1, Chain };
1642     LD = CurDAG->getMachineNode(*Opcode, DL, InstVTList, Ops);
1643   }
1644 
1645   // For automatic generation of LDG (through SelectLoad[Vector], not the
1646   // intrinsics), we may have an extending load like:
1647   //
1648   //   i32,ch = load<LD1[%data1(addrspace=1)], zext from i8> t0, t7, undef:i64
1649   //
1650   // In this case, the matching logic above will select a load for the original
1651   // memory type (in this case, i8) and our types will not match (the node needs
1652   // to return an i32 in this case). Our LDG/LDU nodes do not support the
1653   // concept of sign-/zero-extension, so emulate it here by adding an explicit
1654   // CVT instruction. Ptxas should clean up any redundancies here.
1655 
1656   LoadSDNode *LdNode = dyn_cast<LoadSDNode>(N);
1657 
1658   if (OrigType != EltVT &&
1659       (LdNode || (OrigType.isFloatingPoint() && EltVT.isFloatingPoint()))) {
1660     // We have an extending-load. The instruction we selected operates on the
1661     // smaller type, but the SDNode we are replacing has the larger type. We
1662     // need to emit a CVT to make the types match.
1663     unsigned CvtOpc =
1664         GetConvertOpcode(OrigType.getSimpleVT(), EltVT.getSimpleVT(), LdNode);
1665 
1666     // For each output value, apply the manual sign/zero-extension and make sure
1667     // all users of the load go through that CVT.
1668     for (unsigned i = 0; i != NumElts; ++i) {
1669       SDValue Res(LD, i);
1670       SDValue OrigVal(N, i);
1671 
1672       SDNode *CvtNode =
1673         CurDAG->getMachineNode(CvtOpc, DL, OrigType, Res,
1674                                CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE,
1675                                                          DL, MVT::i32));
1676       ReplaceUses(OrigVal, SDValue(CvtNode, 0));
1677     }
1678   }
1679 
1680   ReplaceNode(N, LD);
1681   return true;
1682 }
1683 
tryStore(SDNode * N)1684 bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
1685   SDLoc dl(N);
1686   MemSDNode *ST = cast<MemSDNode>(N);
1687   assert(ST->writeMem() && "Expected store");
1688   StoreSDNode *PlainStore = dyn_cast<StoreSDNode>(N);
1689   AtomicSDNode *AtomicStore = dyn_cast<AtomicSDNode>(N);
1690   assert((PlainStore || AtomicStore) && "Expected store");
1691   EVT StoreVT = ST->getMemoryVT();
1692   SDNode *NVPTXST = nullptr;
1693 
1694   // do not support pre/post inc/dec
1695   if (PlainStore && PlainStore->isIndexed())
1696     return false;
1697 
1698   if (!StoreVT.isSimple())
1699     return false;
1700 
1701   AtomicOrdering Ordering = ST->getSuccessOrdering();
1702   // In order to lower atomic loads with stronger guarantees we would need to
1703   // use store.release or insert fences. However these features were only added
1704   // with PTX ISA 6.0 / sm_70.
1705   // TODO: Check if we can actually use the new instructions and implement them.
1706   if (isStrongerThanMonotonic(Ordering))
1707     return false;
1708 
1709   // Address Space Setting
1710   unsigned int CodeAddrSpace = getCodeAddrSpace(ST);
1711   unsigned int PointerSize =
1712       CurDAG->getDataLayout().getPointerSizeInBits(ST->getAddressSpace());
1713 
1714   // Volatile Setting
1715   // - .volatile is only available for .global and .shared
1716   // - .volatile has the same memory synchronization semantics as .relaxed.sys
1717   bool isVolatile = ST->isVolatile() || Ordering == AtomicOrdering::Monotonic;
1718   if (CodeAddrSpace != NVPTX::PTXLdStInstCode::GLOBAL &&
1719       CodeAddrSpace != NVPTX::PTXLdStInstCode::SHARED &&
1720       CodeAddrSpace != NVPTX::PTXLdStInstCode::GENERIC)
1721     isVolatile = false;
1722 
1723   // Vector Setting
1724   MVT SimpleVT = StoreVT.getSimpleVT();
1725   unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
1726 
1727   // Type Setting: toType + toTypeWidth
1728   // - for integer type, always use 'u'
1729   //
1730   MVT ScalarVT = SimpleVT.getScalarType();
1731   unsigned toTypeWidth = ScalarVT.getSizeInBits();
1732   if (SimpleVT.isVector()) {
1733     assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) &&
1734            "Unexpected vector type");
1735     // v2x16 is stored using st.b32
1736     toTypeWidth = 32;
1737   }
1738 
1739   unsigned int toType = getLdStRegType(ScalarVT);
1740 
1741   // Create the machine instruction DAG
1742   SDValue Chain = ST->getChain();
1743   SDValue Value = PlainStore ? PlainStore->getValue() : AtomicStore->getVal();
1744   SDValue BasePtr = ST->getBasePtr();
1745   SDValue Addr;
1746   SDValue Offset, Base;
1747   std::optional<unsigned> Opcode;
1748   MVT::SimpleValueType SourceVT =
1749       Value.getNode()->getSimpleValueType(0).SimpleTy;
1750 
1751   if (SelectDirectAddr(BasePtr, Addr)) {
1752     Opcode = pickOpcodeForVT(SourceVT, NVPTX::ST_i8_avar, NVPTX::ST_i16_avar,
1753                              NVPTX::ST_i32_avar, NVPTX::ST_i64_avar,
1754                              NVPTX::ST_f32_avar, NVPTX::ST_f64_avar);
1755     if (!Opcode)
1756       return false;
1757     SDValue Ops[] = {Value,
1758                      getI32Imm(isVolatile, dl),
1759                      getI32Imm(CodeAddrSpace, dl),
1760                      getI32Imm(vecType, dl),
1761                      getI32Imm(toType, dl),
1762                      getI32Imm(toTypeWidth, dl),
1763                      Addr,
1764                      Chain};
1765     NVPTXST = CurDAG->getMachineNode(*Opcode, dl, MVT::Other, Ops);
1766   } else if (PointerSize == 64
1767                  ? SelectADDRsi64(BasePtr.getNode(), BasePtr, Base, Offset)
1768                  : SelectADDRsi(BasePtr.getNode(), BasePtr, Base, Offset)) {
1769     Opcode = pickOpcodeForVT(SourceVT, NVPTX::ST_i8_asi, NVPTX::ST_i16_asi,
1770                              NVPTX::ST_i32_asi, NVPTX::ST_i64_asi,
1771                              NVPTX::ST_f32_asi, NVPTX::ST_f64_asi);
1772     if (!Opcode)
1773       return false;
1774     SDValue Ops[] = {Value,
1775                      getI32Imm(isVolatile, dl),
1776                      getI32Imm(CodeAddrSpace, dl),
1777                      getI32Imm(vecType, dl),
1778                      getI32Imm(toType, dl),
1779                      getI32Imm(toTypeWidth, dl),
1780                      Base,
1781                      Offset,
1782                      Chain};
1783     NVPTXST = CurDAG->getMachineNode(*Opcode, dl, MVT::Other, Ops);
1784   } else if (PointerSize == 64
1785                  ? SelectADDRri64(BasePtr.getNode(), BasePtr, Base, Offset)
1786                  : SelectADDRri(BasePtr.getNode(), BasePtr, Base, Offset)) {
1787     if (PointerSize == 64)
1788       Opcode =
1789           pickOpcodeForVT(SourceVT, NVPTX::ST_i8_ari_64, NVPTX::ST_i16_ari_64,
1790                           NVPTX::ST_i32_ari_64, NVPTX::ST_i64_ari_64,
1791                           NVPTX::ST_f32_ari_64, NVPTX::ST_f64_ari_64);
1792     else
1793       Opcode = pickOpcodeForVT(SourceVT, NVPTX::ST_i8_ari, NVPTX::ST_i16_ari,
1794                                NVPTX::ST_i32_ari, NVPTX::ST_i64_ari,
1795                                NVPTX::ST_f32_ari, NVPTX::ST_f64_ari);
1796     if (!Opcode)
1797       return false;
1798 
1799     SDValue Ops[] = {Value,
1800                      getI32Imm(isVolatile, dl),
1801                      getI32Imm(CodeAddrSpace, dl),
1802                      getI32Imm(vecType, dl),
1803                      getI32Imm(toType, dl),
1804                      getI32Imm(toTypeWidth, dl),
1805                      Base,
1806                      Offset,
1807                      Chain};
1808     NVPTXST = CurDAG->getMachineNode(*Opcode, dl, MVT::Other, Ops);
1809   } else {
1810     if (PointerSize == 64)
1811       Opcode =
1812           pickOpcodeForVT(SourceVT, NVPTX::ST_i8_areg_64, NVPTX::ST_i16_areg_64,
1813                           NVPTX::ST_i32_areg_64, NVPTX::ST_i64_areg_64,
1814                           NVPTX::ST_f32_areg_64, NVPTX::ST_f64_areg_64);
1815     else
1816       Opcode = pickOpcodeForVT(SourceVT, NVPTX::ST_i8_areg, NVPTX::ST_i16_areg,
1817                                NVPTX::ST_i32_areg, NVPTX::ST_i64_areg,
1818                                NVPTX::ST_f32_areg, NVPTX::ST_f64_areg);
1819     if (!Opcode)
1820       return false;
1821     SDValue Ops[] = {Value,
1822                      getI32Imm(isVolatile, dl),
1823                      getI32Imm(CodeAddrSpace, dl),
1824                      getI32Imm(vecType, dl),
1825                      getI32Imm(toType, dl),
1826                      getI32Imm(toTypeWidth, dl),
1827                      BasePtr,
1828                      Chain};
1829     NVPTXST = CurDAG->getMachineNode(*Opcode, dl, MVT::Other, Ops);
1830   }
1831 
1832   if (!NVPTXST)
1833     return false;
1834 
1835   MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
1836   CurDAG->setNodeMemRefs(cast<MachineSDNode>(NVPTXST), {MemRef});
1837   ReplaceNode(N, NVPTXST);
1838   return true;
1839 }
1840 
tryStoreVector(SDNode * N)1841 bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
1842   SDValue Chain = N->getOperand(0);
1843   SDValue Op1 = N->getOperand(1);
1844   SDValue Addr, Offset, Base;
1845   std::optional<unsigned> Opcode;
1846   SDLoc DL(N);
1847   SDNode *ST;
1848   EVT EltVT = Op1.getValueType();
1849   MemSDNode *MemSD = cast<MemSDNode>(N);
1850   EVT StoreVT = MemSD->getMemoryVT();
1851 
1852   // Address Space Setting
1853   unsigned CodeAddrSpace = getCodeAddrSpace(MemSD);
1854   if (CodeAddrSpace == NVPTX::PTXLdStInstCode::CONSTANT) {
1855     report_fatal_error("Cannot store to pointer that points to constant "
1856                        "memory space");
1857   }
1858   unsigned int PointerSize =
1859       CurDAG->getDataLayout().getPointerSizeInBits(MemSD->getAddressSpace());
1860 
1861   // Volatile Setting
1862   // - .volatile is only availalble for .global and .shared
1863   bool IsVolatile = MemSD->isVolatile();
1864   if (CodeAddrSpace != NVPTX::PTXLdStInstCode::GLOBAL &&
1865       CodeAddrSpace != NVPTX::PTXLdStInstCode::SHARED &&
1866       CodeAddrSpace != NVPTX::PTXLdStInstCode::GENERIC)
1867     IsVolatile = false;
1868 
1869   // Type Setting: toType + toTypeWidth
1870   // - for integer type, always use 'u'
1871   assert(StoreVT.isSimple() && "Store value is not simple");
1872   MVT ScalarVT = StoreVT.getSimpleVT().getScalarType();
1873   unsigned ToTypeWidth = ScalarVT.getSizeInBits();
1874   unsigned ToType = getLdStRegType(ScalarVT);
1875 
1876   SmallVector<SDValue, 12> StOps;
1877   SDValue N2;
1878   unsigned VecType;
1879 
1880   switch (N->getOpcode()) {
1881   case NVPTXISD::StoreV2:
1882     VecType = NVPTX::PTXLdStInstCode::V2;
1883     StOps.push_back(N->getOperand(1));
1884     StOps.push_back(N->getOperand(2));
1885     N2 = N->getOperand(3);
1886     break;
1887   case NVPTXISD::StoreV4:
1888     VecType = NVPTX::PTXLdStInstCode::V4;
1889     StOps.push_back(N->getOperand(1));
1890     StOps.push_back(N->getOperand(2));
1891     StOps.push_back(N->getOperand(3));
1892     StOps.push_back(N->getOperand(4));
1893     N2 = N->getOperand(5);
1894     break;
1895   default:
1896     return false;
1897   }
1898 
1899   // v8x16 is a special case. PTX doesn't have st.v8.x16
1900   // instruction. Instead, we split the vector into v2x16 chunks and
1901   // store them with st.v4.b32.
1902   if (Isv2x16VT(EltVT)) {
1903     assert(N->getOpcode() == NVPTXISD::StoreV4 && "Unexpected load opcode.");
1904     EltVT = MVT::i32;
1905     ToType = NVPTX::PTXLdStInstCode::Untyped;
1906     ToTypeWidth = 32;
1907   }
1908 
1909   StOps.push_back(getI32Imm(IsVolatile, DL));
1910   StOps.push_back(getI32Imm(CodeAddrSpace, DL));
1911   StOps.push_back(getI32Imm(VecType, DL));
1912   StOps.push_back(getI32Imm(ToType, DL));
1913   StOps.push_back(getI32Imm(ToTypeWidth, DL));
1914 
1915   if (SelectDirectAddr(N2, Addr)) {
1916     switch (N->getOpcode()) {
1917     default:
1918       return false;
1919     case NVPTXISD::StoreV2:
1920       Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1921                                NVPTX::STV_i8_v2_avar, NVPTX::STV_i16_v2_avar,
1922                                NVPTX::STV_i32_v2_avar, NVPTX::STV_i64_v2_avar,
1923                                NVPTX::STV_f32_v2_avar, NVPTX::STV_f64_v2_avar);
1924       break;
1925     case NVPTXISD::StoreV4:
1926       Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1927                                NVPTX::STV_i8_v4_avar, NVPTX::STV_i16_v4_avar,
1928                                NVPTX::STV_i32_v4_avar, std::nullopt,
1929                                NVPTX::STV_f32_v4_avar, std::nullopt);
1930       break;
1931     }
1932     StOps.push_back(Addr);
1933   } else if (PointerSize == 64 ? SelectADDRsi64(N2.getNode(), N2, Base, Offset)
1934                                : SelectADDRsi(N2.getNode(), N2, Base, Offset)) {
1935     switch (N->getOpcode()) {
1936     default:
1937       return false;
1938     case NVPTXISD::StoreV2:
1939       Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1940                                NVPTX::STV_i8_v2_asi, NVPTX::STV_i16_v2_asi,
1941                                NVPTX::STV_i32_v2_asi, NVPTX::STV_i64_v2_asi,
1942                                NVPTX::STV_f32_v2_asi, NVPTX::STV_f64_v2_asi);
1943       break;
1944     case NVPTXISD::StoreV4:
1945       Opcode =
1946           pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4_asi,
1947                           NVPTX::STV_i16_v4_asi, NVPTX::STV_i32_v4_asi,
1948                           std::nullopt, NVPTX::STV_f32_v4_asi, std::nullopt);
1949       break;
1950     }
1951     StOps.push_back(Base);
1952     StOps.push_back(Offset);
1953   } else if (PointerSize == 64 ? SelectADDRri64(N2.getNode(), N2, Base, Offset)
1954                                : SelectADDRri(N2.getNode(), N2, Base, Offset)) {
1955     if (PointerSize == 64) {
1956       switch (N->getOpcode()) {
1957       default:
1958         return false;
1959       case NVPTXISD::StoreV2:
1960         Opcode =
1961             pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1962                             NVPTX::STV_i8_v2_ari_64, NVPTX::STV_i16_v2_ari_64,
1963                             NVPTX::STV_i32_v2_ari_64, NVPTX::STV_i64_v2_ari_64,
1964                             NVPTX::STV_f32_v2_ari_64, NVPTX::STV_f64_v2_ari_64);
1965         break;
1966       case NVPTXISD::StoreV4:
1967         Opcode = pickOpcodeForVT(
1968             EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4_ari_64,
1969             NVPTX::STV_i16_v4_ari_64, NVPTX::STV_i32_v4_ari_64, std::nullopt,
1970             NVPTX::STV_f32_v4_ari_64, std::nullopt);
1971         break;
1972       }
1973     } else {
1974       switch (N->getOpcode()) {
1975       default:
1976         return false;
1977       case NVPTXISD::StoreV2:
1978         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1979                                  NVPTX::STV_i8_v2_ari, NVPTX::STV_i16_v2_ari,
1980                                  NVPTX::STV_i32_v2_ari, NVPTX::STV_i64_v2_ari,
1981                                  NVPTX::STV_f32_v2_ari, NVPTX::STV_f64_v2_ari);
1982         break;
1983       case NVPTXISD::StoreV4:
1984         Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy,
1985                                  NVPTX::STV_i8_v4_ari, NVPTX::STV_i16_v4_ari,
1986                                  NVPTX::STV_i32_v4_ari, std::nullopt,
1987                                  NVPTX::STV_f32_v4_ari, std::nullopt);
1988         break;
1989       }
1990     }
1991     StOps.push_back(Base);
1992     StOps.push_back(Offset);
1993   } else {
1994     if (PointerSize == 64) {
1995       switch (N->getOpcode()) {
1996       default:
1997         return false;
1998       case NVPTXISD::StoreV2:
1999         Opcode = pickOpcodeForVT(
2000             EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v2_areg_64,
2001             NVPTX::STV_i16_v2_areg_64, NVPTX::STV_i32_v2_areg_64,
2002             NVPTX::STV_i64_v2_areg_64, NVPTX::STV_f32_v2_areg_64,
2003             NVPTX::STV_f64_v2_areg_64);
2004         break;
2005       case NVPTXISD::StoreV4:
2006         Opcode = pickOpcodeForVT(
2007             EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4_areg_64,
2008             NVPTX::STV_i16_v4_areg_64, NVPTX::STV_i32_v4_areg_64, std::nullopt,
2009             NVPTX::STV_f32_v4_areg_64, std::nullopt);
2010         break;
2011       }
2012     } else {
2013       switch (N->getOpcode()) {
2014       default:
2015         return false;
2016       case NVPTXISD::StoreV2:
2017         Opcode =
2018             pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v2_areg,
2019                             NVPTX::STV_i16_v2_areg, NVPTX::STV_i32_v2_areg,
2020                             NVPTX::STV_i64_v2_areg, NVPTX::STV_f32_v2_areg,
2021                             NVPTX::STV_f64_v2_areg);
2022         break;
2023       case NVPTXISD::StoreV4:
2024         Opcode =
2025             pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, NVPTX::STV_i8_v4_areg,
2026                             NVPTX::STV_i16_v4_areg, NVPTX::STV_i32_v4_areg,
2027                             std::nullopt, NVPTX::STV_f32_v4_areg, std::nullopt);
2028         break;
2029       }
2030     }
2031     StOps.push_back(N2);
2032   }
2033 
2034   if (!Opcode)
2035     return false;
2036 
2037   StOps.push_back(Chain);
2038 
2039   ST = CurDAG->getMachineNode(*Opcode, DL, MVT::Other, StOps);
2040 
2041   MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
2042   CurDAG->setNodeMemRefs(cast<MachineSDNode>(ST), {MemRef});
2043 
2044   ReplaceNode(N, ST);
2045   return true;
2046 }
2047 
tryLoadParam(SDNode * Node)2048 bool NVPTXDAGToDAGISel::tryLoadParam(SDNode *Node) {
2049   SDValue Chain = Node->getOperand(0);
2050   SDValue Offset = Node->getOperand(2);
2051   SDValue Glue = Node->getOperand(3);
2052   SDLoc DL(Node);
2053   MemSDNode *Mem = cast<MemSDNode>(Node);
2054 
2055   unsigned VecSize;
2056   switch (Node->getOpcode()) {
2057   default:
2058     return false;
2059   case NVPTXISD::LoadParam:
2060     VecSize = 1;
2061     break;
2062   case NVPTXISD::LoadParamV2:
2063     VecSize = 2;
2064     break;
2065   case NVPTXISD::LoadParamV4:
2066     VecSize = 4;
2067     break;
2068   }
2069 
2070   EVT EltVT = Node->getValueType(0);
2071   EVT MemVT = Mem->getMemoryVT();
2072 
2073   std::optional<unsigned> Opcode;
2074 
2075   switch (VecSize) {
2076   default:
2077     return false;
2078   case 1:
2079     Opcode = pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy,
2080                              NVPTX::LoadParamMemI8, NVPTX::LoadParamMemI16,
2081                              NVPTX::LoadParamMemI32, NVPTX::LoadParamMemI64,
2082                              NVPTX::LoadParamMemF32, NVPTX::LoadParamMemF64);
2083     break;
2084   case 2:
2085     Opcode =
2086         pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, NVPTX::LoadParamMemV2I8,
2087                         NVPTX::LoadParamMemV2I16, NVPTX::LoadParamMemV2I32,
2088                         NVPTX::LoadParamMemV2I64, NVPTX::LoadParamMemV2F32,
2089                         NVPTX::LoadParamMemV2F64);
2090     break;
2091   case 4:
2092     Opcode =
2093         pickOpcodeForVT(MemVT.getSimpleVT().SimpleTy, NVPTX::LoadParamMemV4I8,
2094                         NVPTX::LoadParamMemV4I16, NVPTX::LoadParamMemV4I32,
2095                         std::nullopt, NVPTX::LoadParamMemV4F32, std::nullopt);
2096     break;
2097   }
2098   if (!Opcode)
2099     return false;
2100 
2101   SDVTList VTs;
2102   if (VecSize == 1) {
2103     VTs = CurDAG->getVTList(EltVT, MVT::Other, MVT::Glue);
2104   } else if (VecSize == 2) {
2105     VTs = CurDAG->getVTList(EltVT, EltVT, MVT::Other, MVT::Glue);
2106   } else {
2107     EVT EVTs[] = { EltVT, EltVT, EltVT, EltVT, MVT::Other, MVT::Glue };
2108     VTs = CurDAG->getVTList(EVTs);
2109   }
2110 
2111   unsigned OffsetVal = Offset->getAsZExtVal();
2112 
2113   SmallVector<SDValue, 2> Ops;
2114   Ops.push_back(CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32));
2115   Ops.push_back(Chain);
2116   Ops.push_back(Glue);
2117 
2118   ReplaceNode(Node, CurDAG->getMachineNode(*Opcode, DL, VTs, Ops));
2119   return true;
2120 }
2121 
tryStoreRetval(SDNode * N)2122 bool NVPTXDAGToDAGISel::tryStoreRetval(SDNode *N) {
2123   SDLoc DL(N);
2124   SDValue Chain = N->getOperand(0);
2125   SDValue Offset = N->getOperand(1);
2126   unsigned OffsetVal = Offset->getAsZExtVal();
2127   MemSDNode *Mem = cast<MemSDNode>(N);
2128 
2129   // How many elements do we have?
2130   unsigned NumElts = 1;
2131   switch (N->getOpcode()) {
2132   default:
2133     return false;
2134   case NVPTXISD::StoreRetval:
2135     NumElts = 1;
2136     break;
2137   case NVPTXISD::StoreRetvalV2:
2138     NumElts = 2;
2139     break;
2140   case NVPTXISD::StoreRetvalV4:
2141     NumElts = 4;
2142     break;
2143   }
2144 
2145   // Build vector of operands
2146   SmallVector<SDValue, 6> Ops;
2147   for (unsigned i = 0; i < NumElts; ++i)
2148     Ops.push_back(N->getOperand(i + 2));
2149   Ops.push_back(CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32));
2150   Ops.push_back(Chain);
2151 
2152   // Determine target opcode
2153   // If we have an i1, use an 8-bit store. The lowering code in
2154   // NVPTXISelLowering will have already emitted an upcast.
2155   std::optional<unsigned> Opcode = 0;
2156   switch (NumElts) {
2157   default:
2158     return false;
2159   case 1:
2160     Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
2161                              NVPTX::StoreRetvalI8, NVPTX::StoreRetvalI16,
2162                              NVPTX::StoreRetvalI32, NVPTX::StoreRetvalI64,
2163                              NVPTX::StoreRetvalF32, NVPTX::StoreRetvalF64);
2164     if (Opcode == NVPTX::StoreRetvalI8) {
2165       // Fine tune the opcode depending on the size of the operand.
2166       // This helps to avoid creating redundant COPY instructions in
2167       // InstrEmitter::AddRegisterOperand().
2168       switch (Ops[0].getSimpleValueType().SimpleTy) {
2169       default:
2170         break;
2171       case MVT::i32:
2172         Opcode = NVPTX::StoreRetvalI8TruncI32;
2173         break;
2174       case MVT::i64:
2175         Opcode = NVPTX::StoreRetvalI8TruncI64;
2176         break;
2177       }
2178     }
2179     break;
2180   case 2:
2181     Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
2182                              NVPTX::StoreRetvalV2I8, NVPTX::StoreRetvalV2I16,
2183                              NVPTX::StoreRetvalV2I32, NVPTX::StoreRetvalV2I64,
2184                              NVPTX::StoreRetvalV2F32, NVPTX::StoreRetvalV2F64);
2185     break;
2186   case 4:
2187     Opcode = pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
2188                              NVPTX::StoreRetvalV4I8, NVPTX::StoreRetvalV4I16,
2189                              NVPTX::StoreRetvalV4I32, std::nullopt,
2190                              NVPTX::StoreRetvalV4F32, std::nullopt);
2191     break;
2192   }
2193   if (!Opcode)
2194     return false;
2195 
2196   SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, MVT::Other, Ops);
2197   MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
2198   CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef});
2199 
2200   ReplaceNode(N, Ret);
2201   return true;
2202 }
2203 
2204 // Helpers for constructing opcode (ex: NVPTX::StoreParamV4F32_iiri)
2205 #define getOpcV2H(ty, opKind0, opKind1)                                        \
2206   NVPTX::StoreParamV2##ty##_##opKind0##opKind1
2207 
2208 #define getOpcV2H1(ty, opKind0, isImm1)                                        \
2209   (isImm1) ? getOpcV2H(ty, opKind0, i) : getOpcV2H(ty, opKind0, r)
2210 
2211 #define getOpcodeForVectorStParamV2(ty, isimm)                                 \
2212   (isimm[0]) ? getOpcV2H1(ty, i, isimm[1]) : getOpcV2H1(ty, r, isimm[1])
2213 
2214 #define getOpcV4H(ty, opKind0, opKind1, opKind2, opKind3)                      \
2215   NVPTX::StoreParamV4##ty##_##opKind0##opKind1##opKind2##opKind3
2216 
2217 #define getOpcV4H3(ty, opKind0, opKind1, opKind2, isImm3)                      \
2218   (isImm3) ? getOpcV4H(ty, opKind0, opKind1, opKind2, i)                       \
2219            : getOpcV4H(ty, opKind0, opKind1, opKind2, r)
2220 
2221 #define getOpcV4H2(ty, opKind0, opKind1, isImm2, isImm3)                       \
2222   (isImm2) ? getOpcV4H3(ty, opKind0, opKind1, i, isImm3)                       \
2223            : getOpcV4H3(ty, opKind0, opKind1, r, isImm3)
2224 
2225 #define getOpcV4H1(ty, opKind0, isImm1, isImm2, isImm3)                        \
2226   (isImm1) ? getOpcV4H2(ty, opKind0, i, isImm2, isImm3)                        \
2227            : getOpcV4H2(ty, opKind0, r, isImm2, isImm3)
2228 
2229 #define getOpcodeForVectorStParamV4(ty, isimm)                                 \
2230   (isimm[0]) ? getOpcV4H1(ty, i, isimm[1], isimm[2], isimm[3])                 \
2231              : getOpcV4H1(ty, r, isimm[1], isimm[2], isimm[3])
2232 
2233 #define getOpcodeForVectorStParam(n, ty, isimm)                                \
2234   (n == 2) ? getOpcodeForVectorStParamV2(ty, isimm)                            \
2235            : getOpcodeForVectorStParamV4(ty, isimm)
2236 
pickOpcodeForVectorStParam(SmallVector<SDValue,8> & Ops,unsigned NumElts,MVT::SimpleValueType MemTy,SelectionDAG * CurDAG,SDLoc DL)2237 static unsigned pickOpcodeForVectorStParam(SmallVector<SDValue, 8> &Ops,
2238                                            unsigned NumElts,
2239                                            MVT::SimpleValueType MemTy,
2240                                            SelectionDAG *CurDAG, SDLoc DL) {
2241   // Determine which inputs are registers and immediates make new operators
2242   // with constant values
2243   SmallVector<bool, 4> IsImm(NumElts, false);
2244   for (unsigned i = 0; i < NumElts; i++) {
2245     IsImm[i] = (isa<ConstantSDNode>(Ops[i]) || isa<ConstantFPSDNode>(Ops[i]));
2246     if (IsImm[i]) {
2247       SDValue Imm = Ops[i];
2248       if (MemTy == MVT::f32 || MemTy == MVT::f64) {
2249         const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
2250         const ConstantFP *CF = ConstImm->getConstantFPValue();
2251         Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
2252       } else {
2253         const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
2254         const ConstantInt *CI = ConstImm->getConstantIntValue();
2255         Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
2256       }
2257       Ops[i] = Imm;
2258     }
2259   }
2260 
2261   // Get opcode for MemTy, size, and register/immediate operand ordering
2262   switch (MemTy) {
2263   case MVT::i8:
2264     return getOpcodeForVectorStParam(NumElts, I8, IsImm);
2265   case MVT::i16:
2266     return getOpcodeForVectorStParam(NumElts, I16, IsImm);
2267   case MVT::i32:
2268     return getOpcodeForVectorStParam(NumElts, I32, IsImm);
2269   case MVT::i64:
2270     assert(NumElts == 2 && "MVT too large for NumElts > 2");
2271     return getOpcodeForVectorStParamV2(I64, IsImm);
2272   case MVT::f32:
2273     return getOpcodeForVectorStParam(NumElts, F32, IsImm);
2274   case MVT::f64:
2275     assert(NumElts == 2 && "MVT too large for NumElts > 2");
2276     return getOpcodeForVectorStParamV2(F64, IsImm);
2277 
2278   // These cases don't support immediates, just use the all register version
2279   // and generate moves.
2280   case MVT::i1:
2281     return (NumElts == 2) ? NVPTX::StoreParamV2I8_rr
2282                           : NVPTX::StoreParamV4I8_rrrr;
2283   case MVT::f16:
2284   case MVT::bf16:
2285     return (NumElts == 2) ? NVPTX::StoreParamV2I16_rr
2286                           : NVPTX::StoreParamV4I16_rrrr;
2287   case MVT::v2f16:
2288   case MVT::v2bf16:
2289   case MVT::v2i16:
2290   case MVT::v4i8:
2291     return (NumElts == 2) ? NVPTX::StoreParamV2I32_rr
2292                           : NVPTX::StoreParamV4I32_rrrr;
2293   default:
2294     llvm_unreachable("Cannot select st.param for unknown MemTy");
2295   }
2296 }
2297 
tryStoreParam(SDNode * N)2298 bool NVPTXDAGToDAGISel::tryStoreParam(SDNode *N) {
2299   SDLoc DL(N);
2300   SDValue Chain = N->getOperand(0);
2301   SDValue Param = N->getOperand(1);
2302   unsigned ParamVal = Param->getAsZExtVal();
2303   SDValue Offset = N->getOperand(2);
2304   unsigned OffsetVal = Offset->getAsZExtVal();
2305   MemSDNode *Mem = cast<MemSDNode>(N);
2306   SDValue Glue = N->getOperand(N->getNumOperands() - 1);
2307 
2308   // How many elements do we have?
2309   unsigned NumElts;
2310   switch (N->getOpcode()) {
2311   default:
2312     llvm_unreachable("Unexpected opcode");
2313   case NVPTXISD::StoreParamU32:
2314   case NVPTXISD::StoreParamS32:
2315   case NVPTXISD::StoreParam:
2316     NumElts = 1;
2317     break;
2318   case NVPTXISD::StoreParamV2:
2319     NumElts = 2;
2320     break;
2321   case NVPTXISD::StoreParamV4:
2322     NumElts = 4;
2323     break;
2324   }
2325 
2326   // Build vector of operands
2327   SmallVector<SDValue, 8> Ops;
2328   for (unsigned i = 0; i < NumElts; ++i)
2329     Ops.push_back(N->getOperand(i + 3));
2330   Ops.push_back(CurDAG->getTargetConstant(ParamVal, DL, MVT::i32));
2331   Ops.push_back(CurDAG->getTargetConstant(OffsetVal, DL, MVT::i32));
2332   Ops.push_back(Chain);
2333   Ops.push_back(Glue);
2334 
2335   // Determine target opcode
2336   // If we have an i1, use an 8-bit store. The lowering code in
2337   // NVPTXISelLowering will have already emitted an upcast.
2338   std::optional<unsigned> Opcode;
2339   switch (N->getOpcode()) {
2340   default:
2341     switch (NumElts) {
2342     default:
2343       llvm_unreachable("Unexpected NumElts");
2344     case 1: {
2345       MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
2346       SDValue Imm = Ops[0];
2347       if (MemTy != MVT::f16 && MemTy != MVT::v2f16 &&
2348           (isa<ConstantSDNode>(Imm) || isa<ConstantFPSDNode>(Imm))) {
2349         // Convert immediate to target constant
2350         if (MemTy == MVT::f32 || MemTy == MVT::f64) {
2351           const ConstantFPSDNode *ConstImm = cast<ConstantFPSDNode>(Imm);
2352           const ConstantFP *CF = ConstImm->getConstantFPValue();
2353           Imm = CurDAG->getTargetConstantFP(*CF, DL, Imm->getValueType(0));
2354         } else {
2355           const ConstantSDNode *ConstImm = cast<ConstantSDNode>(Imm);
2356           const ConstantInt *CI = ConstImm->getConstantIntValue();
2357           Imm = CurDAG->getTargetConstant(*CI, DL, Imm->getValueType(0));
2358         }
2359         Ops[0] = Imm;
2360         // Use immediate version of store param
2361         Opcode = pickOpcodeForVT(MemTy, NVPTX::StoreParamI8_i,
2362                                  NVPTX::StoreParamI16_i, NVPTX::StoreParamI32_i,
2363                                  NVPTX::StoreParamI64_i, NVPTX::StoreParamF32_i,
2364                                  NVPTX::StoreParamF64_i);
2365       } else
2366         Opcode =
2367             pickOpcodeForVT(Mem->getMemoryVT().getSimpleVT().SimpleTy,
2368                             NVPTX::StoreParamI8_r, NVPTX::StoreParamI16_r,
2369                             NVPTX::StoreParamI32_r, NVPTX::StoreParamI64_r,
2370                             NVPTX::StoreParamF32_r, NVPTX::StoreParamF64_r);
2371       if (Opcode == NVPTX::StoreParamI8_r) {
2372         // Fine tune the opcode depending on the size of the operand.
2373         // This helps to avoid creating redundant COPY instructions in
2374         // InstrEmitter::AddRegisterOperand().
2375         switch (Ops[0].getSimpleValueType().SimpleTy) {
2376         default:
2377           break;
2378         case MVT::i32:
2379           Opcode = NVPTX::StoreParamI8TruncI32_r;
2380           break;
2381         case MVT::i64:
2382           Opcode = NVPTX::StoreParamI8TruncI64_r;
2383           break;
2384         }
2385       }
2386       break;
2387     }
2388     case 2:
2389     case 4: {
2390       MVT::SimpleValueType MemTy = Mem->getMemoryVT().getSimpleVT().SimpleTy;
2391       Opcode = pickOpcodeForVectorStParam(Ops, NumElts, MemTy, CurDAG, DL);
2392       break;
2393     }
2394     }
2395     break;
2396   // Special case: if we have a sign-extend/zero-extend node, insert the
2397   // conversion instruction first, and use that as the value operand to
2398   // the selected StoreParam node.
2399   case NVPTXISD::StoreParamU32: {
2400     Opcode = NVPTX::StoreParamI32_r;
2401     SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL,
2402                                                 MVT::i32);
2403     SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_u32_u16, DL,
2404                                          MVT::i32, Ops[0], CvtNone);
2405     Ops[0] = SDValue(Cvt, 0);
2406     break;
2407   }
2408   case NVPTXISD::StoreParamS32: {
2409     Opcode = NVPTX::StoreParamI32_r;
2410     SDValue CvtNone = CurDAG->getTargetConstant(NVPTX::PTXCvtMode::NONE, DL,
2411                                                 MVT::i32);
2412     SDNode *Cvt = CurDAG->getMachineNode(NVPTX::CVT_s32_s16, DL,
2413                                          MVT::i32, Ops[0], CvtNone);
2414     Ops[0] = SDValue(Cvt, 0);
2415     break;
2416   }
2417   }
2418 
2419   SDVTList RetVTs = CurDAG->getVTList(MVT::Other, MVT::Glue);
2420   SDNode *Ret = CurDAG->getMachineNode(*Opcode, DL, RetVTs, Ops);
2421   MachineMemOperand *MemRef = cast<MemSDNode>(N)->getMemOperand();
2422   CurDAG->setNodeMemRefs(cast<MachineSDNode>(Ret), {MemRef});
2423 
2424   ReplaceNode(N, Ret);
2425   return true;
2426 }
2427 
tryTextureIntrinsic(SDNode * N)2428 bool NVPTXDAGToDAGISel::tryTextureIntrinsic(SDNode *N) {
2429   unsigned Opc = 0;
2430 
2431   switch (N->getOpcode()) {
2432   default: return false;
2433   case NVPTXISD::Tex1DFloatS32:
2434     Opc = NVPTX::TEX_1D_F32_S32_RR;
2435     break;
2436   case NVPTXISD::Tex1DFloatFloat:
2437     Opc = NVPTX::TEX_1D_F32_F32_RR;
2438     break;
2439   case NVPTXISD::Tex1DFloatFloatLevel:
2440     Opc = NVPTX::TEX_1D_F32_F32_LEVEL_RR;
2441     break;
2442   case NVPTXISD::Tex1DFloatFloatGrad:
2443     Opc = NVPTX::TEX_1D_F32_F32_GRAD_RR;
2444     break;
2445   case NVPTXISD::Tex1DS32S32:
2446     Opc = NVPTX::TEX_1D_S32_S32_RR;
2447     break;
2448   case NVPTXISD::Tex1DS32Float:
2449     Opc = NVPTX::TEX_1D_S32_F32_RR;
2450     break;
2451   case NVPTXISD::Tex1DS32FloatLevel:
2452     Opc = NVPTX::TEX_1D_S32_F32_LEVEL_RR;
2453     break;
2454   case NVPTXISD::Tex1DS32FloatGrad:
2455     Opc = NVPTX::TEX_1D_S32_F32_GRAD_RR;
2456     break;
2457   case NVPTXISD::Tex1DU32S32:
2458     Opc = NVPTX::TEX_1D_U32_S32_RR;
2459     break;
2460   case NVPTXISD::Tex1DU32Float:
2461     Opc = NVPTX::TEX_1D_U32_F32_RR;
2462     break;
2463   case NVPTXISD::Tex1DU32FloatLevel:
2464     Opc = NVPTX::TEX_1D_U32_F32_LEVEL_RR;
2465     break;
2466   case NVPTXISD::Tex1DU32FloatGrad:
2467     Opc = NVPTX::TEX_1D_U32_F32_GRAD_RR;
2468     break;
2469   case NVPTXISD::Tex1DArrayFloatS32:
2470     Opc = NVPTX::TEX_1D_ARRAY_F32_S32_RR;
2471     break;
2472   case NVPTXISD::Tex1DArrayFloatFloat:
2473     Opc = NVPTX::TEX_1D_ARRAY_F32_F32_RR;
2474     break;
2475   case NVPTXISD::Tex1DArrayFloatFloatLevel:
2476     Opc = NVPTX::TEX_1D_ARRAY_F32_F32_LEVEL_RR;
2477     break;
2478   case NVPTXISD::Tex1DArrayFloatFloatGrad:
2479     Opc = NVPTX::TEX_1D_ARRAY_F32_F32_GRAD_RR;
2480     break;
2481   case NVPTXISD::Tex1DArrayS32S32:
2482     Opc = NVPTX::TEX_1D_ARRAY_S32_S32_RR;
2483     break;
2484   case NVPTXISD::Tex1DArrayS32Float:
2485     Opc = NVPTX::TEX_1D_ARRAY_S32_F32_RR;
2486     break;
2487   case NVPTXISD::Tex1DArrayS32FloatLevel:
2488     Opc = NVPTX::TEX_1D_ARRAY_S32_F32_LEVEL_RR;
2489     break;
2490   case NVPTXISD::Tex1DArrayS32FloatGrad:
2491     Opc = NVPTX::TEX_1D_ARRAY_S32_F32_GRAD_RR;
2492     break;
2493   case NVPTXISD::Tex1DArrayU32S32:
2494     Opc = NVPTX::TEX_1D_ARRAY_U32_S32_RR;
2495     break;
2496   case NVPTXISD::Tex1DArrayU32Float:
2497     Opc = NVPTX::TEX_1D_ARRAY_U32_F32_RR;
2498     break;
2499   case NVPTXISD::Tex1DArrayU32FloatLevel:
2500     Opc = NVPTX::TEX_1D_ARRAY_U32_F32_LEVEL_RR;
2501     break;
2502   case NVPTXISD::Tex1DArrayU32FloatGrad:
2503     Opc = NVPTX::TEX_1D_ARRAY_U32_F32_GRAD_RR;
2504     break;
2505   case NVPTXISD::Tex2DFloatS32:
2506     Opc = NVPTX::TEX_2D_F32_S32_RR;
2507     break;
2508   case NVPTXISD::Tex2DFloatFloat:
2509     Opc = NVPTX::TEX_2D_F32_F32_RR;
2510     break;
2511   case NVPTXISD::Tex2DFloatFloatLevel:
2512     Opc = NVPTX::TEX_2D_F32_F32_LEVEL_RR;
2513     break;
2514   case NVPTXISD::Tex2DFloatFloatGrad:
2515     Opc = NVPTX::TEX_2D_F32_F32_GRAD_RR;
2516     break;
2517   case NVPTXISD::Tex2DS32S32:
2518     Opc = NVPTX::TEX_2D_S32_S32_RR;
2519     break;
2520   case NVPTXISD::Tex2DS32Float:
2521     Opc = NVPTX::TEX_2D_S32_F32_RR;
2522     break;
2523   case NVPTXISD::Tex2DS32FloatLevel:
2524     Opc = NVPTX::TEX_2D_S32_F32_LEVEL_RR;
2525     break;
2526   case NVPTXISD::Tex2DS32FloatGrad:
2527     Opc = NVPTX::TEX_2D_S32_F32_GRAD_RR;
2528     break;
2529   case NVPTXISD::Tex2DU32S32:
2530     Opc = NVPTX::TEX_2D_U32_S32_RR;
2531     break;
2532   case NVPTXISD::Tex2DU32Float:
2533     Opc = NVPTX::TEX_2D_U32_F32_RR;
2534     break;
2535   case NVPTXISD::Tex2DU32FloatLevel:
2536     Opc = NVPTX::TEX_2D_U32_F32_LEVEL_RR;
2537     break;
2538   case NVPTXISD::Tex2DU32FloatGrad:
2539     Opc = NVPTX::TEX_2D_U32_F32_GRAD_RR;
2540     break;
2541   case NVPTXISD::Tex2DArrayFloatS32:
2542     Opc = NVPTX::TEX_2D_ARRAY_F32_S32_RR;
2543     break;
2544   case NVPTXISD::Tex2DArrayFloatFloat:
2545     Opc = NVPTX::TEX_2D_ARRAY_F32_F32_RR;
2546     break;
2547   case NVPTXISD::Tex2DArrayFloatFloatLevel:
2548     Opc = NVPTX::TEX_2D_ARRAY_F32_F32_LEVEL_RR;
2549     break;
2550   case NVPTXISD::Tex2DArrayFloatFloatGrad:
2551     Opc = NVPTX::TEX_2D_ARRAY_F32_F32_GRAD_RR;
2552     break;
2553   case NVPTXISD::Tex2DArrayS32S32:
2554     Opc = NVPTX::TEX_2D_ARRAY_S32_S32_RR;
2555     break;
2556   case NVPTXISD::Tex2DArrayS32Float:
2557     Opc = NVPTX::TEX_2D_ARRAY_S32_F32_RR;
2558     break;
2559   case NVPTXISD::Tex2DArrayS32FloatLevel:
2560     Opc = NVPTX::TEX_2D_ARRAY_S32_F32_LEVEL_RR;
2561     break;
2562   case NVPTXISD::Tex2DArrayS32FloatGrad:
2563     Opc = NVPTX::TEX_2D_ARRAY_S32_F32_GRAD_RR;
2564     break;
2565   case NVPTXISD::Tex2DArrayU32S32:
2566     Opc = NVPTX::TEX_2D_ARRAY_U32_S32_RR;
2567     break;
2568   case NVPTXISD::Tex2DArrayU32Float:
2569     Opc = NVPTX::TEX_2D_ARRAY_U32_F32_RR;
2570     break;
2571   case NVPTXISD::Tex2DArrayU32FloatLevel:
2572     Opc = NVPTX::TEX_2D_ARRAY_U32_F32_LEVEL_RR;
2573     break;
2574   case NVPTXISD::Tex2DArrayU32FloatGrad:
2575     Opc = NVPTX::TEX_2D_ARRAY_U32_F32_GRAD_RR;
2576     break;
2577   case NVPTXISD::Tex3DFloatS32:
2578     Opc = NVPTX::TEX_3D_F32_S32_RR;
2579     break;
2580   case NVPTXISD::Tex3DFloatFloat:
2581     Opc = NVPTX::TEX_3D_F32_F32_RR;
2582     break;
2583   case NVPTXISD::Tex3DFloatFloatLevel:
2584     Opc = NVPTX::TEX_3D_F32_F32_LEVEL_RR;
2585     break;
2586   case NVPTXISD::Tex3DFloatFloatGrad:
2587     Opc = NVPTX::TEX_3D_F32_F32_GRAD_RR;
2588     break;
2589   case NVPTXISD::Tex3DS32S32:
2590     Opc = NVPTX::TEX_3D_S32_S32_RR;
2591     break;
2592   case NVPTXISD::Tex3DS32Float:
2593     Opc = NVPTX::TEX_3D_S32_F32_RR;
2594     break;
2595   case NVPTXISD::Tex3DS32FloatLevel:
2596     Opc = NVPTX::TEX_3D_S32_F32_LEVEL_RR;
2597     break;
2598   case NVPTXISD::Tex3DS32FloatGrad:
2599     Opc = NVPTX::TEX_3D_S32_F32_GRAD_RR;
2600     break;
2601   case NVPTXISD::Tex3DU32S32:
2602     Opc = NVPTX::TEX_3D_U32_S32_RR;
2603     break;
2604   case NVPTXISD::Tex3DU32Float:
2605     Opc = NVPTX::TEX_3D_U32_F32_RR;
2606     break;
2607   case NVPTXISD::Tex3DU32FloatLevel:
2608     Opc = NVPTX::TEX_3D_U32_F32_LEVEL_RR;
2609     break;
2610   case NVPTXISD::Tex3DU32FloatGrad:
2611     Opc = NVPTX::TEX_3D_U32_F32_GRAD_RR;
2612     break;
2613   case NVPTXISD::TexCubeFloatFloat:
2614     Opc = NVPTX::TEX_CUBE_F32_F32_RR;
2615     break;
2616   case NVPTXISD::TexCubeFloatFloatLevel:
2617     Opc = NVPTX::TEX_CUBE_F32_F32_LEVEL_RR;
2618     break;
2619   case NVPTXISD::TexCubeS32Float:
2620     Opc = NVPTX::TEX_CUBE_S32_F32_RR;
2621     break;
2622   case NVPTXISD::TexCubeS32FloatLevel:
2623     Opc = NVPTX::TEX_CUBE_S32_F32_LEVEL_RR;
2624     break;
2625   case NVPTXISD::TexCubeU32Float:
2626     Opc = NVPTX::TEX_CUBE_U32_F32_RR;
2627     break;
2628   case NVPTXISD::TexCubeU32FloatLevel:
2629     Opc = NVPTX::TEX_CUBE_U32_F32_LEVEL_RR;
2630     break;
2631   case NVPTXISD::TexCubeArrayFloatFloat:
2632     Opc = NVPTX::TEX_CUBE_ARRAY_F32_F32_RR;
2633     break;
2634   case NVPTXISD::TexCubeArrayFloatFloatLevel:
2635     Opc = NVPTX::TEX_CUBE_ARRAY_F32_F32_LEVEL_RR;
2636     break;
2637   case NVPTXISD::TexCubeArrayS32Float:
2638     Opc = NVPTX::TEX_CUBE_ARRAY_S32_F32_RR;
2639     break;
2640   case NVPTXISD::TexCubeArrayS32FloatLevel:
2641     Opc = NVPTX::TEX_CUBE_ARRAY_S32_F32_LEVEL_RR;
2642     break;
2643   case NVPTXISD::TexCubeArrayU32Float:
2644     Opc = NVPTX::TEX_CUBE_ARRAY_U32_F32_RR;
2645     break;
2646   case NVPTXISD::TexCubeArrayU32FloatLevel:
2647     Opc = NVPTX::TEX_CUBE_ARRAY_U32_F32_LEVEL_RR;
2648     break;
2649   case NVPTXISD::Tld4R2DFloatFloat:
2650     Opc = NVPTX::TLD4_R_2D_F32_F32_RR;
2651     break;
2652   case NVPTXISD::Tld4G2DFloatFloat:
2653     Opc = NVPTX::TLD4_G_2D_F32_F32_RR;
2654     break;
2655   case NVPTXISD::Tld4B2DFloatFloat:
2656     Opc = NVPTX::TLD4_B_2D_F32_F32_RR;
2657     break;
2658   case NVPTXISD::Tld4A2DFloatFloat:
2659     Opc = NVPTX::TLD4_A_2D_F32_F32_RR;
2660     break;
2661   case NVPTXISD::Tld4R2DS64Float:
2662     Opc = NVPTX::TLD4_R_2D_S32_F32_RR;
2663     break;
2664   case NVPTXISD::Tld4G2DS64Float:
2665     Opc = NVPTX::TLD4_G_2D_S32_F32_RR;
2666     break;
2667   case NVPTXISD::Tld4B2DS64Float:
2668     Opc = NVPTX::TLD4_B_2D_S32_F32_RR;
2669     break;
2670   case NVPTXISD::Tld4A2DS64Float:
2671     Opc = NVPTX::TLD4_A_2D_S32_F32_RR;
2672     break;
2673   case NVPTXISD::Tld4R2DU64Float:
2674     Opc = NVPTX::TLD4_R_2D_U32_F32_RR;
2675     break;
2676   case NVPTXISD::Tld4G2DU64Float:
2677     Opc = NVPTX::TLD4_G_2D_U32_F32_RR;
2678     break;
2679   case NVPTXISD::Tld4B2DU64Float:
2680     Opc = NVPTX::TLD4_B_2D_U32_F32_RR;
2681     break;
2682   case NVPTXISD::Tld4A2DU64Float:
2683     Opc = NVPTX::TLD4_A_2D_U32_F32_RR;
2684     break;
2685   case NVPTXISD::TexUnified1DFloatS32:
2686     Opc = NVPTX::TEX_UNIFIED_1D_F32_S32_R;
2687     break;
2688   case NVPTXISD::TexUnified1DFloatFloat:
2689     Opc = NVPTX::TEX_UNIFIED_1D_F32_F32_R;
2690     break;
2691   case NVPTXISD::TexUnified1DFloatFloatLevel:
2692     Opc = NVPTX::TEX_UNIFIED_1D_F32_F32_LEVEL_R;
2693     break;
2694   case NVPTXISD::TexUnified1DFloatFloatGrad:
2695     Opc = NVPTX::TEX_UNIFIED_1D_F32_F32_GRAD_R;
2696     break;
2697   case NVPTXISD::TexUnified1DS32S32:
2698     Opc = NVPTX::TEX_UNIFIED_1D_S32_S32_R;
2699     break;
2700   case NVPTXISD::TexUnified1DS32Float:
2701     Opc = NVPTX::TEX_UNIFIED_1D_S32_F32_R;
2702     break;
2703   case NVPTXISD::TexUnified1DS32FloatLevel:
2704     Opc = NVPTX::TEX_UNIFIED_1D_S32_F32_LEVEL_R;
2705     break;
2706   case NVPTXISD::TexUnified1DS32FloatGrad:
2707     Opc = NVPTX::TEX_UNIFIED_1D_S32_F32_GRAD_R;
2708     break;
2709   case NVPTXISD::TexUnified1DU32S32:
2710     Opc = NVPTX::TEX_UNIFIED_1D_U32_S32_R;
2711     break;
2712   case NVPTXISD::TexUnified1DU32Float:
2713     Opc = NVPTX::TEX_UNIFIED_1D_U32_F32_R;
2714     break;
2715   case NVPTXISD::TexUnified1DU32FloatLevel:
2716     Opc = NVPTX::TEX_UNIFIED_1D_U32_F32_LEVEL_R;
2717     break;
2718   case NVPTXISD::TexUnified1DU32FloatGrad:
2719     Opc = NVPTX::TEX_UNIFIED_1D_U32_F32_GRAD_R;
2720     break;
2721   case NVPTXISD::TexUnified1DArrayFloatS32:
2722     Opc = NVPTX::TEX_UNIFIED_1D_ARRAY_F32_S32_R;
2723     break;
2724   case NVPTXISD::TexUnified1DArrayFloatFloat:
2725     Opc = NVPTX::TEX_UNIFIED_1D_ARRAY_F32_F32_R;
2726     break;
2727   case NVPTXISD::TexUnified1DArrayFloatFloatLevel:
2728     Opc = NVPTX::TEX_UNIFIED_1D_ARRAY_F32_F32_LEVEL_R;
2729     break;
2730   case NVPTXISD::TexUnified1DArrayFloatFloatGrad:
2731     Opc = NVPTX::TEX_UNIFIED_1D_ARRAY_F32_F32_GRAD_R;
2732     break;
2733   case NVPTXISD::TexUnified1DArrayS32S32:
2734     Opc = NVPTX::TEX_UNIFIED_1D_ARRAY_S32_S32_R;
2735     break;
2736   case NVPTXISD::TexUnified1DArrayS32Float:
2737     Opc = NVPTX::TEX_UNIFIED_1D_ARRAY_S32_F32_R;
2738     break;
2739   case NVPTXISD::TexUnified1DArrayS32FloatLevel:
2740     Opc = NVPTX::TEX_UNIFIED_1D_ARRAY_S32_F32_LEVEL_R;
2741     break;
2742   case NVPTXISD::TexUnified1DArrayS32FloatGrad:
2743     Opc = NVPTX::TEX_UNIFIED_1D_ARRAY_S32_F32_GRAD_R;
2744     break;
2745   case NVPTXISD::TexUnified1DArrayU32S32:
2746     Opc = NVPTX::TEX_UNIFIED_1D_ARRAY_U32_S32_R;
2747     break;
2748   case NVPTXISD::TexUnified1DArrayU32Float:
2749     Opc = NVPTX::TEX_UNIFIED_1D_ARRAY_U32_F32_R;
2750     break;
2751   case NVPTXISD::TexUnified1DArrayU32FloatLevel:
2752     Opc = NVPTX::TEX_UNIFIED_1D_ARRAY_U32_F32_LEVEL_R;
2753     break;
2754   case NVPTXISD::TexUnified1DArrayU32FloatGrad:
2755     Opc = NVPTX::TEX_UNIFIED_1D_ARRAY_U32_F32_GRAD_R;
2756     break;
2757   case NVPTXISD::TexUnified2DFloatS32:
2758     Opc = NVPTX::TEX_UNIFIED_2D_F32_S32_R;
2759     break;
2760   case NVPTXISD::TexUnified2DFloatFloat:
2761     Opc = NVPTX::TEX_UNIFIED_2D_F32_F32_R;
2762     break;
2763   case NVPTXISD::TexUnified2DFloatFloatLevel:
2764     Opc = NVPTX::TEX_UNIFIED_2D_F32_F32_LEVEL_R;
2765     break;
2766   case NVPTXISD::TexUnified2DFloatFloatGrad:
2767     Opc = NVPTX::TEX_UNIFIED_2D_F32_F32_GRAD_R;
2768     break;
2769   case NVPTXISD::TexUnified2DS32S32:
2770     Opc = NVPTX::TEX_UNIFIED_2D_S32_S32_R;
2771     break;
2772   case NVPTXISD::TexUnified2DS32Float:
2773     Opc = NVPTX::TEX_UNIFIED_2D_S32_F32_R;
2774     break;
2775   case NVPTXISD::TexUnified2DS32FloatLevel:
2776     Opc = NVPTX::TEX_UNIFIED_2D_S32_F32_LEVEL_R;
2777     break;
2778   case NVPTXISD::TexUnified2DS32FloatGrad:
2779     Opc = NVPTX::TEX_UNIFIED_2D_S32_F32_GRAD_R;
2780     break;
2781   case NVPTXISD::TexUnified2DU32S32:
2782     Opc = NVPTX::TEX_UNIFIED_2D_U32_S32_R;
2783     break;
2784   case NVPTXISD::TexUnified2DU32Float:
2785     Opc = NVPTX::TEX_UNIFIED_2D_U32_F32_R;
2786     break;
2787   case NVPTXISD::TexUnified2DU32FloatLevel:
2788     Opc = NVPTX::TEX_UNIFIED_2D_U32_F32_LEVEL_R;
2789     break;
2790   case NVPTXISD::TexUnified2DU32FloatGrad:
2791     Opc = NVPTX::TEX_UNIFIED_2D_U32_F32_GRAD_R;
2792     break;
2793   case NVPTXISD::TexUnified2DArrayFloatS32:
2794     Opc = NVPTX::TEX_UNIFIED_2D_ARRAY_F32_S32_R;
2795     break;
2796   case NVPTXISD::TexUnified2DArrayFloatFloat:
2797     Opc = NVPTX::TEX_UNIFIED_2D_ARRAY_F32_F32_R;
2798     break;
2799   case NVPTXISD::TexUnified2DArrayFloatFloatLevel:
2800     Opc = NVPTX::TEX_UNIFIED_2D_ARRAY_F32_F32_LEVEL_R;
2801     break;
2802   case NVPTXISD::TexUnified2DArrayFloatFloatGrad:
2803     Opc = NVPTX::TEX_UNIFIED_2D_ARRAY_F32_F32_GRAD_R;
2804     break;
2805   case NVPTXISD::TexUnified2DArrayS32S32:
2806     Opc = NVPTX::TEX_UNIFIED_2D_ARRAY_S32_S32_R;
2807     break;
2808   case NVPTXISD::TexUnified2DArrayS32Float:
2809     Opc = NVPTX::TEX_UNIFIED_2D_ARRAY_S32_F32_R;
2810     break;
2811   case NVPTXISD::TexUnified2DArrayS32FloatLevel:
2812     Opc = NVPTX::TEX_UNIFIED_2D_ARRAY_S32_F32_LEVEL_R;
2813     break;
2814   case NVPTXISD::TexUnified2DArrayS32FloatGrad:
2815     Opc = NVPTX::TEX_UNIFIED_2D_ARRAY_S32_F32_GRAD_R;
2816     break;
2817   case NVPTXISD::TexUnified2DArrayU32S32:
2818     Opc = NVPTX::TEX_UNIFIED_2D_ARRAY_U32_S32_R;
2819     break;
2820   case NVPTXISD::TexUnified2DArrayU32Float:
2821     Opc = NVPTX::TEX_UNIFIED_2D_ARRAY_U32_F32_R;
2822     break;
2823   case NVPTXISD::TexUnified2DArrayU32FloatLevel:
2824     Opc = NVPTX::TEX_UNIFIED_2D_ARRAY_U32_F32_LEVEL_R;
2825     break;
2826   case NVPTXISD::TexUnified2DArrayU32FloatGrad:
2827     Opc = NVPTX::TEX_UNIFIED_2D_ARRAY_U32_F32_GRAD_R;
2828     break;
2829   case NVPTXISD::TexUnified3DFloatS32:
2830     Opc = NVPTX::TEX_UNIFIED_3D_F32_S32_R;
2831     break;
2832   case NVPTXISD::TexUnified3DFloatFloat:
2833     Opc = NVPTX::TEX_UNIFIED_3D_F32_F32_R;
2834     break;
2835   case NVPTXISD::TexUnified3DFloatFloatLevel:
2836     Opc = NVPTX::TEX_UNIFIED_3D_F32_F32_LEVEL_R;
2837     break;
2838   case NVPTXISD::TexUnified3DFloatFloatGrad:
2839     Opc = NVPTX::TEX_UNIFIED_3D_F32_F32_GRAD_R;
2840     break;
2841   case NVPTXISD::TexUnified3DS32S32:
2842     Opc = NVPTX::TEX_UNIFIED_3D_S32_S32_R;
2843     break;
2844   case NVPTXISD::TexUnified3DS32Float:
2845     Opc = NVPTX::TEX_UNIFIED_3D_S32_F32_R;
2846     break;
2847   case NVPTXISD::TexUnified3DS32FloatLevel:
2848     Opc = NVPTX::TEX_UNIFIED_3D_S32_F32_LEVEL_R;
2849     break;
2850   case NVPTXISD::TexUnified3DS32FloatGrad:
2851     Opc = NVPTX::TEX_UNIFIED_3D_S32_F32_GRAD_R;
2852     break;
2853   case NVPTXISD::TexUnified3DU32S32:
2854     Opc = NVPTX::TEX_UNIFIED_3D_U32_S32_R;
2855     break;
2856   case NVPTXISD::TexUnified3DU32Float:
2857     Opc = NVPTX::TEX_UNIFIED_3D_U32_F32_R;
2858     break;
2859   case NVPTXISD::TexUnified3DU32FloatLevel:
2860     Opc = NVPTX::TEX_UNIFIED_3D_U32_F32_LEVEL_R;
2861     break;
2862   case NVPTXISD::TexUnified3DU32FloatGrad:
2863     Opc = NVPTX::TEX_UNIFIED_3D_U32_F32_GRAD_R;
2864     break;
2865   case NVPTXISD::TexUnifiedCubeFloatFloat:
2866     Opc = NVPTX::TEX_UNIFIED_CUBE_F32_F32_R;
2867     break;
2868   case NVPTXISD::TexUnifiedCubeFloatFloatLevel:
2869     Opc = NVPTX::TEX_UNIFIED_CUBE_F32_F32_LEVEL_R;
2870     break;
2871   case NVPTXISD::TexUnifiedCubeS32Float:
2872     Opc = NVPTX::TEX_UNIFIED_CUBE_S32_F32_R;
2873     break;
2874   case NVPTXISD::TexUnifiedCubeS32FloatLevel:
2875     Opc = NVPTX::TEX_UNIFIED_CUBE_S32_F32_LEVEL_R;
2876     break;
2877   case NVPTXISD::TexUnifiedCubeU32Float:
2878     Opc = NVPTX::TEX_UNIFIED_CUBE_U32_F32_R;
2879     break;
2880   case NVPTXISD::TexUnifiedCubeU32FloatLevel:
2881     Opc = NVPTX::TEX_UNIFIED_CUBE_U32_F32_LEVEL_R;
2882     break;
2883   case NVPTXISD::TexUnifiedCubeArrayFloatFloat:
2884     Opc = NVPTX::TEX_UNIFIED_CUBE_ARRAY_F32_F32_R;
2885     break;
2886   case NVPTXISD::TexUnifiedCubeArrayFloatFloatLevel:
2887     Opc = NVPTX::TEX_UNIFIED_CUBE_ARRAY_F32_F32_LEVEL_R;
2888     break;
2889   case NVPTXISD::TexUnifiedCubeArrayS32Float:
2890     Opc = NVPTX::TEX_UNIFIED_CUBE_ARRAY_S32_F32_R;
2891     break;
2892   case NVPTXISD::TexUnifiedCubeArrayS32FloatLevel:
2893     Opc = NVPTX::TEX_UNIFIED_CUBE_ARRAY_S32_F32_LEVEL_R;
2894     break;
2895   case NVPTXISD::TexUnifiedCubeArrayU32Float:
2896     Opc = NVPTX::TEX_UNIFIED_CUBE_ARRAY_U32_F32_R;
2897     break;
2898   case NVPTXISD::TexUnifiedCubeArrayU32FloatLevel:
2899     Opc = NVPTX::TEX_UNIFIED_CUBE_ARRAY_U32_F32_LEVEL_R;
2900     break;
2901   case NVPTXISD::Tld4UnifiedR2DFloatFloat:
2902     Opc = NVPTX::TLD4_UNIFIED_R_2D_F32_F32_R;
2903     break;
2904   case NVPTXISD::Tld4UnifiedG2DFloatFloat:
2905     Opc = NVPTX::TLD4_UNIFIED_G_2D_F32_F32_R;
2906     break;
2907   case NVPTXISD::Tld4UnifiedB2DFloatFloat:
2908     Opc = NVPTX::TLD4_UNIFIED_B_2D_F32_F32_R;
2909     break;
2910   case NVPTXISD::Tld4UnifiedA2DFloatFloat:
2911     Opc = NVPTX::TLD4_UNIFIED_A_2D_F32_F32_R;
2912     break;
2913   case NVPTXISD::Tld4UnifiedR2DS64Float:
2914     Opc = NVPTX::TLD4_UNIFIED_R_2D_S32_F32_R;
2915     break;
2916   case NVPTXISD::Tld4UnifiedG2DS64Float:
2917     Opc = NVPTX::TLD4_UNIFIED_G_2D_S32_F32_R;
2918     break;
2919   case NVPTXISD::Tld4UnifiedB2DS64Float:
2920     Opc = NVPTX::TLD4_UNIFIED_B_2D_S32_F32_R;
2921     break;
2922   case NVPTXISD::Tld4UnifiedA2DS64Float:
2923     Opc = NVPTX::TLD4_UNIFIED_A_2D_S32_F32_R;
2924     break;
2925   case NVPTXISD::Tld4UnifiedR2DU64Float:
2926     Opc = NVPTX::TLD4_UNIFIED_R_2D_U32_F32_R;
2927     break;
2928   case NVPTXISD::Tld4UnifiedG2DU64Float:
2929     Opc = NVPTX::TLD4_UNIFIED_G_2D_U32_F32_R;
2930     break;
2931   case NVPTXISD::Tld4UnifiedB2DU64Float:
2932     Opc = NVPTX::TLD4_UNIFIED_B_2D_U32_F32_R;
2933     break;
2934   case NVPTXISD::Tld4UnifiedA2DU64Float:
2935     Opc = NVPTX::TLD4_UNIFIED_A_2D_U32_F32_R;
2936     break;
2937   case NVPTXISD::TexUnifiedCubeFloatFloatGrad:
2938     Opc = NVPTX::TEX_UNIFIED_CUBE_F32_F32_GRAD_R;
2939     break;
2940   case NVPTXISD::TexUnifiedCubeS32FloatGrad:
2941     Opc = NVPTX::TEX_UNIFIED_CUBE_S32_F32_GRAD_R;
2942     break;
2943   case NVPTXISD::TexUnifiedCubeU32FloatGrad:
2944     Opc = NVPTX::TEX_UNIFIED_CUBE_U32_F32_GRAD_R;
2945     break;
2946   case NVPTXISD::TexUnifiedCubeArrayFloatFloatGrad:
2947     Opc = NVPTX::TEX_UNIFIED_CUBE_ARRAY_F32_F32_GRAD_R;
2948     break;
2949   case NVPTXISD::TexUnifiedCubeArrayS32FloatGrad:
2950     Opc = NVPTX::TEX_UNIFIED_CUBE_ARRAY_S32_F32_GRAD_R;
2951     break;
2952   case NVPTXISD::TexUnifiedCubeArrayU32FloatGrad:
2953     Opc = NVPTX::TEX_UNIFIED_CUBE_ARRAY_U32_F32_GRAD_R;
2954     break;
2955   }
2956 
2957   // Copy over operands
2958   SmallVector<SDValue, 8> Ops(drop_begin(N->ops()));
2959   Ops.push_back(N->getOperand(0)); // Move chain to the back.
2960 
2961   ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getVTList(), Ops));
2962   return true;
2963 }
2964 
trySurfaceIntrinsic(SDNode * N)2965 bool NVPTXDAGToDAGISel::trySurfaceIntrinsic(SDNode *N) {
2966   unsigned Opc = 0;
2967   switch (N->getOpcode()) {
2968   default: return false;
2969   case NVPTXISD::Suld1DI8Clamp:
2970     Opc = NVPTX::SULD_1D_I8_CLAMP_R;
2971     break;
2972   case NVPTXISD::Suld1DI16Clamp:
2973     Opc = NVPTX::SULD_1D_I16_CLAMP_R;
2974     break;
2975   case NVPTXISD::Suld1DI32Clamp:
2976     Opc = NVPTX::SULD_1D_I32_CLAMP_R;
2977     break;
2978   case NVPTXISD::Suld1DI64Clamp:
2979     Opc = NVPTX::SULD_1D_I64_CLAMP_R;
2980     break;
2981   case NVPTXISD::Suld1DV2I8Clamp:
2982     Opc = NVPTX::SULD_1D_V2I8_CLAMP_R;
2983     break;
2984   case NVPTXISD::Suld1DV2I16Clamp:
2985     Opc = NVPTX::SULD_1D_V2I16_CLAMP_R;
2986     break;
2987   case NVPTXISD::Suld1DV2I32Clamp:
2988     Opc = NVPTX::SULD_1D_V2I32_CLAMP_R;
2989     break;
2990   case NVPTXISD::Suld1DV2I64Clamp:
2991     Opc = NVPTX::SULD_1D_V2I64_CLAMP_R;
2992     break;
2993   case NVPTXISD::Suld1DV4I8Clamp:
2994     Opc = NVPTX::SULD_1D_V4I8_CLAMP_R;
2995     break;
2996   case NVPTXISD::Suld1DV4I16Clamp:
2997     Opc = NVPTX::SULD_1D_V4I16_CLAMP_R;
2998     break;
2999   case NVPTXISD::Suld1DV4I32Clamp:
3000     Opc = NVPTX::SULD_1D_V4I32_CLAMP_R;
3001     break;
3002   case NVPTXISD::Suld1DArrayI8Clamp:
3003     Opc = NVPTX::SULD_1D_ARRAY_I8_CLAMP_R;
3004     break;
3005   case NVPTXISD::Suld1DArrayI16Clamp:
3006     Opc = NVPTX::SULD_1D_ARRAY_I16_CLAMP_R;
3007     break;
3008   case NVPTXISD::Suld1DArrayI32Clamp:
3009     Opc = NVPTX::SULD_1D_ARRAY_I32_CLAMP_R;
3010     break;
3011   case NVPTXISD::Suld1DArrayI64Clamp:
3012     Opc = NVPTX::SULD_1D_ARRAY_I64_CLAMP_R;
3013     break;
3014   case NVPTXISD::Suld1DArrayV2I8Clamp:
3015     Opc = NVPTX::SULD_1D_ARRAY_V2I8_CLAMP_R;
3016     break;
3017   case NVPTXISD::Suld1DArrayV2I16Clamp:
3018     Opc = NVPTX::SULD_1D_ARRAY_V2I16_CLAMP_R;
3019     break;
3020   case NVPTXISD::Suld1DArrayV2I32Clamp:
3021     Opc = NVPTX::SULD_1D_ARRAY_V2I32_CLAMP_R;
3022     break;
3023   case NVPTXISD::Suld1DArrayV2I64Clamp:
3024     Opc = NVPTX::SULD_1D_ARRAY_V2I64_CLAMP_R;
3025     break;
3026   case NVPTXISD::Suld1DArrayV4I8Clamp:
3027     Opc = NVPTX::SULD_1D_ARRAY_V4I8_CLAMP_R;
3028     break;
3029   case NVPTXISD::Suld1DArrayV4I16Clamp:
3030     Opc = NVPTX::SULD_1D_ARRAY_V4I16_CLAMP_R;
3031     break;
3032   case NVPTXISD::Suld1DArrayV4I32Clamp:
3033     Opc = NVPTX::SULD_1D_ARRAY_V4I32_CLAMP_R;
3034     break;
3035   case NVPTXISD::Suld2DI8Clamp:
3036     Opc = NVPTX::SULD_2D_I8_CLAMP_R;
3037     break;
3038   case NVPTXISD::Suld2DI16Clamp:
3039     Opc = NVPTX::SULD_2D_I16_CLAMP_R;
3040     break;
3041   case NVPTXISD::Suld2DI32Clamp:
3042     Opc = NVPTX::SULD_2D_I32_CLAMP_R;
3043     break;
3044   case NVPTXISD::Suld2DI64Clamp:
3045     Opc = NVPTX::SULD_2D_I64_CLAMP_R;
3046     break;
3047   case NVPTXISD::Suld2DV2I8Clamp:
3048     Opc = NVPTX::SULD_2D_V2I8_CLAMP_R;
3049     break;
3050   case NVPTXISD::Suld2DV2I16Clamp:
3051     Opc = NVPTX::SULD_2D_V2I16_CLAMP_R;
3052     break;
3053   case NVPTXISD::Suld2DV2I32Clamp:
3054     Opc = NVPTX::SULD_2D_V2I32_CLAMP_R;
3055     break;
3056   case NVPTXISD::Suld2DV2I64Clamp:
3057     Opc = NVPTX::SULD_2D_V2I64_CLAMP_R;
3058     break;
3059   case NVPTXISD::Suld2DV4I8Clamp:
3060     Opc = NVPTX::SULD_2D_V4I8_CLAMP_R;
3061     break;
3062   case NVPTXISD::Suld2DV4I16Clamp:
3063     Opc = NVPTX::SULD_2D_V4I16_CLAMP_R;
3064     break;
3065   case NVPTXISD::Suld2DV4I32Clamp:
3066     Opc = NVPTX::SULD_2D_V4I32_CLAMP_R;
3067     break;
3068   case NVPTXISD::Suld2DArrayI8Clamp:
3069     Opc = NVPTX::SULD_2D_ARRAY_I8_CLAMP_R;
3070     break;
3071   case NVPTXISD::Suld2DArrayI16Clamp:
3072     Opc = NVPTX::SULD_2D_ARRAY_I16_CLAMP_R;
3073     break;
3074   case NVPTXISD::Suld2DArrayI32Clamp:
3075     Opc = NVPTX::SULD_2D_ARRAY_I32_CLAMP_R;
3076     break;
3077   case NVPTXISD::Suld2DArrayI64Clamp:
3078     Opc = NVPTX::SULD_2D_ARRAY_I64_CLAMP_R;
3079     break;
3080   case NVPTXISD::Suld2DArrayV2I8Clamp:
3081     Opc = NVPTX::SULD_2D_ARRAY_V2I8_CLAMP_R;
3082     break;
3083   case NVPTXISD::Suld2DArrayV2I16Clamp:
3084     Opc = NVPTX::SULD_2D_ARRAY_V2I16_CLAMP_R;
3085     break;
3086   case NVPTXISD::Suld2DArrayV2I32Clamp:
3087     Opc = NVPTX::SULD_2D_ARRAY_V2I32_CLAMP_R;
3088     break;
3089   case NVPTXISD::Suld2DArrayV2I64Clamp:
3090     Opc = NVPTX::SULD_2D_ARRAY_V2I64_CLAMP_R;
3091     break;
3092   case NVPTXISD::Suld2DArrayV4I8Clamp:
3093     Opc = NVPTX::SULD_2D_ARRAY_V4I8_CLAMP_R;
3094     break;
3095   case NVPTXISD::Suld2DArrayV4I16Clamp:
3096     Opc = NVPTX::SULD_2D_ARRAY_V4I16_CLAMP_R;
3097     break;
3098   case NVPTXISD::Suld2DArrayV4I32Clamp:
3099     Opc = NVPTX::SULD_2D_ARRAY_V4I32_CLAMP_R;
3100     break;
3101   case NVPTXISD::Suld3DI8Clamp:
3102     Opc = NVPTX::SULD_3D_I8_CLAMP_R;
3103     break;
3104   case NVPTXISD::Suld3DI16Clamp:
3105     Opc = NVPTX::SULD_3D_I16_CLAMP_R;
3106     break;
3107   case NVPTXISD::Suld3DI32Clamp:
3108     Opc = NVPTX::SULD_3D_I32_CLAMP_R;
3109     break;
3110   case NVPTXISD::Suld3DI64Clamp:
3111     Opc = NVPTX::SULD_3D_I64_CLAMP_R;
3112     break;
3113   case NVPTXISD::Suld3DV2I8Clamp:
3114     Opc = NVPTX::SULD_3D_V2I8_CLAMP_R;
3115     break;
3116   case NVPTXISD::Suld3DV2I16Clamp:
3117     Opc = NVPTX::SULD_3D_V2I16_CLAMP_R;
3118     break;
3119   case NVPTXISD::Suld3DV2I32Clamp:
3120     Opc = NVPTX::SULD_3D_V2I32_CLAMP_R;
3121     break;
3122   case NVPTXISD::Suld3DV2I64Clamp:
3123     Opc = NVPTX::SULD_3D_V2I64_CLAMP_R;
3124     break;
3125   case NVPTXISD::Suld3DV4I8Clamp:
3126     Opc = NVPTX::SULD_3D_V4I8_CLAMP_R;
3127     break;
3128   case NVPTXISD::Suld3DV4I16Clamp:
3129     Opc = NVPTX::SULD_3D_V4I16_CLAMP_R;
3130     break;
3131   case NVPTXISD::Suld3DV4I32Clamp:
3132     Opc = NVPTX::SULD_3D_V4I32_CLAMP_R;
3133     break;
3134   case NVPTXISD::Suld1DI8Trap:
3135     Opc = NVPTX::SULD_1D_I8_TRAP_R;
3136     break;
3137   case NVPTXISD::Suld1DI16Trap:
3138     Opc = NVPTX::SULD_1D_I16_TRAP_R;
3139     break;
3140   case NVPTXISD::Suld1DI32Trap:
3141     Opc = NVPTX::SULD_1D_I32_TRAP_R;
3142     break;
3143   case NVPTXISD::Suld1DI64Trap:
3144     Opc = NVPTX::SULD_1D_I64_TRAP_R;
3145     break;
3146   case NVPTXISD::Suld1DV2I8Trap:
3147     Opc = NVPTX::SULD_1D_V2I8_TRAP_R;
3148     break;
3149   case NVPTXISD::Suld1DV2I16Trap:
3150     Opc = NVPTX::SULD_1D_V2I16_TRAP_R;
3151     break;
3152   case NVPTXISD::Suld1DV2I32Trap:
3153     Opc = NVPTX::SULD_1D_V2I32_TRAP_R;
3154     break;
3155   case NVPTXISD::Suld1DV2I64Trap:
3156     Opc = NVPTX::SULD_1D_V2I64_TRAP_R;
3157     break;
3158   case NVPTXISD::Suld1DV4I8Trap:
3159     Opc = NVPTX::SULD_1D_V4I8_TRAP_R;
3160     break;
3161   case NVPTXISD::Suld1DV4I16Trap:
3162     Opc = NVPTX::SULD_1D_V4I16_TRAP_R;
3163     break;
3164   case NVPTXISD::Suld1DV4I32Trap:
3165     Opc = NVPTX::SULD_1D_V4I32_TRAP_R;
3166     break;
3167   case NVPTXISD::Suld1DArrayI8Trap:
3168     Opc = NVPTX::SULD_1D_ARRAY_I8_TRAP_R;
3169     break;
3170   case NVPTXISD::Suld1DArrayI16Trap:
3171     Opc = NVPTX::SULD_1D_ARRAY_I16_TRAP_R;
3172     break;
3173   case NVPTXISD::Suld1DArrayI32Trap:
3174     Opc = NVPTX::SULD_1D_ARRAY_I32_TRAP_R;
3175     break;
3176   case NVPTXISD::Suld1DArrayI64Trap:
3177     Opc = NVPTX::SULD_1D_ARRAY_I64_TRAP_R;
3178     break;
3179   case NVPTXISD::Suld1DArrayV2I8Trap:
3180     Opc = NVPTX::SULD_1D_ARRAY_V2I8_TRAP_R;
3181     break;
3182   case NVPTXISD::Suld1DArrayV2I16Trap:
3183     Opc = NVPTX::SULD_1D_ARRAY_V2I16_TRAP_R;
3184     break;
3185   case NVPTXISD::Suld1DArrayV2I32Trap:
3186     Opc = NVPTX::SULD_1D_ARRAY_V2I32_TRAP_R;
3187     break;
3188   case NVPTXISD::Suld1DArrayV2I64Trap:
3189     Opc = NVPTX::SULD_1D_ARRAY_V2I64_TRAP_R;
3190     break;
3191   case NVPTXISD::Suld1DArrayV4I8Trap:
3192     Opc = NVPTX::SULD_1D_ARRAY_V4I8_TRAP_R;
3193     break;
3194   case NVPTXISD::Suld1DArrayV4I16Trap:
3195     Opc = NVPTX::SULD_1D_ARRAY_V4I16_TRAP_R;
3196     break;
3197   case NVPTXISD::Suld1DArrayV4I32Trap:
3198     Opc = NVPTX::SULD_1D_ARRAY_V4I32_TRAP_R;
3199     break;
3200   case NVPTXISD::Suld2DI8Trap:
3201     Opc = NVPTX::SULD_2D_I8_TRAP_R;
3202     break;
3203   case NVPTXISD::Suld2DI16Trap:
3204     Opc = NVPTX::SULD_2D_I16_TRAP_R;
3205     break;
3206   case NVPTXISD::Suld2DI32Trap:
3207     Opc = NVPTX::SULD_2D_I32_TRAP_R;
3208     break;
3209   case NVPTXISD::Suld2DI64Trap:
3210     Opc = NVPTX::SULD_2D_I64_TRAP_R;
3211     break;
3212   case NVPTXISD::Suld2DV2I8Trap:
3213     Opc = NVPTX::SULD_2D_V2I8_TRAP_R;
3214     break;
3215   case NVPTXISD::Suld2DV2I16Trap:
3216     Opc = NVPTX::SULD_2D_V2I16_TRAP_R;
3217     break;
3218   case NVPTXISD::Suld2DV2I32Trap:
3219     Opc = NVPTX::SULD_2D_V2I32_TRAP_R;
3220     break;
3221   case NVPTXISD::Suld2DV2I64Trap:
3222     Opc = NVPTX::SULD_2D_V2I64_TRAP_R;
3223     break;
3224   case NVPTXISD::Suld2DV4I8Trap:
3225     Opc = NVPTX::SULD_2D_V4I8_TRAP_R;
3226     break;
3227   case NVPTXISD::Suld2DV4I16Trap:
3228     Opc = NVPTX::SULD_2D_V4I16_TRAP_R;
3229     break;
3230   case NVPTXISD::Suld2DV4I32Trap:
3231     Opc = NVPTX::SULD_2D_V4I32_TRAP_R;
3232     break;
3233   case NVPTXISD::Suld2DArrayI8Trap:
3234     Opc = NVPTX::SULD_2D_ARRAY_I8_TRAP_R;
3235     break;
3236   case NVPTXISD::Suld2DArrayI16Trap:
3237     Opc = NVPTX::SULD_2D_ARRAY_I16_TRAP_R;
3238     break;
3239   case NVPTXISD::Suld2DArrayI32Trap:
3240     Opc = NVPTX::SULD_2D_ARRAY_I32_TRAP_R;
3241     break;
3242   case NVPTXISD::Suld2DArrayI64Trap:
3243     Opc = NVPTX::SULD_2D_ARRAY_I64_TRAP_R;
3244     break;
3245   case NVPTXISD::Suld2DArrayV2I8Trap:
3246     Opc = NVPTX::SULD_2D_ARRAY_V2I8_TRAP_R;
3247     break;
3248   case NVPTXISD::Suld2DArrayV2I16Trap:
3249     Opc = NVPTX::SULD_2D_ARRAY_V2I16_TRAP_R;
3250     break;
3251   case NVPTXISD::Suld2DArrayV2I32Trap:
3252     Opc = NVPTX::SULD_2D_ARRAY_V2I32_TRAP_R;
3253     break;
3254   case NVPTXISD::Suld2DArrayV2I64Trap:
3255     Opc = NVPTX::SULD_2D_ARRAY_V2I64_TRAP_R;
3256     break;
3257   case NVPTXISD::Suld2DArrayV4I8Trap:
3258     Opc = NVPTX::SULD_2D_ARRAY_V4I8_TRAP_R;
3259     break;
3260   case NVPTXISD::Suld2DArrayV4I16Trap:
3261     Opc = NVPTX::SULD_2D_ARRAY_V4I16_TRAP_R;
3262     break;
3263   case NVPTXISD::Suld2DArrayV4I32Trap:
3264     Opc = NVPTX::SULD_2D_ARRAY_V4I32_TRAP_R;
3265     break;
3266   case NVPTXISD::Suld3DI8Trap:
3267     Opc = NVPTX::SULD_3D_I8_TRAP_R;
3268     break;
3269   case NVPTXISD::Suld3DI16Trap:
3270     Opc = NVPTX::SULD_3D_I16_TRAP_R;
3271     break;
3272   case NVPTXISD::Suld3DI32Trap:
3273     Opc = NVPTX::SULD_3D_I32_TRAP_R;
3274     break;
3275   case NVPTXISD::Suld3DI64Trap:
3276     Opc = NVPTX::SULD_3D_I64_TRAP_R;
3277     break;
3278   case NVPTXISD::Suld3DV2I8Trap:
3279     Opc = NVPTX::SULD_3D_V2I8_TRAP_R;
3280     break;
3281   case NVPTXISD::Suld3DV2I16Trap:
3282     Opc = NVPTX::SULD_3D_V2I16_TRAP_R;
3283     break;
3284   case NVPTXISD::Suld3DV2I32Trap:
3285     Opc = NVPTX::SULD_3D_V2I32_TRAP_R;
3286     break;
3287   case NVPTXISD::Suld3DV2I64Trap:
3288     Opc = NVPTX::SULD_3D_V2I64_TRAP_R;
3289     break;
3290   case NVPTXISD::Suld3DV4I8Trap:
3291     Opc = NVPTX::SULD_3D_V4I8_TRAP_R;
3292     break;
3293   case NVPTXISD::Suld3DV4I16Trap:
3294     Opc = NVPTX::SULD_3D_V4I16_TRAP_R;
3295     break;
3296   case NVPTXISD::Suld3DV4I32Trap:
3297     Opc = NVPTX::SULD_3D_V4I32_TRAP_R;
3298     break;
3299   case NVPTXISD::Suld1DI8Zero:
3300     Opc = NVPTX::SULD_1D_I8_ZERO_R;
3301     break;
3302   case NVPTXISD::Suld1DI16Zero:
3303     Opc = NVPTX::SULD_1D_I16_ZERO_R;
3304     break;
3305   case NVPTXISD::Suld1DI32Zero:
3306     Opc = NVPTX::SULD_1D_I32_ZERO_R;
3307     break;
3308   case NVPTXISD::Suld1DI64Zero:
3309     Opc = NVPTX::SULD_1D_I64_ZERO_R;
3310     break;
3311   case NVPTXISD::Suld1DV2I8Zero:
3312     Opc = NVPTX::SULD_1D_V2I8_ZERO_R;
3313     break;
3314   case NVPTXISD::Suld1DV2I16Zero:
3315     Opc = NVPTX::SULD_1D_V2I16_ZERO_R;
3316     break;
3317   case NVPTXISD::Suld1DV2I32Zero:
3318     Opc = NVPTX::SULD_1D_V2I32_ZERO_R;
3319     break;
3320   case NVPTXISD::Suld1DV2I64Zero:
3321     Opc = NVPTX::SULD_1D_V2I64_ZERO_R;
3322     break;
3323   case NVPTXISD::Suld1DV4I8Zero:
3324     Opc = NVPTX::SULD_1D_V4I8_ZERO_R;
3325     break;
3326   case NVPTXISD::Suld1DV4I16Zero:
3327     Opc = NVPTX::SULD_1D_V4I16_ZERO_R;
3328     break;
3329   case NVPTXISD::Suld1DV4I32Zero:
3330     Opc = NVPTX::SULD_1D_V4I32_ZERO_R;
3331     break;
3332   case NVPTXISD::Suld1DArrayI8Zero:
3333     Opc = NVPTX::SULD_1D_ARRAY_I8_ZERO_R;
3334     break;
3335   case NVPTXISD::Suld1DArrayI16Zero:
3336     Opc = NVPTX::SULD_1D_ARRAY_I16_ZERO_R;
3337     break;
3338   case NVPTXISD::Suld1DArrayI32Zero:
3339     Opc = NVPTX::SULD_1D_ARRAY_I32_ZERO_R;
3340     break;
3341   case NVPTXISD::Suld1DArrayI64Zero:
3342     Opc = NVPTX::SULD_1D_ARRAY_I64_ZERO_R;
3343     break;
3344   case NVPTXISD::Suld1DArrayV2I8Zero:
3345     Opc = NVPTX::SULD_1D_ARRAY_V2I8_ZERO_R;
3346     break;
3347   case NVPTXISD::Suld1DArrayV2I16Zero:
3348     Opc = NVPTX::SULD_1D_ARRAY_V2I16_ZERO_R;
3349     break;
3350   case NVPTXISD::Suld1DArrayV2I32Zero:
3351     Opc = NVPTX::SULD_1D_ARRAY_V2I32_ZERO_R;
3352     break;
3353   case NVPTXISD::Suld1DArrayV2I64Zero:
3354     Opc = NVPTX::SULD_1D_ARRAY_V2I64_ZERO_R;
3355     break;
3356   case NVPTXISD::Suld1DArrayV4I8Zero:
3357     Opc = NVPTX::SULD_1D_ARRAY_V4I8_ZERO_R;
3358     break;
3359   case NVPTXISD::Suld1DArrayV4I16Zero:
3360     Opc = NVPTX::SULD_1D_ARRAY_V4I16_ZERO_R;
3361     break;
3362   case NVPTXISD::Suld1DArrayV4I32Zero:
3363     Opc = NVPTX::SULD_1D_ARRAY_V4I32_ZERO_R;
3364     break;
3365   case NVPTXISD::Suld2DI8Zero:
3366     Opc = NVPTX::SULD_2D_I8_ZERO_R;
3367     break;
3368   case NVPTXISD::Suld2DI16Zero:
3369     Opc = NVPTX::SULD_2D_I16_ZERO_R;
3370     break;
3371   case NVPTXISD::Suld2DI32Zero:
3372     Opc = NVPTX::SULD_2D_I32_ZERO_R;
3373     break;
3374   case NVPTXISD::Suld2DI64Zero:
3375     Opc = NVPTX::SULD_2D_I64_ZERO_R;
3376     break;
3377   case NVPTXISD::Suld2DV2I8Zero:
3378     Opc = NVPTX::SULD_2D_V2I8_ZERO_R;
3379     break;
3380   case NVPTXISD::Suld2DV2I16Zero:
3381     Opc = NVPTX::SULD_2D_V2I16_ZERO_R;
3382     break;
3383   case NVPTXISD::Suld2DV2I32Zero:
3384     Opc = NVPTX::SULD_2D_V2I32_ZERO_R;
3385     break;
3386   case NVPTXISD::Suld2DV2I64Zero:
3387     Opc = NVPTX::SULD_2D_V2I64_ZERO_R;
3388     break;
3389   case NVPTXISD::Suld2DV4I8Zero:
3390     Opc = NVPTX::SULD_2D_V4I8_ZERO_R;
3391     break;
3392   case NVPTXISD::Suld2DV4I16Zero:
3393     Opc = NVPTX::SULD_2D_V4I16_ZERO_R;
3394     break;
3395   case NVPTXISD::Suld2DV4I32Zero:
3396     Opc = NVPTX::SULD_2D_V4I32_ZERO_R;
3397     break;
3398   case NVPTXISD::Suld2DArrayI8Zero:
3399     Opc = NVPTX::SULD_2D_ARRAY_I8_ZERO_R;
3400     break;
3401   case NVPTXISD::Suld2DArrayI16Zero:
3402     Opc = NVPTX::SULD_2D_ARRAY_I16_ZERO_R;
3403     break;
3404   case NVPTXISD::Suld2DArrayI32Zero:
3405     Opc = NVPTX::SULD_2D_ARRAY_I32_ZERO_R;
3406     break;
3407   case NVPTXISD::Suld2DArrayI64Zero:
3408     Opc = NVPTX::SULD_2D_ARRAY_I64_ZERO_R;
3409     break;
3410   case NVPTXISD::Suld2DArrayV2I8Zero:
3411     Opc = NVPTX::SULD_2D_ARRAY_V2I8_ZERO_R;
3412     break;
3413   case NVPTXISD::Suld2DArrayV2I16Zero:
3414     Opc = NVPTX::SULD_2D_ARRAY_V2I16_ZERO_R;
3415     break;
3416   case NVPTXISD::Suld2DArrayV2I32Zero:
3417     Opc = NVPTX::SULD_2D_ARRAY_V2I32_ZERO_R;
3418     break;
3419   case NVPTXISD::Suld2DArrayV2I64Zero:
3420     Opc = NVPTX::SULD_2D_ARRAY_V2I64_ZERO_R;
3421     break;
3422   case NVPTXISD::Suld2DArrayV4I8Zero:
3423     Opc = NVPTX::SULD_2D_ARRAY_V4I8_ZERO_R;
3424     break;
3425   case NVPTXISD::Suld2DArrayV4I16Zero:
3426     Opc = NVPTX::SULD_2D_ARRAY_V4I16_ZERO_R;
3427     break;
3428   case NVPTXISD::Suld2DArrayV4I32Zero:
3429     Opc = NVPTX::SULD_2D_ARRAY_V4I32_ZERO_R;
3430     break;
3431   case NVPTXISD::Suld3DI8Zero:
3432     Opc = NVPTX::SULD_3D_I8_ZERO_R;
3433     break;
3434   case NVPTXISD::Suld3DI16Zero:
3435     Opc = NVPTX::SULD_3D_I16_ZERO_R;
3436     break;
3437   case NVPTXISD::Suld3DI32Zero:
3438     Opc = NVPTX::SULD_3D_I32_ZERO_R;
3439     break;
3440   case NVPTXISD::Suld3DI64Zero:
3441     Opc = NVPTX::SULD_3D_I64_ZERO_R;
3442     break;
3443   case NVPTXISD::Suld3DV2I8Zero:
3444     Opc = NVPTX::SULD_3D_V2I8_ZERO_R;
3445     break;
3446   case NVPTXISD::Suld3DV2I16Zero:
3447     Opc = NVPTX::SULD_3D_V2I16_ZERO_R;
3448     break;
3449   case NVPTXISD::Suld3DV2I32Zero:
3450     Opc = NVPTX::SULD_3D_V2I32_ZERO_R;
3451     break;
3452   case NVPTXISD::Suld3DV2I64Zero:
3453     Opc = NVPTX::SULD_3D_V2I64_ZERO_R;
3454     break;
3455   case NVPTXISD::Suld3DV4I8Zero:
3456     Opc = NVPTX::SULD_3D_V4I8_ZERO_R;
3457     break;
3458   case NVPTXISD::Suld3DV4I16Zero:
3459     Opc = NVPTX::SULD_3D_V4I16_ZERO_R;
3460     break;
3461   case NVPTXISD::Suld3DV4I32Zero:
3462     Opc = NVPTX::SULD_3D_V4I32_ZERO_R;
3463     break;
3464   }
3465 
3466   // Copy over operands
3467   SmallVector<SDValue, 8> Ops(drop_begin(N->ops()));
3468   Ops.push_back(N->getOperand(0)); // Move chain to the back.
3469 
3470   ReplaceNode(N, CurDAG->getMachineNode(Opc, SDLoc(N), N->getVTList(), Ops));
3471   return true;
3472 }
3473 
3474 
3475 /// SelectBFE - Look for instruction sequences that can be made more efficient
3476 /// by using the 'bfe' (bit-field extract) PTX instruction
tryBFE(SDNode * N)3477 bool NVPTXDAGToDAGISel::tryBFE(SDNode *N) {
3478   SDLoc DL(N);
3479   SDValue LHS = N->getOperand(0);
3480   SDValue RHS = N->getOperand(1);
3481   SDValue Len;
3482   SDValue Start;
3483   SDValue Val;
3484   bool IsSigned = false;
3485 
3486   if (N->getOpcode() == ISD::AND) {
3487     // Canonicalize the operands
3488     // We want 'and %val, %mask'
3489     if (isa<ConstantSDNode>(LHS) && !isa<ConstantSDNode>(RHS)) {
3490       std::swap(LHS, RHS);
3491     }
3492 
3493     ConstantSDNode *Mask = dyn_cast<ConstantSDNode>(RHS);
3494     if (!Mask) {
3495       // We need a constant mask on the RHS of the AND
3496       return false;
3497     }
3498 
3499     // Extract the mask bits
3500     uint64_t MaskVal = Mask->getZExtValue();
3501     if (!isMask_64(MaskVal)) {
3502       // We *could* handle shifted masks here, but doing so would require an
3503       // 'and' operation to fix up the low-order bits so we would trade
3504       // shr+and for bfe+and, which has the same throughput
3505       return false;
3506     }
3507 
3508     // How many bits are in our mask?
3509     int64_t NumBits = countr_one(MaskVal);
3510     Len = CurDAG->getTargetConstant(NumBits, DL, MVT::i32);
3511 
3512     if (LHS.getOpcode() == ISD::SRL || LHS.getOpcode() == ISD::SRA) {
3513       // We have a 'srl/and' pair, extract the effective start bit and length
3514       Val = LHS.getNode()->getOperand(0);
3515       Start = LHS.getNode()->getOperand(1);
3516       ConstantSDNode *StartConst = dyn_cast<ConstantSDNode>(Start);
3517       if (StartConst) {
3518         uint64_t StartVal = StartConst->getZExtValue();
3519         // How many "good" bits do we have left?  "good" is defined here as bits
3520         // that exist in the original value, not shifted in.
3521         int64_t GoodBits = Start.getValueSizeInBits() - StartVal;
3522         if (NumBits > GoodBits) {
3523           // Do not handle the case where bits have been shifted in. In theory
3524           // we could handle this, but the cost is likely higher than just
3525           // emitting the srl/and pair.
3526           return false;
3527         }
3528         Start = CurDAG->getTargetConstant(StartVal, DL, MVT::i32);
3529       } else {
3530         // Do not handle the case where the shift amount (can be zero if no srl
3531         // was found) is not constant. We could handle this case, but it would
3532         // require run-time logic that would be more expensive than just
3533         // emitting the srl/and pair.
3534         return false;
3535       }
3536     } else {
3537       // Do not handle the case where the LHS of the and is not a shift. While
3538       // it would be trivial to handle this case, it would just transform
3539       // 'and' -> 'bfe', but 'and' has higher-throughput.
3540       return false;
3541     }
3542   } else if (N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) {
3543     if (LHS->getOpcode() == ISD::AND) {
3544       ConstantSDNode *ShiftCnst = dyn_cast<ConstantSDNode>(RHS);
3545       if (!ShiftCnst) {
3546         // Shift amount must be constant
3547         return false;
3548       }
3549 
3550       uint64_t ShiftAmt = ShiftCnst->getZExtValue();
3551 
3552       SDValue AndLHS = LHS->getOperand(0);
3553       SDValue AndRHS = LHS->getOperand(1);
3554 
3555       // Canonicalize the AND to have the mask on the RHS
3556       if (isa<ConstantSDNode>(AndLHS)) {
3557         std::swap(AndLHS, AndRHS);
3558       }
3559 
3560       ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(AndRHS);
3561       if (!MaskCnst) {
3562         // Mask must be constant
3563         return false;
3564       }
3565 
3566       uint64_t MaskVal = MaskCnst->getZExtValue();
3567       uint64_t NumZeros;
3568       uint64_t NumBits;
3569       if (isMask_64(MaskVal)) {
3570         NumZeros = 0;
3571         // The number of bits in the result bitfield will be the number of
3572         // trailing ones (the AND) minus the number of bits we shift off
3573         NumBits = llvm::countr_one(MaskVal) - ShiftAmt;
3574       } else if (isShiftedMask_64(MaskVal)) {
3575         NumZeros = llvm::countr_zero(MaskVal);
3576         unsigned NumOnes = llvm::countr_one(MaskVal >> NumZeros);
3577         // The number of bits in the result bitfield will be the number of
3578         // trailing zeros plus the number of set bits in the mask minus the
3579         // number of bits we shift off
3580         NumBits = NumZeros + NumOnes - ShiftAmt;
3581       } else {
3582         // This is not a mask we can handle
3583         return false;
3584       }
3585 
3586       if (ShiftAmt < NumZeros) {
3587         // Handling this case would require extra logic that would make this
3588         // transformation non-profitable
3589         return false;
3590       }
3591 
3592       Val = AndLHS;
3593       Start = CurDAG->getTargetConstant(ShiftAmt, DL, MVT::i32);
3594       Len = CurDAG->getTargetConstant(NumBits, DL, MVT::i32);
3595     } else if (LHS->getOpcode() == ISD::SHL) {
3596       // Here, we have a pattern like:
3597       //
3598       // (sra (shl val, NN), MM)
3599       // or
3600       // (srl (shl val, NN), MM)
3601       //
3602       // If MM >= NN, we can efficiently optimize this with bfe
3603       Val = LHS->getOperand(0);
3604 
3605       SDValue ShlRHS = LHS->getOperand(1);
3606       ConstantSDNode *ShlCnst = dyn_cast<ConstantSDNode>(ShlRHS);
3607       if (!ShlCnst) {
3608         // Shift amount must be constant
3609         return false;
3610       }
3611       uint64_t InnerShiftAmt = ShlCnst->getZExtValue();
3612 
3613       SDValue ShrRHS = RHS;
3614       ConstantSDNode *ShrCnst = dyn_cast<ConstantSDNode>(ShrRHS);
3615       if (!ShrCnst) {
3616         // Shift amount must be constant
3617         return false;
3618       }
3619       uint64_t OuterShiftAmt = ShrCnst->getZExtValue();
3620 
3621       // To avoid extra codegen and be profitable, we need Outer >= Inner
3622       if (OuterShiftAmt < InnerShiftAmt) {
3623         return false;
3624       }
3625 
3626       // If the outer shift is more than the type size, we have no bitfield to
3627       // extract (since we also check that the inner shift is <= the outer shift
3628       // then this also implies that the inner shift is < the type size)
3629       if (OuterShiftAmt >= Val.getValueSizeInBits()) {
3630         return false;
3631       }
3632 
3633       Start = CurDAG->getTargetConstant(OuterShiftAmt - InnerShiftAmt, DL,
3634                                         MVT::i32);
3635       Len = CurDAG->getTargetConstant(Val.getValueSizeInBits() - OuterShiftAmt,
3636                                       DL, MVT::i32);
3637 
3638       if (N->getOpcode() == ISD::SRA) {
3639         // If we have a arithmetic right shift, we need to use the signed bfe
3640         // variant
3641         IsSigned = true;
3642       }
3643     } else {
3644       // No can do...
3645       return false;
3646     }
3647   } else {
3648     // No can do...
3649     return false;
3650   }
3651 
3652 
3653   unsigned Opc;
3654   // For the BFE operations we form here from "and" and "srl", always use the
3655   // unsigned variants.
3656   if (Val.getValueType() == MVT::i32) {
3657     if (IsSigned) {
3658       Opc = NVPTX::BFE_S32rii;
3659     } else {
3660       Opc = NVPTX::BFE_U32rii;
3661     }
3662   } else if (Val.getValueType() == MVT::i64) {
3663     if (IsSigned) {
3664       Opc = NVPTX::BFE_S64rii;
3665     } else {
3666       Opc = NVPTX::BFE_U64rii;
3667     }
3668   } else {
3669     // We cannot handle this type
3670     return false;
3671   }
3672 
3673   SDValue Ops[] = {
3674     Val, Start, Len
3675   };
3676 
3677   ReplaceNode(N, CurDAG->getMachineNode(Opc, DL, N->getVTList(), Ops));
3678   return true;
3679 }
3680 
3681 // SelectDirectAddr - Match a direct address for DAG.
3682 // A direct address could be a globaladdress or externalsymbol.
SelectDirectAddr(SDValue N,SDValue & Address)3683 bool NVPTXDAGToDAGISel::SelectDirectAddr(SDValue N, SDValue &Address) {
3684   // Return true if TGA or ES.
3685   if (N.getOpcode() == ISD::TargetGlobalAddress ||
3686       N.getOpcode() == ISD::TargetExternalSymbol) {
3687     Address = N;
3688     return true;
3689   }
3690   if (N.getOpcode() == NVPTXISD::Wrapper) {
3691     Address = N.getOperand(0);
3692     return true;
3693   }
3694   // addrspacecast(MoveParam(arg_symbol) to addrspace(PARAM)) -> arg_symbol
3695   if (AddrSpaceCastSDNode *CastN = dyn_cast<AddrSpaceCastSDNode>(N)) {
3696     if (CastN->getSrcAddressSpace() == ADDRESS_SPACE_GENERIC &&
3697         CastN->getDestAddressSpace() == ADDRESS_SPACE_PARAM &&
3698         CastN->getOperand(0).getOpcode() == NVPTXISD::MoveParam)
3699       return SelectDirectAddr(CastN->getOperand(0).getOperand(0), Address);
3700   }
3701   return false;
3702 }
3703 
3704 // symbol+offset
SelectADDRsi_imp(SDNode * OpNode,SDValue Addr,SDValue & Base,SDValue & Offset,MVT mvt)3705 bool NVPTXDAGToDAGISel::SelectADDRsi_imp(
3706     SDNode *OpNode, SDValue Addr, SDValue &Base, SDValue &Offset, MVT mvt) {
3707   if (Addr.getOpcode() == ISD::ADD) {
3708     if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(Addr.getOperand(1))) {
3709       SDValue base = Addr.getOperand(0);
3710       if (SelectDirectAddr(base, Base)) {
3711         Offset = CurDAG->getTargetConstant(CN->getZExtValue(), SDLoc(OpNode),
3712                                            mvt);
3713         return true;
3714       }
3715     }
3716   }
3717   return false;
3718 }
3719 
3720 // symbol+offset
SelectADDRsi(SDNode * OpNode,SDValue Addr,SDValue & Base,SDValue & Offset)3721 bool NVPTXDAGToDAGISel::SelectADDRsi(SDNode *OpNode, SDValue Addr,
3722                                      SDValue &Base, SDValue &Offset) {
3723   return SelectADDRsi_imp(OpNode, Addr, Base, Offset, MVT::i32);
3724 }
3725 
3726 // symbol+offset
SelectADDRsi64(SDNode * OpNode,SDValue Addr,SDValue & Base,SDValue & Offset)3727 bool NVPTXDAGToDAGISel::SelectADDRsi64(SDNode *OpNode, SDValue Addr,
3728                                        SDValue &Base, SDValue &Offset) {
3729   return SelectADDRsi_imp(OpNode, Addr, Base, Offset, MVT::i64);
3730 }
3731 
3732 // register+offset
SelectADDRri_imp(SDNode * OpNode,SDValue Addr,SDValue & Base,SDValue & Offset,MVT mvt)3733 bool NVPTXDAGToDAGISel::SelectADDRri_imp(
3734     SDNode *OpNode, SDValue Addr, SDValue &Base, SDValue &Offset, MVT mvt) {
3735   if (FrameIndexSDNode *FIN = dyn_cast<FrameIndexSDNode>(Addr)) {
3736     Base = CurDAG->getTargetFrameIndex(FIN->getIndex(), mvt);
3737     Offset = CurDAG->getTargetConstant(0, SDLoc(OpNode), mvt);
3738     return true;
3739   }
3740   if (Addr.getOpcode() == ISD::TargetExternalSymbol ||
3741       Addr.getOpcode() == ISD::TargetGlobalAddress)
3742     return false; // direct calls.
3743 
3744   if (Addr.getOpcode() == ISD::ADD) {
3745     if (SelectDirectAddr(Addr.getOperand(0), Addr)) {
3746       return false;
3747     }
3748     if (ConstantSDNode *CN = dyn_cast<ConstantSDNode>(Addr.getOperand(1))) {
3749       if (FrameIndexSDNode *FIN =
3750               dyn_cast<FrameIndexSDNode>(Addr.getOperand(0)))
3751         // Constant offset from frame ref.
3752         Base = CurDAG->getTargetFrameIndex(FIN->getIndex(), mvt);
3753       else
3754         Base = Addr.getOperand(0);
3755 
3756       // Offset must fit in a 32-bit signed int in PTX [register+offset] address
3757       // mode
3758       if (!CN->getAPIntValue().isSignedIntN(32))
3759         return false;
3760 
3761       Offset = CurDAG->getTargetConstant(CN->getSExtValue(), SDLoc(OpNode),
3762                                          MVT::i32);
3763       return true;
3764     }
3765   }
3766   return false;
3767 }
3768 
3769 // register+offset
SelectADDRri(SDNode * OpNode,SDValue Addr,SDValue & Base,SDValue & Offset)3770 bool NVPTXDAGToDAGISel::SelectADDRri(SDNode *OpNode, SDValue Addr,
3771                                      SDValue &Base, SDValue &Offset) {
3772   return SelectADDRri_imp(OpNode, Addr, Base, Offset, MVT::i32);
3773 }
3774 
3775 // register+offset
SelectADDRri64(SDNode * OpNode,SDValue Addr,SDValue & Base,SDValue & Offset)3776 bool NVPTXDAGToDAGISel::SelectADDRri64(SDNode *OpNode, SDValue Addr,
3777                                        SDValue &Base, SDValue &Offset) {
3778   return SelectADDRri_imp(OpNode, Addr, Base, Offset, MVT::i64);
3779 }
3780 
ChkMemSDNodeAddressSpace(SDNode * N,unsigned int spN) const3781 bool NVPTXDAGToDAGISel::ChkMemSDNodeAddressSpace(SDNode *N,
3782                                                  unsigned int spN) const {
3783   const Value *Src = nullptr;
3784   if (MemSDNode *mN = dyn_cast<MemSDNode>(N)) {
3785     if (spN == 0 && mN->getMemOperand()->getPseudoValue())
3786       return true;
3787     Src = mN->getMemOperand()->getValue();
3788   }
3789   if (!Src)
3790     return false;
3791   if (auto *PT = dyn_cast<PointerType>(Src->getType()))
3792     return (PT->getAddressSpace() == spN);
3793   return false;
3794 }
3795 
3796 /// SelectInlineAsmMemoryOperand - Implement addressing mode selection for
3797 /// inline asm expressions.
SelectInlineAsmMemoryOperand(const SDValue & Op,InlineAsm::ConstraintCode ConstraintID,std::vector<SDValue> & OutOps)3798 bool NVPTXDAGToDAGISel::SelectInlineAsmMemoryOperand(
3799     const SDValue &Op, InlineAsm::ConstraintCode ConstraintID,
3800     std::vector<SDValue> &OutOps) {
3801   SDValue Op0, Op1;
3802   switch (ConstraintID) {
3803   default:
3804     return true;
3805   case InlineAsm::ConstraintCode::m: // memory
3806     if (SelectDirectAddr(Op, Op0)) {
3807       OutOps.push_back(Op0);
3808       OutOps.push_back(CurDAG->getTargetConstant(0, SDLoc(Op), MVT::i32));
3809       return false;
3810     }
3811     if (SelectADDRri(Op.getNode(), Op, Op0, Op1)) {
3812       OutOps.push_back(Op0);
3813       OutOps.push_back(Op1);
3814       return false;
3815     }
3816     break;
3817   }
3818   return true;
3819 }
3820 
SelectV2I64toI128(SDNode * N)3821 void NVPTXDAGToDAGISel::SelectV2I64toI128(SDNode *N) {
3822   // Lower a CopyToReg with two 64-bit inputs
3823   // Dst:i128, lo:i64, hi:i64
3824   //
3825   // CopyToReg Dst, lo, hi;
3826   //
3827   // ==>
3828   //
3829   // tmp = V2I64toI128 {lo, hi};
3830   // CopyToReg Dst, tmp;
3831   SDValue Dst = N->getOperand(1);
3832   SDValue Lo = N->getOperand(2);
3833   SDValue Hi = N->getOperand(3);
3834 
3835   SDLoc DL(N);
3836   SDNode *Mov =
3837       CurDAG->getMachineNode(NVPTX::V2I64toI128, DL, MVT::i128, {Lo, Hi});
3838 
3839   SmallVector<SDValue, 4> NewOps(N->getNumOperands() - 1);
3840   NewOps[0] = N->getOperand(0);
3841   NewOps[1] = Dst;
3842   NewOps[2] = SDValue(Mov, 0);
3843   if (N->getNumOperands() == 5)
3844     NewOps[3] = N->getOperand(4);
3845   SDValue NewValue = CurDAG->getNode(ISD::CopyToReg, DL, SmallVector<EVT>(N->values()), NewOps);
3846 
3847   ReplaceNode(N, NewValue.getNode());
3848 }
3849 
SelectI128toV2I64(SDNode * N)3850 void NVPTXDAGToDAGISel::SelectI128toV2I64(SDNode *N) {
3851   // Lower CopyFromReg from a 128-bit regs to two 64-bit regs
3852   // Dst:i128, Src:i128
3853   //
3854   // {lo, hi} = CopyFromReg Src
3855   //
3856   // ==>
3857   //
3858   // {lo, hi} = I128toV2I64 Src
3859   //
3860   SDValue Ch = N->getOperand(0);
3861   SDValue Src = N->getOperand(1);
3862   SDValue Glue = N->getOperand(2);
3863   SDLoc DL(N);
3864 
3865   // Add Glue and Ch to the operands and results to avoid break the execution
3866   // order
3867   SDNode *Mov = CurDAG->getMachineNode(
3868       NVPTX::I128toV2I64, DL,
3869       {MVT::i64, MVT::i64, Ch.getValueType(), Glue.getValueType()},
3870       {Src, Ch, Glue});
3871 
3872   ReplaceNode(N, Mov);
3873 }
3874 
3875 /// GetConvertOpcode - Returns the CVT_ instruction opcode that implements a
3876 /// conversion from \p SrcTy to \p DestTy.
GetConvertOpcode(MVT DestTy,MVT SrcTy,LoadSDNode * LdNode)3877 unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,
3878                                              LoadSDNode *LdNode) {
3879   bool IsSigned = LdNode && LdNode->getExtensionType() == ISD::SEXTLOAD;
3880   switch (SrcTy.SimpleTy) {
3881   default:
3882     llvm_unreachable("Unhandled source type");
3883   case MVT::i8:
3884     switch (DestTy.SimpleTy) {
3885     default:
3886       llvm_unreachable("Unhandled dest type");
3887     case MVT::i16:
3888       return IsSigned ? NVPTX::CVT_s16_s8 : NVPTX::CVT_u16_u8;
3889     case MVT::i32:
3890       return IsSigned ? NVPTX::CVT_s32_s8 : NVPTX::CVT_u32_u8;
3891     case MVT::i64:
3892       return IsSigned ? NVPTX::CVT_s64_s8 : NVPTX::CVT_u64_u8;
3893     }
3894   case MVT::i16:
3895     switch (DestTy.SimpleTy) {
3896     default:
3897       llvm_unreachable("Unhandled dest type");
3898     case MVT::i8:
3899       return IsSigned ? NVPTX::CVT_s8_s16 : NVPTX::CVT_u8_u16;
3900     case MVT::i32:
3901       return IsSigned ? NVPTX::CVT_s32_s16 : NVPTX::CVT_u32_u16;
3902     case MVT::i64:
3903       return IsSigned ? NVPTX::CVT_s64_s16 : NVPTX::CVT_u64_u16;
3904     }
3905   case MVT::i32:
3906     switch (DestTy.SimpleTy) {
3907     default:
3908       llvm_unreachable("Unhandled dest type");
3909     case MVT::i8:
3910       return IsSigned ? NVPTX::CVT_s8_s32 : NVPTX::CVT_u8_u32;
3911     case MVT::i16:
3912       return IsSigned ? NVPTX::CVT_s16_s32 : NVPTX::CVT_u16_u32;
3913     case MVT::i64:
3914       return IsSigned ? NVPTX::CVT_s64_s32 : NVPTX::CVT_u64_u32;
3915     }
3916   case MVT::i64:
3917     switch (DestTy.SimpleTy) {
3918     default:
3919       llvm_unreachable("Unhandled dest type");
3920     case MVT::i8:
3921       return IsSigned ? NVPTX::CVT_s8_s64 : NVPTX::CVT_u8_u64;
3922     case MVT::i16:
3923       return IsSigned ? NVPTX::CVT_s16_s64 : NVPTX::CVT_u16_u64;
3924     case MVT::i32:
3925       return IsSigned ? NVPTX::CVT_s32_s64 : NVPTX::CVT_u32_u64;
3926     }
3927   case MVT::f16:
3928     switch (DestTy.SimpleTy) {
3929     default:
3930       llvm_unreachable("Unhandled dest type");
3931     case MVT::f32:
3932       return NVPTX::CVT_f32_f16;
3933     case MVT::f64:
3934       return NVPTX::CVT_f64_f16;
3935     }
3936   }
3937 }
3938