xref: /freebsd/contrib/llvm-project/llvm/lib/Target/RISCV/GISel/RISCVLegalizerInfo.cpp (revision 3ceba58a7509418b47b8fca2d2b6bbf088714e26)
1 //===-- RISCVLegalizerInfo.cpp ----------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file implements the targeting of the Machinelegalizer class for RISC-V.
10 /// \todo This should be generated by TableGen.
11 //===----------------------------------------------------------------------===//
12 
13 #include "RISCVLegalizerInfo.h"
14 #include "MCTargetDesc/RISCVMatInt.h"
15 #include "RISCVMachineFunctionInfo.h"
16 #include "RISCVSubtarget.h"
17 #include "llvm/CodeGen/GlobalISel/GIMatchTableExecutor.h"
18 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h"
19 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
20 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
21 #include "llvm/CodeGen/MachineConstantPool.h"
22 #include "llvm/CodeGen/MachineRegisterInfo.h"
23 #include "llvm/CodeGen/TargetOpcodes.h"
24 #include "llvm/CodeGen/ValueTypes.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Type.h"
27 
28 using namespace llvm;
29 using namespace LegalityPredicates;
30 using namespace LegalizeMutations;
31 
32 // Is this type supported by scalar FP arithmetic operations given the current
33 // subtarget.
34 static LegalityPredicate typeIsScalarFPArith(unsigned TypeIdx,
35                                              const RISCVSubtarget &ST) {
36   return [=, &ST](const LegalityQuery &Query) {
37     return Query.Types[TypeIdx].isScalar() &&
38            ((ST.hasStdExtZfh() && Query.Types[TypeIdx].getSizeInBits() == 16) ||
39             (ST.hasStdExtF() && Query.Types[TypeIdx].getSizeInBits() == 32) ||
40             (ST.hasStdExtD() && Query.Types[TypeIdx].getSizeInBits() == 64));
41   };
42 }
43 
44 static LegalityPredicate
45 typeIsLegalIntOrFPVec(unsigned TypeIdx,
46                       std::initializer_list<LLT> IntOrFPVecTys,
47                       const RISCVSubtarget &ST) {
48   LegalityPredicate P = [=, &ST](const LegalityQuery &Query) {
49     return ST.hasVInstructions() &&
50            (Query.Types[TypeIdx].getScalarSizeInBits() != 64 ||
51             ST.hasVInstructionsI64()) &&
52            (Query.Types[TypeIdx].getElementCount().getKnownMinValue() != 1 ||
53             ST.getELen() == 64);
54   };
55 
56   return all(typeInSet(TypeIdx, IntOrFPVecTys), P);
57 }
58 
59 static LegalityPredicate
60 typeIsLegalBoolVec(unsigned TypeIdx, std::initializer_list<LLT> BoolVecTys,
61                    const RISCVSubtarget &ST) {
62   LegalityPredicate P = [=, &ST](const LegalityQuery &Query) {
63     return ST.hasVInstructions() &&
64            (Query.Types[TypeIdx].getElementCount().getKnownMinValue() != 1 ||
65             ST.getELen() == 64);
66   };
67   return all(typeInSet(TypeIdx, BoolVecTys), P);
68 }
69 
70 RISCVLegalizerInfo::RISCVLegalizerInfo(const RISCVSubtarget &ST)
71     : STI(ST), XLen(STI.getXLen()), sXLen(LLT::scalar(XLen)) {
72   const LLT sDoubleXLen = LLT::scalar(2 * XLen);
73   const LLT p0 = LLT::pointer(0, XLen);
74   const LLT s1 = LLT::scalar(1);
75   const LLT s8 = LLT::scalar(8);
76   const LLT s16 = LLT::scalar(16);
77   const LLT s32 = LLT::scalar(32);
78   const LLT s64 = LLT::scalar(64);
79 
80   const LLT nxv1s1 = LLT::scalable_vector(1, s1);
81   const LLT nxv2s1 = LLT::scalable_vector(2, s1);
82   const LLT nxv4s1 = LLT::scalable_vector(4, s1);
83   const LLT nxv8s1 = LLT::scalable_vector(8, s1);
84   const LLT nxv16s1 = LLT::scalable_vector(16, s1);
85   const LLT nxv32s1 = LLT::scalable_vector(32, s1);
86   const LLT nxv64s1 = LLT::scalable_vector(64, s1);
87 
88   const LLT nxv1s8 = LLT::scalable_vector(1, s8);
89   const LLT nxv2s8 = LLT::scalable_vector(2, s8);
90   const LLT nxv4s8 = LLT::scalable_vector(4, s8);
91   const LLT nxv8s8 = LLT::scalable_vector(8, s8);
92   const LLT nxv16s8 = LLT::scalable_vector(16, s8);
93   const LLT nxv32s8 = LLT::scalable_vector(32, s8);
94   const LLT nxv64s8 = LLT::scalable_vector(64, s8);
95 
96   const LLT nxv1s16 = LLT::scalable_vector(1, s16);
97   const LLT nxv2s16 = LLT::scalable_vector(2, s16);
98   const LLT nxv4s16 = LLT::scalable_vector(4, s16);
99   const LLT nxv8s16 = LLT::scalable_vector(8, s16);
100   const LLT nxv16s16 = LLT::scalable_vector(16, s16);
101   const LLT nxv32s16 = LLT::scalable_vector(32, s16);
102 
103   const LLT nxv1s32 = LLT::scalable_vector(1, s32);
104   const LLT nxv2s32 = LLT::scalable_vector(2, s32);
105   const LLT nxv4s32 = LLT::scalable_vector(4, s32);
106   const LLT nxv8s32 = LLT::scalable_vector(8, s32);
107   const LLT nxv16s32 = LLT::scalable_vector(16, s32);
108 
109   const LLT nxv1s64 = LLT::scalable_vector(1, s64);
110   const LLT nxv2s64 = LLT::scalable_vector(2, s64);
111   const LLT nxv4s64 = LLT::scalable_vector(4, s64);
112   const LLT nxv8s64 = LLT::scalable_vector(8, s64);
113 
114   using namespace TargetOpcode;
115 
116   auto BoolVecTys = {nxv1s1, nxv2s1, nxv4s1, nxv8s1, nxv16s1, nxv32s1, nxv64s1};
117 
118   auto IntOrFPVecTys = {nxv1s8,   nxv2s8,  nxv4s8,  nxv8s8,  nxv16s8, nxv32s8,
119                         nxv64s8,  nxv1s16, nxv2s16, nxv4s16, nxv8s16, nxv16s16,
120                         nxv32s16, nxv1s32, nxv2s32, nxv4s32, nxv8s32, nxv16s32,
121                         nxv1s64,  nxv2s64, nxv4s64, nxv8s64};
122 
123   getActionDefinitionsBuilder({G_ADD, G_SUB, G_AND, G_OR, G_XOR})
124       .legalFor({s32, sXLen})
125       .legalIf(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST))
126       .widenScalarToNextPow2(0)
127       .clampScalar(0, s32, sXLen);
128 
129   getActionDefinitionsBuilder(
130       {G_UADDE, G_UADDO, G_USUBE, G_USUBO}).lower();
131 
132   getActionDefinitionsBuilder({G_SADDO, G_SSUBO}).minScalar(0, sXLen).lower();
133 
134   // TODO: Use Vector Single-Width Saturating Instructions for vector types.
135   getActionDefinitionsBuilder({G_UADDSAT, G_SADDSAT, G_USUBSAT, G_SSUBSAT})
136       .lower();
137 
138   auto &ShiftActions = getActionDefinitionsBuilder({G_ASHR, G_LSHR, G_SHL});
139   if (ST.is64Bit())
140     ShiftActions.customFor({{s32, s32}});
141   ShiftActions.legalFor({{s32, s32}, {s32, sXLen}, {sXLen, sXLen}})
142       .widenScalarToNextPow2(0)
143       .clampScalar(1, s32, sXLen)
144       .clampScalar(0, s32, sXLen)
145       .minScalarSameAs(1, 0)
146       .widenScalarToNextPow2(1);
147 
148   auto &ExtActions =
149       getActionDefinitionsBuilder({G_ZEXT, G_SEXT, G_ANYEXT})
150           .legalIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
151                        typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST)));
152   if (ST.is64Bit()) {
153     ExtActions.legalFor({{sXLen, s32}});
154     getActionDefinitionsBuilder(G_SEXT_INREG)
155         .customFor({sXLen})
156         .maxScalar(0, sXLen)
157         .lower();
158   } else {
159     getActionDefinitionsBuilder(G_SEXT_INREG).maxScalar(0, sXLen).lower();
160   }
161   ExtActions.customIf(typeIsLegalBoolVec(1, BoolVecTys, ST))
162       .maxScalar(0, sXLen);
163 
164   // Merge/Unmerge
165   for (unsigned Op : {G_MERGE_VALUES, G_UNMERGE_VALUES}) {
166     auto &MergeUnmergeActions = getActionDefinitionsBuilder(Op);
167     unsigned BigTyIdx = Op == G_MERGE_VALUES ? 0 : 1;
168     unsigned LitTyIdx = Op == G_MERGE_VALUES ? 1 : 0;
169     if (XLen == 32 && ST.hasStdExtD()) {
170       MergeUnmergeActions.legalIf(
171           all(typeIs(BigTyIdx, s64), typeIs(LitTyIdx, s32)));
172     }
173     MergeUnmergeActions.widenScalarToNextPow2(LitTyIdx, XLen)
174         .widenScalarToNextPow2(BigTyIdx, XLen)
175         .clampScalar(LitTyIdx, sXLen, sXLen)
176         .clampScalar(BigTyIdx, sXLen, sXLen);
177   }
178 
179   getActionDefinitionsBuilder({G_FSHL, G_FSHR}).lower();
180 
181   auto &RotateActions = getActionDefinitionsBuilder({G_ROTL, G_ROTR});
182   if (ST.hasStdExtZbb() || ST.hasStdExtZbkb()) {
183     RotateActions.legalFor({{s32, sXLen}, {sXLen, sXLen}});
184     // Widen s32 rotate amount to s64 so SDAG patterns will match.
185     if (ST.is64Bit())
186       RotateActions.widenScalarIf(all(typeIs(0, s32), typeIs(1, s32)),
187                                   changeTo(1, sXLen));
188   }
189   RotateActions.lower();
190 
191   getActionDefinitionsBuilder(G_BITREVERSE).maxScalar(0, sXLen).lower();
192 
193   getActionDefinitionsBuilder(G_BITCAST).legalIf(
194       all(LegalityPredicates::any(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
195                                   typeIsLegalBoolVec(0, BoolVecTys, ST)),
196           LegalityPredicates::any(typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST),
197                                   typeIsLegalBoolVec(1, BoolVecTys, ST))));
198 
199   auto &BSWAPActions = getActionDefinitionsBuilder(G_BSWAP);
200   if (ST.hasStdExtZbb() || ST.hasStdExtZbkb())
201     BSWAPActions.legalFor({sXLen}).clampScalar(0, sXLen, sXLen);
202   else
203     BSWAPActions.maxScalar(0, sXLen).lower();
204 
205   auto &CountZerosActions = getActionDefinitionsBuilder({G_CTLZ, G_CTTZ});
206   auto &CountZerosUndefActions =
207       getActionDefinitionsBuilder({G_CTLZ_ZERO_UNDEF, G_CTTZ_ZERO_UNDEF});
208   if (ST.hasStdExtZbb()) {
209     CountZerosActions.legalFor({{s32, s32}, {sXLen, sXLen}})
210         .clampScalar(0, s32, sXLen)
211         .widenScalarToNextPow2(0)
212         .scalarSameSizeAs(1, 0);
213   } else {
214     CountZerosActions.maxScalar(0, sXLen).scalarSameSizeAs(1, 0).lower();
215     CountZerosUndefActions.maxScalar(0, sXLen).scalarSameSizeAs(1, 0);
216   }
217   CountZerosUndefActions.lower();
218 
219   auto &CTPOPActions = getActionDefinitionsBuilder(G_CTPOP);
220   if (ST.hasStdExtZbb()) {
221     CTPOPActions.legalFor({{s32, s32}, {sXLen, sXLen}})
222         .clampScalar(0, s32, sXLen)
223         .widenScalarToNextPow2(0)
224         .scalarSameSizeAs(1, 0);
225   } else {
226     CTPOPActions.maxScalar(0, sXLen).scalarSameSizeAs(1, 0).lower();
227   }
228 
229   auto &ConstantActions = getActionDefinitionsBuilder(G_CONSTANT);
230   ConstantActions.legalFor({s32, p0});
231   if (ST.is64Bit())
232     ConstantActions.customFor({s64});
233   ConstantActions.widenScalarToNextPow2(0).clampScalar(0, s32, sXLen);
234 
235   // TODO: transform illegal vector types into legal vector type
236   getActionDefinitionsBuilder(
237       {G_IMPLICIT_DEF, G_CONSTANT_FOLD_BARRIER, G_FREEZE})
238       .legalFor({s32, sXLen, p0})
239       .legalIf(typeIsLegalBoolVec(0, BoolVecTys, ST))
240       .legalIf(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST))
241       .widenScalarToNextPow2(0)
242       .clampScalar(0, s32, sXLen);
243 
244   getActionDefinitionsBuilder(G_ICMP)
245       .legalFor({{sXLen, sXLen}, {sXLen, p0}})
246       .legalIf(all(typeIsLegalBoolVec(0, BoolVecTys, ST),
247                    typeIsLegalIntOrFPVec(1, IntOrFPVecTys, ST)))
248       .widenScalarOrEltToNextPow2OrMinSize(1, 8)
249       .clampScalar(1, sXLen, sXLen)
250       .clampScalar(0, sXLen, sXLen);
251 
252   auto &SelectActions =
253       getActionDefinitionsBuilder(G_SELECT)
254           .legalFor({{s32, sXLen}, {p0, sXLen}})
255           .legalIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
256                        typeIsLegalBoolVec(1, BoolVecTys, ST)));
257   if (XLen == 64 || ST.hasStdExtD())
258     SelectActions.legalFor({{s64, sXLen}});
259   SelectActions.widenScalarToNextPow2(0)
260       .clampScalar(0, s32, (XLen == 64 || ST.hasStdExtD()) ? s64 : s32)
261       .clampScalar(1, sXLen, sXLen);
262 
263   auto &LoadStoreActions =
264       getActionDefinitionsBuilder({G_LOAD, G_STORE})
265           .legalForTypesWithMemDesc({{s32, p0, s8, 8},
266                                      {s32, p0, s16, 16},
267                                      {s32, p0, s32, 32},
268                                      {p0, p0, sXLen, XLen}});
269   auto &ExtLoadActions =
270       getActionDefinitionsBuilder({G_SEXTLOAD, G_ZEXTLOAD})
271           .legalForTypesWithMemDesc({{s32, p0, s8, 8}, {s32, p0, s16, 16}});
272   if (XLen == 64) {
273     LoadStoreActions.legalForTypesWithMemDesc({{s64, p0, s8, 8},
274                                                {s64, p0, s16, 16},
275                                                {s64, p0, s32, 32},
276                                                {s64, p0, s64, 64}});
277     ExtLoadActions.legalForTypesWithMemDesc(
278         {{s64, p0, s8, 8}, {s64, p0, s16, 16}, {s64, p0, s32, 32}});
279   } else if (ST.hasStdExtD()) {
280     LoadStoreActions.legalForTypesWithMemDesc({{s64, p0, s64, 64}});
281   }
282   LoadStoreActions.clampScalar(0, s32, sXLen).lower();
283   ExtLoadActions.widenScalarToNextPow2(0).clampScalar(0, s32, sXLen).lower();
284 
285   getActionDefinitionsBuilder({G_PTR_ADD, G_PTRMASK}).legalFor({{p0, sXLen}});
286 
287   getActionDefinitionsBuilder(G_PTRTOINT)
288       .legalFor({{sXLen, p0}})
289       .clampScalar(0, sXLen, sXLen);
290 
291   getActionDefinitionsBuilder(G_INTTOPTR)
292       .legalFor({{p0, sXLen}})
293       .clampScalar(1, sXLen, sXLen);
294 
295   getActionDefinitionsBuilder(G_BRCOND).legalFor({sXLen}).minScalar(0, sXLen);
296 
297   getActionDefinitionsBuilder(G_BRJT).legalFor({{p0, sXLen}});
298 
299   getActionDefinitionsBuilder(G_BRINDIRECT).legalFor({p0});
300 
301   getActionDefinitionsBuilder(G_PHI)
302       .legalFor({p0, sXLen})
303       .widenScalarToNextPow2(0)
304       .clampScalar(0, sXLen, sXLen);
305 
306   getActionDefinitionsBuilder({G_GLOBAL_VALUE, G_JUMP_TABLE, G_CONSTANT_POOL})
307       .legalFor({p0});
308 
309   if (ST.hasStdExtZmmul()) {
310     getActionDefinitionsBuilder(G_MUL)
311         .legalFor({s32, sXLen})
312         .widenScalarToNextPow2(0)
313         .clampScalar(0, s32, sXLen);
314 
315     // clang-format off
316     getActionDefinitionsBuilder({G_SMULH, G_UMULH})
317         .legalFor({sXLen})
318         .lower();
319     // clang-format on
320 
321     getActionDefinitionsBuilder({G_SMULO, G_UMULO}).minScalar(0, sXLen).lower();
322   } else {
323     getActionDefinitionsBuilder(G_MUL)
324         .libcallFor({sXLen, sDoubleXLen})
325         .widenScalarToNextPow2(0)
326         .clampScalar(0, sXLen, sDoubleXLen);
327 
328     getActionDefinitionsBuilder({G_SMULH, G_UMULH}).lowerFor({sXLen});
329 
330     getActionDefinitionsBuilder({G_SMULO, G_UMULO})
331         .minScalar(0, sXLen)
332         // Widen sXLen to sDoubleXLen so we can use a single libcall to get
333         // the low bits for the mul result and high bits to do the overflow
334         // check.
335         .widenScalarIf(typeIs(0, sXLen),
336                        LegalizeMutations::changeTo(0, sDoubleXLen))
337         .lower();
338   }
339 
340   if (ST.hasStdExtM()) {
341     getActionDefinitionsBuilder({G_UDIV, G_SDIV, G_UREM, G_SREM})
342         .legalFor({s32, sXLen})
343         .libcallFor({sDoubleXLen})
344         .clampScalar(0, s32, sDoubleXLen)
345         .widenScalarToNextPow2(0);
346   } else {
347     getActionDefinitionsBuilder({G_UDIV, G_SDIV, G_UREM, G_SREM})
348         .libcallFor({sXLen, sDoubleXLen})
349         .clampScalar(0, sXLen, sDoubleXLen)
350         .widenScalarToNextPow2(0);
351   }
352 
353   // TODO: Use libcall for sDoubleXLen.
354   getActionDefinitionsBuilder({G_UDIVREM, G_SDIVREM}).lower();
355 
356   auto &AbsActions = getActionDefinitionsBuilder(G_ABS);
357   if (ST.hasStdExtZbb())
358     AbsActions.customFor({s32, sXLen}).minScalar(0, sXLen);
359   AbsActions.lower();
360 
361   auto &MinMaxActions =
362       getActionDefinitionsBuilder({G_UMAX, G_UMIN, G_SMAX, G_SMIN});
363   if (ST.hasStdExtZbb())
364     MinMaxActions.legalFor({sXLen}).minScalar(0, sXLen);
365   MinMaxActions.lower();
366 
367   getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
368 
369   getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE, G_MEMSET}).libcall();
370 
371   getActionDefinitionsBuilder(G_DYN_STACKALLOC).lower();
372 
373   // FP Operations
374 
375   getActionDefinitionsBuilder({G_FADD, G_FSUB, G_FMUL, G_FDIV, G_FMA, G_FNEG,
376                                G_FABS, G_FSQRT, G_FMAXNUM, G_FMINNUM})
377       .legalIf(typeIsScalarFPArith(0, ST));
378 
379   getActionDefinitionsBuilder(G_FREM)
380       .libcallFor({s32, s64})
381       .minScalar(0, s32)
382       .scalarize(0);
383 
384   getActionDefinitionsBuilder(G_FCOPYSIGN)
385       .legalIf(all(typeIsScalarFPArith(0, ST), typeIsScalarFPArith(1, ST)));
386 
387   // FIXME: Use Zfhmin.
388   getActionDefinitionsBuilder(G_FPTRUNC).legalIf(
389       [=, &ST](const LegalityQuery &Query) -> bool {
390         return (ST.hasStdExtD() && typeIs(0, s32)(Query) &&
391                 typeIs(1, s64)(Query)) ||
392                (ST.hasStdExtZfh() && typeIs(0, s16)(Query) &&
393                 typeIs(1, s32)(Query)) ||
394                (ST.hasStdExtZfh() && ST.hasStdExtD() && typeIs(0, s16)(Query) &&
395                 typeIs(1, s64)(Query));
396       });
397   getActionDefinitionsBuilder(G_FPEXT).legalIf(
398       [=, &ST](const LegalityQuery &Query) -> bool {
399         return (ST.hasStdExtD() && typeIs(0, s64)(Query) &&
400                 typeIs(1, s32)(Query)) ||
401                (ST.hasStdExtZfh() && typeIs(0, s32)(Query) &&
402                 typeIs(1, s16)(Query)) ||
403                (ST.hasStdExtZfh() && ST.hasStdExtD() && typeIs(0, s64)(Query) &&
404                 typeIs(1, s16)(Query));
405       });
406 
407   getActionDefinitionsBuilder(G_FCMP)
408       .legalIf(all(typeIs(0, sXLen), typeIsScalarFPArith(1, ST)))
409       .clampScalar(0, sXLen, sXLen);
410 
411   // TODO: Support vector version of G_IS_FPCLASS.
412   getActionDefinitionsBuilder(G_IS_FPCLASS)
413       .customIf(all(typeIs(0, s1), typeIsScalarFPArith(1, ST)));
414 
415   getActionDefinitionsBuilder(G_FCONSTANT)
416       .legalIf(typeIsScalarFPArith(0, ST))
417       .lowerFor({s32, s64});
418 
419   getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
420       .legalIf(all(typeInSet(0, {s32, sXLen}), typeIsScalarFPArith(1, ST)))
421       .widenScalarToNextPow2(0)
422       .clampScalar(0, s32, sXLen)
423       .libcall();
424 
425   getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
426       .legalIf(all(typeIsScalarFPArith(0, ST), typeInSet(1, {s32, sXLen})))
427       .widenScalarToNextPow2(1)
428       .clampScalar(1, s32, sXLen);
429 
430   // FIXME: We can do custom inline expansion like SelectionDAG.
431   // FIXME: Legal with Zfa.
432   getActionDefinitionsBuilder({G_FCEIL, G_FFLOOR})
433       .libcallFor({s32, s64});
434 
435   getActionDefinitionsBuilder(G_VASTART).customFor({p0});
436 
437   // va_list must be a pointer, but most sized types are pretty easy to handle
438   // as the destination.
439   getActionDefinitionsBuilder(G_VAARG)
440       // TODO: Implement narrowScalar and widenScalar for G_VAARG for types
441       // outside the [s32, sXLen] range.
442       .clampScalar(0, s32, sXLen)
443       .lowerForCartesianProduct({s32, sXLen, p0}, {p0});
444 
445   getActionDefinitionsBuilder(G_VSCALE)
446       .clampScalar(0, sXLen, sXLen)
447       .customFor({sXLen});
448 
449   auto &SplatActions =
450       getActionDefinitionsBuilder(G_SPLAT_VECTOR)
451           .legalIf(all(typeIsLegalIntOrFPVec(0, IntOrFPVecTys, ST),
452                        typeIs(1, sXLen)))
453           .customIf(all(typeIsLegalBoolVec(0, BoolVecTys, ST), typeIs(1, s1)));
454   // Handle case of s64 element vectors on RV32. If the subtarget does not have
455   // f64, then try to lower it to G_SPLAT_VECTOR_SPLIT_64_VL. If the subtarget
456   // does have f64, then we don't know whether the type is an f64 or an i64,
457   // so mark the G_SPLAT_VECTOR as legal and decide later what to do with it,
458   // depending on how the instructions it consumes are legalized. They are not
459   // legalized yet since legalization is in reverse postorder, so we cannot
460   // make the decision at this moment.
461   if (XLen == 32) {
462     if (ST.hasVInstructionsF64() && ST.hasStdExtD())
463       SplatActions.legalIf(all(
464           typeInSet(0, {nxv1s64, nxv2s64, nxv4s64, nxv8s64}), typeIs(1, s64)));
465     else if (ST.hasVInstructionsI64())
466       SplatActions.customIf(all(
467           typeInSet(0, {nxv1s64, nxv2s64, nxv4s64, nxv8s64}), typeIs(1, s64)));
468   }
469 
470   SplatActions.clampScalar(1, sXLen, sXLen);
471 
472   getLegacyLegalizerInfo().computeTables();
473 }
474 
475 bool RISCVLegalizerInfo::legalizeIntrinsic(LegalizerHelper &Helper,
476                                            MachineInstr &MI) const {
477   Intrinsic::ID IntrinsicID = cast<GIntrinsic>(MI).getIntrinsicID();
478   switch (IntrinsicID) {
479   default:
480     return false;
481   case Intrinsic::vacopy: {
482     // vacopy arguments must be legal because of the intrinsic signature.
483     // No need to check here.
484 
485     MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
486     MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
487     MachineFunction &MF = *MI.getMF();
488     const DataLayout &DL = MIRBuilder.getDataLayout();
489     LLVMContext &Ctx = MF.getFunction().getContext();
490 
491     Register DstLst = MI.getOperand(1).getReg();
492     LLT PtrTy = MRI.getType(DstLst);
493 
494     // Load the source va_list
495     Align Alignment = DL.getABITypeAlign(getTypeForLLT(PtrTy, Ctx));
496     MachineMemOperand *LoadMMO = MF.getMachineMemOperand(
497         MachinePointerInfo(), MachineMemOperand::MOLoad, PtrTy, Alignment);
498     auto Tmp = MIRBuilder.buildLoad(PtrTy, MI.getOperand(2), *LoadMMO);
499 
500     // Store the result in the destination va_list
501     MachineMemOperand *StoreMMO = MF.getMachineMemOperand(
502         MachinePointerInfo(), MachineMemOperand::MOStore, PtrTy, Alignment);
503     MIRBuilder.buildStore(Tmp, DstLst, *StoreMMO);
504 
505     MI.eraseFromParent();
506     return true;
507   }
508   }
509 }
510 
511 bool RISCVLegalizerInfo::legalizeShlAshrLshr(
512     MachineInstr &MI, MachineIRBuilder &MIRBuilder,
513     GISelChangeObserver &Observer) const {
514   assert(MI.getOpcode() == TargetOpcode::G_ASHR ||
515          MI.getOpcode() == TargetOpcode::G_LSHR ||
516          MI.getOpcode() == TargetOpcode::G_SHL);
517   MachineRegisterInfo &MRI = *MIRBuilder.getMRI();
518   // If the shift amount is a G_CONSTANT, promote it to a 64 bit type so the
519   // imported patterns can select it later. Either way, it will be legal.
520   Register AmtReg = MI.getOperand(2).getReg();
521   auto VRegAndVal = getIConstantVRegValWithLookThrough(AmtReg, MRI);
522   if (!VRegAndVal)
523     return true;
524   // Check the shift amount is in range for an immediate form.
525   uint64_t Amount = VRegAndVal->Value.getZExtValue();
526   if (Amount > 31)
527     return true; // This will have to remain a register variant.
528   auto ExtCst = MIRBuilder.buildConstant(LLT::scalar(64), Amount);
529   Observer.changingInstr(MI);
530   MI.getOperand(2).setReg(ExtCst.getReg(0));
531   Observer.changedInstr(MI);
532   return true;
533 }
534 
535 bool RISCVLegalizerInfo::legalizeVAStart(MachineInstr &MI,
536                                          MachineIRBuilder &MIRBuilder) const {
537   // Stores the address of the VarArgsFrameIndex slot into the memory location
538   assert(MI.getOpcode() == TargetOpcode::G_VASTART);
539   MachineFunction *MF = MI.getParent()->getParent();
540   RISCVMachineFunctionInfo *FuncInfo = MF->getInfo<RISCVMachineFunctionInfo>();
541   int FI = FuncInfo->getVarArgsFrameIndex();
542   LLT AddrTy = MIRBuilder.getMRI()->getType(MI.getOperand(0).getReg());
543   auto FINAddr = MIRBuilder.buildFrameIndex(AddrTy, FI);
544   assert(MI.hasOneMemOperand());
545   MIRBuilder.buildStore(FINAddr, MI.getOperand(0).getReg(),
546                         *MI.memoperands()[0]);
547   MI.eraseFromParent();
548   return true;
549 }
550 
551 bool RISCVLegalizerInfo::shouldBeInConstantPool(APInt APImm,
552                                                 bool ShouldOptForSize) const {
553   assert(APImm.getBitWidth() == 32 || APImm.getBitWidth() == 64);
554   int64_t Imm = APImm.getSExtValue();
555   // All simm32 constants should be handled by isel.
556   // NOTE: The getMaxBuildIntsCost call below should return a value >= 2 making
557   // this check redundant, but small immediates are common so this check
558   // should have better compile time.
559   if (isInt<32>(Imm))
560     return false;
561 
562   // We only need to cost the immediate, if constant pool lowering is enabled.
563   if (!STI.useConstantPoolForLargeInts())
564     return false;
565 
566   RISCVMatInt::InstSeq Seq = RISCVMatInt::generateInstSeq(Imm, STI);
567   if (Seq.size() <= STI.getMaxBuildIntsCost())
568     return false;
569 
570   // Optimizations below are disabled for opt size. If we're optimizing for
571   // size, use a constant pool.
572   if (ShouldOptForSize)
573     return true;
574   //
575   // Special case. See if we can build the constant as (ADD (SLLI X, C), X) do
576   // that if it will avoid a constant pool.
577   // It will require an extra temporary register though.
578   // If we have Zba we can use (ADD_UW X, (SLLI X, 32)) to handle cases where
579   // low and high 32 bits are the same and bit 31 and 63 are set.
580   unsigned ShiftAmt, AddOpc;
581   RISCVMatInt::InstSeq SeqLo =
582       RISCVMatInt::generateTwoRegInstSeq(Imm, STI, ShiftAmt, AddOpc);
583   return !(!SeqLo.empty() && (SeqLo.size() + 2) <= STI.getMaxBuildIntsCost());
584 }
585 
586 bool RISCVLegalizerInfo::legalizeVScale(MachineInstr &MI,
587                                         MachineIRBuilder &MIB) const {
588   const LLT XLenTy(STI.getXLenVT());
589   Register Dst = MI.getOperand(0).getReg();
590 
591   // We define our scalable vector types for lmul=1 to use a 64 bit known
592   // minimum size. e.g. <vscale x 2 x i32>. VLENB is in bytes so we calculate
593   // vscale as VLENB / 8.
594   static_assert(RISCV::RVVBitsPerBlock == 64, "Unexpected bits per block!");
595   if (STI.getRealMinVLen() < RISCV::RVVBitsPerBlock)
596     // Support for VLEN==32 is incomplete.
597     return false;
598 
599   // We assume VLENB is a multiple of 8. We manually choose the best shift
600   // here because SimplifyDemandedBits isn't always able to simplify it.
601   uint64_t Val = MI.getOperand(1).getCImm()->getZExtValue();
602   if (isPowerOf2_64(Val)) {
603     uint64_t Log2 = Log2_64(Val);
604     if (Log2 < 3) {
605       auto VLENB = MIB.buildInstr(RISCV::G_READ_VLENB, {XLenTy}, {});
606       MIB.buildLShr(Dst, VLENB, MIB.buildConstant(XLenTy, 3 - Log2));
607     } else if (Log2 > 3) {
608       auto VLENB = MIB.buildInstr(RISCV::G_READ_VLENB, {XLenTy}, {});
609       MIB.buildShl(Dst, VLENB, MIB.buildConstant(XLenTy, Log2 - 3));
610     } else {
611       MIB.buildInstr(RISCV::G_READ_VLENB, {Dst}, {});
612     }
613   } else if ((Val % 8) == 0) {
614     // If the multiplier is a multiple of 8, scale it down to avoid needing
615     // to shift the VLENB value.
616     auto VLENB = MIB.buildInstr(RISCV::G_READ_VLENB, {XLenTy}, {});
617     MIB.buildMul(Dst, VLENB, MIB.buildConstant(XLenTy, Val / 8));
618   } else {
619     auto VLENB = MIB.buildInstr(RISCV::G_READ_VLENB, {XLenTy}, {});
620     auto VScale = MIB.buildLShr(XLenTy, VLENB, MIB.buildConstant(XLenTy, 3));
621     MIB.buildMul(Dst, VScale, MIB.buildConstant(XLenTy, Val));
622   }
623   MI.eraseFromParent();
624   return true;
625 }
626 
627 // Custom-lower extensions from mask vectors by using a vselect either with 1
628 // for zero/any-extension or -1 for sign-extension:
629 //   (vXiN = (s|z)ext vXi1:vmask) -> (vXiN = vselect vmask, (-1 or 1), 0)
630 // Note that any-extension is lowered identically to zero-extension.
631 bool RISCVLegalizerInfo::legalizeExt(MachineInstr &MI,
632                                      MachineIRBuilder &MIB) const {
633 
634   unsigned Opc = MI.getOpcode();
635   assert(Opc == TargetOpcode::G_ZEXT || Opc == TargetOpcode::G_SEXT ||
636          Opc == TargetOpcode::G_ANYEXT);
637 
638   MachineRegisterInfo &MRI = *MIB.getMRI();
639   Register Dst = MI.getOperand(0).getReg();
640   Register Src = MI.getOperand(1).getReg();
641 
642   LLT DstTy = MRI.getType(Dst);
643   int64_t ExtTrueVal = Opc == TargetOpcode::G_SEXT ? -1 : 1;
644   LLT DstEltTy = DstTy.getElementType();
645   auto SplatZero = MIB.buildSplatVector(DstTy, MIB.buildConstant(DstEltTy, 0));
646   auto SplatTrue =
647       MIB.buildSplatVector(DstTy, MIB.buildConstant(DstEltTy, ExtTrueVal));
648   MIB.buildSelect(Dst, Src, SplatTrue, SplatZero);
649 
650   MI.eraseFromParent();
651   return true;
652 }
653 
654 /// Return the type of the mask type suitable for masking the provided
655 /// vector type.  This is simply an i1 element type vector of the same
656 /// (possibly scalable) length.
657 static LLT getMaskTypeFor(LLT VecTy) {
658   assert(VecTy.isVector());
659   ElementCount EC = VecTy.getElementCount();
660   return LLT::vector(EC, LLT::scalar(1));
661 }
662 
663 /// Creates an all ones mask suitable for masking a vector of type VecTy with
664 /// vector length VL.
665 static MachineInstrBuilder buildAllOnesMask(LLT VecTy, const SrcOp &VL,
666                                             MachineIRBuilder &MIB,
667                                             MachineRegisterInfo &MRI) {
668   LLT MaskTy = getMaskTypeFor(VecTy);
669   return MIB.buildInstr(RISCV::G_VMSET_VL, {MaskTy}, {VL});
670 }
671 
672 /// Gets the two common "VL" operands: an all-ones mask and the vector length.
673 /// VecTy is a scalable vector type.
674 static std::pair<MachineInstrBuilder, Register>
675 buildDefaultVLOps(const DstOp &Dst, MachineIRBuilder &MIB,
676                   MachineRegisterInfo &MRI) {
677   LLT VecTy = Dst.getLLTTy(MRI);
678   assert(VecTy.isScalableVector() && "Expecting scalable container type");
679   Register VL(RISCV::X0);
680   MachineInstrBuilder Mask = buildAllOnesMask(VecTy, VL, MIB, MRI);
681   return {Mask, VL};
682 }
683 
684 static MachineInstrBuilder
685 buildSplatPartsS64WithVL(const DstOp &Dst, const SrcOp &Passthru, Register Lo,
686                          Register Hi, Register VL, MachineIRBuilder &MIB,
687                          MachineRegisterInfo &MRI) {
688   // TODO: If the Hi bits of the splat are undefined, then it's fine to just
689   // splat Lo even if it might be sign extended. I don't think we have
690   // introduced a case where we're build a s64 where the upper bits are undef
691   // yet.
692 
693   // Fall back to a stack store and stride x0 vector load.
694   // TODO: need to lower G_SPLAT_VECTOR_SPLIT_I64. This is done in
695   // preprocessDAG in SDAG.
696   return MIB.buildInstr(RISCV::G_SPLAT_VECTOR_SPLIT_I64_VL, {Dst},
697                         {Passthru, Lo, Hi, VL});
698 }
699 
700 static MachineInstrBuilder
701 buildSplatSplitS64WithVL(const DstOp &Dst, const SrcOp &Passthru,
702                          const SrcOp &Scalar, Register VL,
703                          MachineIRBuilder &MIB, MachineRegisterInfo &MRI) {
704   assert(Scalar.getLLTTy(MRI) == LLT::scalar(64) && "Unexpected VecTy!");
705   auto Unmerge = MIB.buildUnmerge(LLT::scalar(32), Scalar);
706   return buildSplatPartsS64WithVL(Dst, Passthru, Unmerge.getReg(0),
707                                   Unmerge.getReg(1), VL, MIB, MRI);
708 }
709 
710 // Lower splats of s1 types to G_ICMP. For each mask vector type, we have a
711 // legal equivalently-sized i8 type, so we can use that as a go-between.
712 // Splats of s1 types that have constant value can be legalized as VMSET_VL or
713 // VMCLR_VL.
714 bool RISCVLegalizerInfo::legalizeSplatVector(MachineInstr &MI,
715                                              MachineIRBuilder &MIB) const {
716   assert(MI.getOpcode() == TargetOpcode::G_SPLAT_VECTOR);
717 
718   MachineRegisterInfo &MRI = *MIB.getMRI();
719 
720   Register Dst = MI.getOperand(0).getReg();
721   Register SplatVal = MI.getOperand(1).getReg();
722 
723   LLT VecTy = MRI.getType(Dst);
724   LLT XLenTy(STI.getXLenVT());
725 
726   // Handle case of s64 element vectors on rv32
727   if (XLenTy.getSizeInBits() == 32 &&
728       VecTy.getElementType().getSizeInBits() == 64) {
729     auto [_, VL] = buildDefaultVLOps(Dst, MIB, MRI);
730     buildSplatSplitS64WithVL(Dst, MIB.buildUndef(VecTy), SplatVal, VL, MIB,
731                              MRI);
732     MI.eraseFromParent();
733     return true;
734   }
735 
736   // All-zeros or all-ones splats are handled specially.
737   MachineInstr &SplatValMI = *MRI.getVRegDef(SplatVal);
738   if (isAllOnesOrAllOnesSplat(SplatValMI, MRI)) {
739     auto VL = buildDefaultVLOps(VecTy, MIB, MRI).second;
740     MIB.buildInstr(RISCV::G_VMSET_VL, {Dst}, {VL});
741     MI.eraseFromParent();
742     return true;
743   }
744   if (isNullOrNullSplat(SplatValMI, MRI)) {
745     auto VL = buildDefaultVLOps(VecTy, MIB, MRI).second;
746     MIB.buildInstr(RISCV::G_VMCLR_VL, {Dst}, {VL});
747     MI.eraseFromParent();
748     return true;
749   }
750 
751   // Handle non-constant mask splat (i.e. not sure if it's all zeros or all
752   // ones) by promoting it to an s8 splat.
753   LLT InterEltTy = LLT::scalar(8);
754   LLT InterTy = VecTy.changeElementType(InterEltTy);
755   auto ZExtSplatVal = MIB.buildZExt(InterEltTy, SplatVal);
756   auto And =
757       MIB.buildAnd(InterEltTy, ZExtSplatVal, MIB.buildConstant(InterEltTy, 1));
758   auto LHS = MIB.buildSplatVector(InterTy, And);
759   auto ZeroSplat =
760       MIB.buildSplatVector(InterTy, MIB.buildConstant(InterEltTy, 0));
761   MIB.buildICmp(CmpInst::Predicate::ICMP_NE, Dst, LHS, ZeroSplat);
762   MI.eraseFromParent();
763   return true;
764 }
765 
766 bool RISCVLegalizerInfo::legalizeCustom(
767     LegalizerHelper &Helper, MachineInstr &MI,
768     LostDebugLocObserver &LocObserver) const {
769   MachineIRBuilder &MIRBuilder = Helper.MIRBuilder;
770   GISelChangeObserver &Observer = Helper.Observer;
771   MachineFunction &MF = *MI.getParent()->getParent();
772   switch (MI.getOpcode()) {
773   default:
774     // No idea what to do.
775     return false;
776   case TargetOpcode::G_ABS:
777     return Helper.lowerAbsToMaxNeg(MI);
778   // TODO: G_FCONSTANT
779   case TargetOpcode::G_CONSTANT: {
780     const Function &F = MF.getFunction();
781     // TODO: if PSI and BFI are present, add " ||
782     // llvm::shouldOptForSize(*CurMBB, PSI, BFI)".
783     bool ShouldOptForSize = F.hasOptSize() || F.hasMinSize();
784     const ConstantInt *ConstVal = MI.getOperand(1).getCImm();
785     if (!shouldBeInConstantPool(ConstVal->getValue(), ShouldOptForSize))
786       return true;
787     return Helper.lowerConstant(MI);
788   }
789   case TargetOpcode::G_SHL:
790   case TargetOpcode::G_ASHR:
791   case TargetOpcode::G_LSHR:
792     return legalizeShlAshrLshr(MI, MIRBuilder, Observer);
793   case TargetOpcode::G_SEXT_INREG: {
794     // Source size of 32 is sext.w.
795     int64_t SizeInBits = MI.getOperand(2).getImm();
796     if (SizeInBits == 32)
797       return true;
798 
799     return Helper.lower(MI, 0, /* Unused hint type */ LLT()) ==
800            LegalizerHelper::Legalized;
801   }
802   case TargetOpcode::G_IS_FPCLASS: {
803     Register GISFPCLASS = MI.getOperand(0).getReg();
804     Register Src = MI.getOperand(1).getReg();
805     const MachineOperand &ImmOp = MI.getOperand(2);
806     MachineIRBuilder MIB(MI);
807 
808     // Turn LLVM IR's floating point classes to that in RISC-V,
809     // by simply rotating the 10-bit immediate right by two bits.
810     APInt GFpClassImm(10, static_cast<uint64_t>(ImmOp.getImm()));
811     auto FClassMask = MIB.buildConstant(sXLen, GFpClassImm.rotr(2).zext(XLen));
812     auto ConstZero = MIB.buildConstant(sXLen, 0);
813 
814     auto GFClass = MIB.buildInstr(RISCV::G_FCLASS, {sXLen}, {Src});
815     auto And = MIB.buildAnd(sXLen, GFClass, FClassMask);
816     MIB.buildICmp(CmpInst::ICMP_NE, GISFPCLASS, And, ConstZero);
817 
818     MI.eraseFromParent();
819     return true;
820   }
821   case TargetOpcode::G_VASTART:
822     return legalizeVAStart(MI, MIRBuilder);
823   case TargetOpcode::G_VSCALE:
824     return legalizeVScale(MI, MIRBuilder);
825   case TargetOpcode::G_ZEXT:
826   case TargetOpcode::G_SEXT:
827   case TargetOpcode::G_ANYEXT:
828     return legalizeExt(MI, MIRBuilder);
829   case TargetOpcode::G_SPLAT_VECTOR:
830     return legalizeSplatVector(MI, MIRBuilder);
831   }
832 
833   llvm_unreachable("expected switch to return");
834 }
835