xref: /freebsd/contrib/llvm-project/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp (revision f5f40dd63bc7acbb5312b26ac1ea1103c12352a6)
1 //===-- RISCVRegisterBankInfo.cpp -------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 /// This file implements the targeting of the RegisterBankInfo class for RISC-V.
10 /// \todo This should be generated by TableGen.
11 //===----------------------------------------------------------------------===//
12 
13 #include "RISCVRegisterBankInfo.h"
14 #include "MCTargetDesc/RISCVMCTargetDesc.h"
15 #include "RISCVSubtarget.h"
16 #include "llvm/CodeGen/MachineRegisterInfo.h"
17 #include "llvm/CodeGen/RegisterBank.h"
18 #include "llvm/CodeGen/RegisterBankInfo.h"
19 #include "llvm/CodeGen/TargetRegisterInfo.h"
20 
21 #define GET_TARGET_REGBANK_IMPL
22 #include "RISCVGenRegisterBank.inc"
23 
24 namespace llvm {
25 namespace RISCV {
26 
27 const RegisterBankInfo::PartialMapping PartMappings[] = {
28     {0, 32, GPRBRegBank},
29     {0, 64, GPRBRegBank},
30     {0, 32, FPRBRegBank},
31     {0, 64, FPRBRegBank},
32 };
33 
34 enum PartialMappingIdx {
35   PMI_GPRB32 = 0,
36   PMI_GPRB64 = 1,
37   PMI_FPRB32 = 2,
38   PMI_FPRB64 = 3,
39 };
40 
41 const RegisterBankInfo::ValueMapping ValueMappings[] = {
42     // Invalid value mapping.
43     {nullptr, 0},
44     // Maximum 3 GPR operands; 32 bit.
45     {&PartMappings[PMI_GPRB32], 1},
46     {&PartMappings[PMI_GPRB32], 1},
47     {&PartMappings[PMI_GPRB32], 1},
48     // Maximum 3 GPR operands; 64 bit.
49     {&PartMappings[PMI_GPRB64], 1},
50     {&PartMappings[PMI_GPRB64], 1},
51     {&PartMappings[PMI_GPRB64], 1},
52     // Maximum 3 FPR operands; 32 bit.
53     {&PartMappings[PMI_FPRB32], 1},
54     {&PartMappings[PMI_FPRB32], 1},
55     {&PartMappings[PMI_FPRB32], 1},
56     // Maximum 3 FPR operands; 64 bit.
57     {&PartMappings[PMI_FPRB64], 1},
58     {&PartMappings[PMI_FPRB64], 1},
59     {&PartMappings[PMI_FPRB64], 1},
60 };
61 
62 enum ValueMappingIdx {
63   InvalidIdx = 0,
64   GPRB32Idx = 1,
65   GPRB64Idx = 4,
66   FPRB32Idx = 7,
67   FPRB64Idx = 10,
68 };
69 } // namespace RISCV
70 } // namespace llvm
71 
72 using namespace llvm;
73 
74 RISCVRegisterBankInfo::RISCVRegisterBankInfo(unsigned HwMode)
75     : RISCVGenRegisterBankInfo(HwMode) {}
76 
77 const RegisterBank &
78 RISCVRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
79                                               LLT Ty) const {
80   switch (RC.getID()) {
81   default:
82     llvm_unreachable("Register class not supported");
83   case RISCV::GPRRegClassID:
84   case RISCV::GPRF16RegClassID:
85   case RISCV::GPRF32RegClassID:
86   case RISCV::GPRNoX0RegClassID:
87   case RISCV::GPRNoX0X2RegClassID:
88   case RISCV::GPRJALRRegClassID:
89   case RISCV::GPRTCRegClassID:
90   case RISCV::GPRC_and_GPRTCRegClassID:
91   case RISCV::GPRCRegClassID:
92   case RISCV::GPRC_and_SR07RegClassID:
93   case RISCV::SR07RegClassID:
94   case RISCV::SPRegClassID:
95   case RISCV::GPRX0RegClassID:
96     return getRegBank(RISCV::GPRBRegBankID);
97   case RISCV::FPR64RegClassID:
98   case RISCV::FPR16RegClassID:
99   case RISCV::FPR32RegClassID:
100   case RISCV::FPR64CRegClassID:
101   case RISCV::FPR32CRegClassID:
102     return getRegBank(RISCV::FPRBRegBankID);
103   case RISCV::VMRegClassID:
104   case RISCV::VRRegClassID:
105   case RISCV::VRNoV0RegClassID:
106   case RISCV::VRM2RegClassID:
107   case RISCV::VRM2NoV0RegClassID:
108   case RISCV::VRM4RegClassID:
109   case RISCV::VRM4NoV0RegClassID:
110   case RISCV::VMV0RegClassID:
111   case RISCV::VRM2_with_sub_vrm1_0_in_VMV0RegClassID:
112   case RISCV::VRM4_with_sub_vrm1_0_in_VMV0RegClassID:
113   case RISCV::VRM8RegClassID:
114   case RISCV::VRM8NoV0RegClassID:
115   case RISCV::VRM8_with_sub_vrm1_0_in_VMV0RegClassID:
116     return getRegBank(RISCV::VRBRegBankID);
117   }
118 }
119 
120 static const RegisterBankInfo::ValueMapping *getFPValueMapping(unsigned Size) {
121   assert(Size == 32 || Size == 64);
122   unsigned Idx = Size == 64 ? RISCV::FPRB64Idx : RISCV::FPRB32Idx;
123   return &RISCV::ValueMappings[Idx];
124 }
125 
126 /// Returns whether opcode \p Opc is a pre-isel generic floating-point opcode,
127 /// having only floating-point operands.
128 /// FIXME: this is copied from target AArch64. Needs some code refactor here to
129 /// put this function in GlobalISel/Utils.cpp.
130 static bool isPreISelGenericFloatingPointOpcode(unsigned Opc) {
131   switch (Opc) {
132   case TargetOpcode::G_FADD:
133   case TargetOpcode::G_FSUB:
134   case TargetOpcode::G_FMUL:
135   case TargetOpcode::G_FMA:
136   case TargetOpcode::G_FDIV:
137   case TargetOpcode::G_FCONSTANT:
138   case TargetOpcode::G_FPEXT:
139   case TargetOpcode::G_FPTRUNC:
140   case TargetOpcode::G_FCEIL:
141   case TargetOpcode::G_FFLOOR:
142   case TargetOpcode::G_FNEARBYINT:
143   case TargetOpcode::G_FNEG:
144   case TargetOpcode::G_FCOPYSIGN:
145   case TargetOpcode::G_FCOS:
146   case TargetOpcode::G_FSIN:
147   case TargetOpcode::G_FLOG10:
148   case TargetOpcode::G_FLOG:
149   case TargetOpcode::G_FLOG2:
150   case TargetOpcode::G_FSQRT:
151   case TargetOpcode::G_FABS:
152   case TargetOpcode::G_FEXP:
153   case TargetOpcode::G_FRINT:
154   case TargetOpcode::G_INTRINSIC_TRUNC:
155   case TargetOpcode::G_INTRINSIC_ROUND:
156   case TargetOpcode::G_INTRINSIC_ROUNDEVEN:
157   case TargetOpcode::G_FMAXNUM:
158   case TargetOpcode::G_FMINNUM:
159   case TargetOpcode::G_FMAXIMUM:
160   case TargetOpcode::G_FMINIMUM:
161     return true;
162   }
163   return false;
164 }
165 
166 // TODO: Make this more like AArch64?
167 bool RISCVRegisterBankInfo::hasFPConstraints(
168     const MachineInstr &MI, const MachineRegisterInfo &MRI,
169     const TargetRegisterInfo &TRI) const {
170   if (isPreISelGenericFloatingPointOpcode(MI.getOpcode()))
171     return true;
172 
173   // If we have a copy instruction, we could be feeding floating point
174   // instructions.
175   if (MI.getOpcode() != TargetOpcode::COPY)
176     return false;
177 
178   return getRegBank(MI.getOperand(0).getReg(), MRI, TRI) == &RISCV::FPRBRegBank;
179 }
180 
181 bool RISCVRegisterBankInfo::onlyUsesFP(const MachineInstr &MI,
182                                        const MachineRegisterInfo &MRI,
183                                        const TargetRegisterInfo &TRI) const {
184   switch (MI.getOpcode()) {
185   case TargetOpcode::G_FPTOSI:
186   case TargetOpcode::G_FPTOUI:
187   case TargetOpcode::G_FCMP:
188     return true;
189   default:
190     break;
191   }
192 
193   return hasFPConstraints(MI, MRI, TRI);
194 }
195 
196 bool RISCVRegisterBankInfo::onlyDefinesFP(const MachineInstr &MI,
197                                           const MachineRegisterInfo &MRI,
198                                           const TargetRegisterInfo &TRI) const {
199   switch (MI.getOpcode()) {
200   case TargetOpcode::G_SITOFP:
201   case TargetOpcode::G_UITOFP:
202     return true;
203   default:
204     break;
205   }
206 
207   return hasFPConstraints(MI, MRI, TRI);
208 }
209 
210 bool RISCVRegisterBankInfo::anyUseOnlyUseFP(
211     Register Def, const MachineRegisterInfo &MRI,
212     const TargetRegisterInfo &TRI) const {
213   return any_of(
214       MRI.use_nodbg_instructions(Def),
215       [&](const MachineInstr &UseMI) { return onlyUsesFP(UseMI, MRI, TRI); });
216 }
217 
218 const RegisterBankInfo::InstructionMapping &
219 RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
220   const unsigned Opc = MI.getOpcode();
221 
222   // Try the default logic for non-generic instructions that are either copies
223   // or already have some operands assigned to banks.
224   if (!isPreISelGenericOpcode(Opc) || Opc == TargetOpcode::G_PHI) {
225     const InstructionMapping &Mapping = getInstrMappingImpl(MI);
226     if (Mapping.isValid())
227       return Mapping;
228   }
229 
230   const MachineFunction &MF = *MI.getParent()->getParent();
231   const MachineRegisterInfo &MRI = MF.getRegInfo();
232   const TargetSubtargetInfo &STI = MF.getSubtarget();
233   const TargetRegisterInfo &TRI = *STI.getRegisterInfo();
234 
235   unsigned GPRSize = getMaximumSize(RISCV::GPRBRegBankID);
236   assert((GPRSize == 32 || GPRSize == 64) && "Unexpected GPR size");
237 
238   unsigned NumOperands = MI.getNumOperands();
239   const ValueMapping *GPRValueMapping =
240       &RISCV::ValueMappings[GPRSize == 64 ? RISCV::GPRB64Idx
241                                           : RISCV::GPRB32Idx];
242 
243   switch (Opc) {
244   case TargetOpcode::G_ADD:
245   case TargetOpcode::G_SUB:
246   case TargetOpcode::G_SHL:
247   case TargetOpcode::G_ASHR:
248   case TargetOpcode::G_LSHR:
249   case TargetOpcode::G_AND:
250   case TargetOpcode::G_OR:
251   case TargetOpcode::G_XOR:
252   case TargetOpcode::G_MUL:
253   case TargetOpcode::G_SDIV:
254   case TargetOpcode::G_SREM:
255   case TargetOpcode::G_SMULH:
256   case TargetOpcode::G_SMAX:
257   case TargetOpcode::G_SMIN:
258   case TargetOpcode::G_UDIV:
259   case TargetOpcode::G_UREM:
260   case TargetOpcode::G_UMULH:
261   case TargetOpcode::G_UMAX:
262   case TargetOpcode::G_UMIN:
263   case TargetOpcode::G_PTR_ADD:
264   case TargetOpcode::G_PTRTOINT:
265   case TargetOpcode::G_INTTOPTR:
266   case TargetOpcode::G_TRUNC:
267   case TargetOpcode::G_ANYEXT:
268   case TargetOpcode::G_SEXT:
269   case TargetOpcode::G_ZEXT:
270   case TargetOpcode::G_SEXTLOAD:
271   case TargetOpcode::G_ZEXTLOAD:
272     return getInstructionMapping(DefaultMappingID, /*Cost=*/1, GPRValueMapping,
273                                  NumOperands);
274   case TargetOpcode::G_FADD:
275   case TargetOpcode::G_FSUB:
276   case TargetOpcode::G_FMUL:
277   case TargetOpcode::G_FDIV:
278   case TargetOpcode::G_FABS:
279   case TargetOpcode::G_FNEG:
280   case TargetOpcode::G_FSQRT:
281   case TargetOpcode::G_FMAXNUM:
282   case TargetOpcode::G_FMINNUM: {
283     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
284     return getInstructionMapping(DefaultMappingID, /*Cost=*/1,
285                                  getFPValueMapping(Ty.getSizeInBits()),
286                                  NumOperands);
287   }
288   case TargetOpcode::G_IMPLICIT_DEF: {
289     Register Dst = MI.getOperand(0).getReg();
290     auto Mapping = GPRValueMapping;
291     // FIXME: May need to do a better job determining when to use FPRB.
292     // For example, the look through COPY case:
293     // %0:_(s32) = G_IMPLICIT_DEF
294     // %1:_(s32) = COPY %0
295     // $f10_d = COPY %1(s32)
296     if (anyUseOnlyUseFP(Dst, MRI, TRI))
297       Mapping = getFPValueMapping(MRI.getType(Dst).getSizeInBits());
298     return getInstructionMapping(DefaultMappingID, /*Cost=*/1, Mapping,
299                                  NumOperands);
300   }
301   }
302 
303   SmallVector<const ValueMapping *, 4> OpdsMapping(NumOperands);
304 
305   switch (Opc) {
306   case TargetOpcode::G_LOAD: {
307     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
308     OpdsMapping[0] = GPRValueMapping;
309     OpdsMapping[1] = GPRValueMapping;
310     // Use FPR64 for s64 loads on rv32.
311     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
312       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
313       OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
314       break;
315     }
316 
317     // Check if that load feeds fp instructions.
318     // In that case, we want the default mapping to be on FPR
319     // instead of blind map every scalar to GPR.
320     if (anyUseOnlyUseFP(MI.getOperand(0).getReg(), MRI, TRI))
321       // If we have at least one direct use in a FP instruction,
322       // assume this was a floating point load in the IR. If it was
323       // not, we would have had a bitcast before reaching that
324       // instruction.
325       OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
326 
327     break;
328   }
329   case TargetOpcode::G_STORE: {
330     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
331     OpdsMapping[0] = GPRValueMapping;
332     OpdsMapping[1] = GPRValueMapping;
333     // Use FPR64 for s64 stores on rv32.
334     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
335       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
336       OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
337       break;
338     }
339 
340     MachineInstr *DefMI = MRI.getVRegDef(MI.getOperand(0).getReg());
341     if (onlyDefinesFP(*DefMI, MRI, TRI))
342       OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
343     break;
344   }
345   case TargetOpcode::G_SELECT: {
346     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
347 
348     // Try to minimize the number of copies. If we have more floating point
349     // constrained values than not, then we'll put everything on FPR. Otherwise,
350     // everything has to be on GPR.
351     unsigned NumFP = 0;
352 
353     // Use FPR64 for s64 select on rv32.
354     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
355       NumFP = 3;
356     } else {
357       // Check if the uses of the result always produce floating point values.
358       //
359       // For example:
360       //
361       // %z = G_SELECT %cond %x %y
362       // fpr = G_FOO %z ...
363       if (any_of(MRI.use_nodbg_instructions(MI.getOperand(0).getReg()),
364                  [&](const MachineInstr &UseMI) {
365                    return onlyUsesFP(UseMI, MRI, TRI);
366                  }))
367         ++NumFP;
368 
369       // Check if the defs of the source values always produce floating point
370       // values.
371       //
372       // For example:
373       //
374       // %x = G_SOMETHING_ALWAYS_FLOAT %a ...
375       // %z = G_SELECT %cond %x %y
376       //
377       // Also check whether or not the sources have already been decided to be
378       // FPR. Keep track of this.
379       //
380       // This doesn't check the condition, since the condition is always an
381       // integer.
382       for (unsigned Idx = 2; Idx < 4; ++Idx) {
383         Register VReg = MI.getOperand(Idx).getReg();
384         MachineInstr *DefMI = MRI.getVRegDef(VReg);
385         if (getRegBank(VReg, MRI, TRI) == &RISCV::FPRBRegBank ||
386             onlyDefinesFP(*DefMI, MRI, TRI))
387           ++NumFP;
388       }
389     }
390 
391     // Condition operand is always GPR.
392     OpdsMapping[1] = GPRValueMapping;
393 
394     const ValueMapping *Mapping = GPRValueMapping;
395     if (NumFP >= 2)
396       Mapping = getFPValueMapping(Ty.getSizeInBits());
397 
398     OpdsMapping[0] = OpdsMapping[2] = OpdsMapping[3] = Mapping;
399     break;
400   }
401   case TargetOpcode::G_FPTOSI:
402   case TargetOpcode::G_FPTOUI:
403   case RISCV::G_FCLASS: {
404     LLT Ty = MRI.getType(MI.getOperand(1).getReg());
405     OpdsMapping[0] = GPRValueMapping;
406     OpdsMapping[1] = getFPValueMapping(Ty.getSizeInBits());
407     break;
408   }
409   case TargetOpcode::G_SITOFP:
410   case TargetOpcode::G_UITOFP: {
411     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
412     OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
413     OpdsMapping[1] = GPRValueMapping;
414     break;
415   }
416   case TargetOpcode::G_FCMP: {
417     LLT Ty = MRI.getType(MI.getOperand(2).getReg());
418 
419     unsigned Size = Ty.getSizeInBits();
420     assert((Size == 32 || Size == 64) && "Unsupported size for G_FCMP");
421 
422     OpdsMapping[0] = GPRValueMapping;
423     OpdsMapping[2] = OpdsMapping[3] = getFPValueMapping(Size);
424     break;
425   }
426   case TargetOpcode::G_MERGE_VALUES: {
427     // Use FPR64 for s64 merge on rv32.
428     LLT Ty = MRI.getType(MI.getOperand(0).getReg());
429     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
430       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
431       OpdsMapping[0] = getFPValueMapping(Ty.getSizeInBits());
432       OpdsMapping[1] = GPRValueMapping;
433       OpdsMapping[2] = GPRValueMapping;
434     }
435     break;
436   }
437   case TargetOpcode::G_UNMERGE_VALUES: {
438     // Use FPR64 for s64 unmerge on rv32.
439     LLT Ty = MRI.getType(MI.getOperand(2).getReg());
440     if (GPRSize == 32 && Ty.getSizeInBits() == 64) {
441       assert(MF.getSubtarget<RISCVSubtarget>().hasStdExtD());
442       OpdsMapping[0] = GPRValueMapping;
443       OpdsMapping[1] = GPRValueMapping;
444       OpdsMapping[2] = getFPValueMapping(Ty.getSizeInBits());
445     }
446     break;
447   }
448   default:
449     // By default map all scalars to GPR.
450     for (unsigned Idx = 0; Idx < NumOperands; ++Idx) {
451        auto &MO = MI.getOperand(Idx);
452        if (!MO.isReg() || !MO.getReg())
453          continue;
454        LLT Ty = MRI.getType(MO.getReg());
455        if (!Ty.isValid())
456          continue;
457 
458        if (isPreISelGenericFloatingPointOpcode(Opc))
459          OpdsMapping[Idx] = getFPValueMapping(Ty.getSizeInBits());
460        else
461          OpdsMapping[Idx] = GPRValueMapping;
462     }
463     break;
464   }
465 
466   return getInstructionMapping(DefaultMappingID, /*Cost=*/1,
467                                getOperandsMapping(OpdsMapping), NumOperands);
468 }
469