xref: /freebsd/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===-- RISCVInstrInfo.cpp - RISC-V Instruction Information -----*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the RISC-V implementation of the TargetInstrInfo class.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "RISCVInstrInfo.h"
14 #include "MCTargetDesc/RISCVBaseInfo.h"
15 #include "MCTargetDesc/RISCVMatInt.h"
16 #include "RISCV.h"
17 #include "RISCVMachineFunctionInfo.h"
18 #include "RISCVSubtarget.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/Statistic.h"
22 #include "llvm/Analysis/MemoryLocation.h"
23 #include "llvm/Analysis/ValueTracking.h"
24 #include "llvm/CodeGen/LiveIntervals.h"
25 #include "llvm/CodeGen/LiveVariables.h"
26 #include "llvm/CodeGen/MachineCombinerPattern.h"
27 #include "llvm/CodeGen/MachineInstrBuilder.h"
28 #include "llvm/CodeGen/MachineRegisterInfo.h"
29 #include "llvm/CodeGen/MachineTraceMetrics.h"
30 #include "llvm/CodeGen/RegisterScavenging.h"
31 #include "llvm/CodeGen/StackMaps.h"
32 #include "llvm/IR/DebugInfoMetadata.h"
33 #include "llvm/IR/Module.h"
34 #include "llvm/MC/MCInstBuilder.h"
35 #include "llvm/MC/TargetRegistry.h"
36 #include "llvm/Support/ErrorHandling.h"
37 
38 using namespace llvm;
39 
40 #define GEN_CHECK_COMPRESS_INSTR
41 #include "RISCVGenCompressInstEmitter.inc"
42 
43 #define GET_INSTRINFO_CTOR_DTOR
44 #define GET_INSTRINFO_NAMED_OPS
45 #include "RISCVGenInstrInfo.inc"
46 
47 #define DEBUG_TYPE "riscv-instr-info"
48 STATISTIC(NumVRegSpilled,
49           "Number of registers within vector register groups spilled");
50 STATISTIC(NumVRegReloaded,
51           "Number of registers within vector register groups reloaded");
52 
53 static cl::opt<bool> PreferWholeRegisterMove(
54     "riscv-prefer-whole-register-move", cl::init(false), cl::Hidden,
55     cl::desc("Prefer whole register move for vector registers."));
56 
57 static cl::opt<MachineTraceStrategy> ForceMachineCombinerStrategy(
58     "riscv-force-machine-combiner-strategy", cl::Hidden,
59     cl::desc("Force machine combiner to use a specific strategy for machine "
60              "trace metrics evaluation."),
61     cl::init(MachineTraceStrategy::TS_NumStrategies),
62     cl::values(clEnumValN(MachineTraceStrategy::TS_Local, "local",
63                           "Local strategy."),
64                clEnumValN(MachineTraceStrategy::TS_MinInstrCount, "min-instr",
65                           "MinInstrCount strategy.")));
66 
67 namespace llvm::RISCVVPseudosTable {
68 
69 using namespace RISCV;
70 
71 #define GET_RISCVVPseudosTable_IMPL
72 #include "RISCVGenSearchableTables.inc"
73 
74 } // namespace llvm::RISCVVPseudosTable
75 
76 namespace llvm::RISCV {
77 
78 #define GET_RISCVMaskedPseudosTable_IMPL
79 #include "RISCVGenSearchableTables.inc"
80 
81 } // end namespace llvm::RISCV
82 
RISCVInstrInfo(RISCVSubtarget & STI)83 RISCVInstrInfo::RISCVInstrInfo(RISCVSubtarget &STI)
84     : RISCVGenInstrInfo(RISCV::ADJCALLSTACKDOWN, RISCV::ADJCALLSTACKUP),
85       STI(STI) {}
86 
87 #define GET_INSTRINFO_HELPERS
88 #include "RISCVGenInstrInfo.inc"
89 
getNop() const90 MCInst RISCVInstrInfo::getNop() const {
91   if (STI.hasStdExtZca())
92     return MCInstBuilder(RISCV::C_NOP);
93   return MCInstBuilder(RISCV::ADDI)
94       .addReg(RISCV::X0)
95       .addReg(RISCV::X0)
96       .addImm(0);
97 }
98 
isLoadFromStackSlot(const MachineInstr & MI,int & FrameIndex) const99 Register RISCVInstrInfo::isLoadFromStackSlot(const MachineInstr &MI,
100                                              int &FrameIndex) const {
101   TypeSize Dummy = TypeSize::getZero();
102   return isLoadFromStackSlot(MI, FrameIndex, Dummy);
103 }
104 
getLMULForRVVWholeLoadStore(unsigned Opcode)105 static std::optional<unsigned> getLMULForRVVWholeLoadStore(unsigned Opcode) {
106   switch (Opcode) {
107   default:
108     return std::nullopt;
109   case RISCV::VS1R_V:
110   case RISCV::VL1RE8_V:
111   case RISCV::VL1RE16_V:
112   case RISCV::VL1RE32_V:
113   case RISCV::VL1RE64_V:
114     return 1;
115   case RISCV::VS2R_V:
116   case RISCV::VL2RE8_V:
117   case RISCV::VL2RE16_V:
118   case RISCV::VL2RE32_V:
119   case RISCV::VL2RE64_V:
120     return 2;
121   case RISCV::VS4R_V:
122   case RISCV::VL4RE8_V:
123   case RISCV::VL4RE16_V:
124   case RISCV::VL4RE32_V:
125   case RISCV::VL4RE64_V:
126     return 4;
127   case RISCV::VS8R_V:
128   case RISCV::VL8RE8_V:
129   case RISCV::VL8RE16_V:
130   case RISCV::VL8RE32_V:
131   case RISCV::VL8RE64_V:
132     return 8;
133   }
134 }
135 
isLoadFromStackSlot(const MachineInstr & MI,int & FrameIndex,TypeSize & MemBytes) const136 Register RISCVInstrInfo::isLoadFromStackSlot(const MachineInstr &MI,
137                                              int &FrameIndex,
138                                              TypeSize &MemBytes) const {
139   switch (MI.getOpcode()) {
140   default:
141     return 0;
142   case RISCV::LB:
143   case RISCV::LBU:
144     MemBytes = TypeSize::getFixed(1);
145     break;
146   case RISCV::LH:
147   case RISCV::LH_INX:
148   case RISCV::LHU:
149   case RISCV::FLH:
150     MemBytes = TypeSize::getFixed(2);
151     break;
152   case RISCV::LW:
153   case RISCV::LW_INX:
154   case RISCV::FLW:
155   case RISCV::LWU:
156     MemBytes = TypeSize::getFixed(4);
157     break;
158   case RISCV::LD:
159   case RISCV::LD_RV32:
160   case RISCV::FLD:
161     MemBytes = TypeSize::getFixed(8);
162     break;
163   case RISCV::VL1RE8_V:
164   case RISCV::VL2RE8_V:
165   case RISCV::VL4RE8_V:
166   case RISCV::VL8RE8_V:
167     if (!MI.getOperand(1).isFI())
168       return Register();
169     FrameIndex = MI.getOperand(1).getIndex();
170     unsigned LMUL = *getLMULForRVVWholeLoadStore(MI.getOpcode());
171     MemBytes = TypeSize::getScalable(RISCV::RVVBytesPerBlock * LMUL);
172     return MI.getOperand(0).getReg();
173   }
174 
175   if (MI.getOperand(1).isFI() && MI.getOperand(2).isImm() &&
176       MI.getOperand(2).getImm() == 0) {
177     FrameIndex = MI.getOperand(1).getIndex();
178     return MI.getOperand(0).getReg();
179   }
180 
181   return 0;
182 }
183 
isStoreToStackSlot(const MachineInstr & MI,int & FrameIndex) const184 Register RISCVInstrInfo::isStoreToStackSlot(const MachineInstr &MI,
185                                             int &FrameIndex) const {
186   TypeSize Dummy = TypeSize::getZero();
187   return isStoreToStackSlot(MI, FrameIndex, Dummy);
188 }
189 
isStoreToStackSlot(const MachineInstr & MI,int & FrameIndex,TypeSize & MemBytes) const190 Register RISCVInstrInfo::isStoreToStackSlot(const MachineInstr &MI,
191                                             int &FrameIndex,
192                                             TypeSize &MemBytes) const {
193   switch (MI.getOpcode()) {
194   default:
195     return 0;
196   case RISCV::SB:
197     MemBytes = TypeSize::getFixed(1);
198     break;
199   case RISCV::SH:
200   case RISCV::SH_INX:
201   case RISCV::FSH:
202     MemBytes = TypeSize::getFixed(2);
203     break;
204   case RISCV::SW:
205   case RISCV::SW_INX:
206   case RISCV::FSW:
207     MemBytes = TypeSize::getFixed(4);
208     break;
209   case RISCV::SD:
210   case RISCV::SD_RV32:
211   case RISCV::FSD:
212     MemBytes = TypeSize::getFixed(8);
213     break;
214   case RISCV::VS1R_V:
215   case RISCV::VS2R_V:
216   case RISCV::VS4R_V:
217   case RISCV::VS8R_V:
218     if (!MI.getOperand(1).isFI())
219       return Register();
220     FrameIndex = MI.getOperand(1).getIndex();
221     unsigned LMUL = *getLMULForRVVWholeLoadStore(MI.getOpcode());
222     MemBytes = TypeSize::getScalable(RISCV::RVVBytesPerBlock * LMUL);
223     return MI.getOperand(0).getReg();
224   }
225 
226   if (MI.getOperand(1).isFI() && MI.getOperand(2).isImm() &&
227       MI.getOperand(2).getImm() == 0) {
228     FrameIndex = MI.getOperand(1).getIndex();
229     return MI.getOperand(0).getReg();
230   }
231 
232   return 0;
233 }
234 
isReallyTriviallyReMaterializable(const MachineInstr & MI) const235 bool RISCVInstrInfo::isReallyTriviallyReMaterializable(
236     const MachineInstr &MI) const {
237   switch (RISCV::getRVVMCOpcode(MI.getOpcode())) {
238   case RISCV::VMV_V_X:
239   case RISCV::VFMV_V_F:
240   case RISCV::VMV_V_I:
241   case RISCV::VMV_S_X:
242   case RISCV::VFMV_S_F:
243   case RISCV::VID_V:
244     return MI.getOperand(1).isUndef();
245   default:
246     return TargetInstrInfo::isReallyTriviallyReMaterializable(MI);
247   }
248 }
249 
forwardCopyWillClobberTuple(unsigned DstReg,unsigned SrcReg,unsigned NumRegs)250 static bool forwardCopyWillClobberTuple(unsigned DstReg, unsigned SrcReg,
251                                         unsigned NumRegs) {
252   return DstReg > SrcReg && (DstReg - SrcReg) < NumRegs;
253 }
254 
isConvertibleToVMV_V_V(const RISCVSubtarget & STI,const MachineBasicBlock & MBB,MachineBasicBlock::const_iterator MBBI,MachineBasicBlock::const_iterator & DefMBBI,RISCVVType::VLMUL LMul)255 static bool isConvertibleToVMV_V_V(const RISCVSubtarget &STI,
256                                    const MachineBasicBlock &MBB,
257                                    MachineBasicBlock::const_iterator MBBI,
258                                    MachineBasicBlock::const_iterator &DefMBBI,
259                                    RISCVVType::VLMUL LMul) {
260   if (PreferWholeRegisterMove)
261     return false;
262 
263   assert(MBBI->getOpcode() == TargetOpcode::COPY &&
264          "Unexpected COPY instruction.");
265   Register SrcReg = MBBI->getOperand(1).getReg();
266   const TargetRegisterInfo *TRI = STI.getRegisterInfo();
267 
268   bool FoundDef = false;
269   bool FirstVSetVLI = false;
270   unsigned FirstSEW = 0;
271   while (MBBI != MBB.begin()) {
272     --MBBI;
273     if (MBBI->isMetaInstruction())
274       continue;
275 
276     if (RISCVInstrInfo::isVectorConfigInstr(*MBBI)) {
277       // There is a vsetvli between COPY and source define instruction.
278       // vy = def_vop ...  (producing instruction)
279       // ...
280       // vsetvli
281       // ...
282       // vx = COPY vy
283       if (!FoundDef) {
284         if (!FirstVSetVLI) {
285           FirstVSetVLI = true;
286           unsigned FirstVType = MBBI->getOperand(2).getImm();
287           RISCVVType::VLMUL FirstLMul = RISCVVType::getVLMUL(FirstVType);
288           FirstSEW = RISCVVType::getSEW(FirstVType);
289           // The first encountered vsetvli must have the same lmul as the
290           // register class of COPY.
291           if (FirstLMul != LMul)
292             return false;
293         }
294         // Only permit `vsetvli x0, x0, vtype` between COPY and the source
295         // define instruction.
296         if (!RISCVInstrInfo::isVLPreservingConfig(*MBBI))
297           return false;
298         continue;
299       }
300 
301       // MBBI is the first vsetvli before the producing instruction.
302       unsigned VType = MBBI->getOperand(2).getImm();
303       // If there is a vsetvli between COPY and the producing instruction.
304       if (FirstVSetVLI) {
305         // If SEW is different, return false.
306         if (RISCVVType::getSEW(VType) != FirstSEW)
307           return false;
308       }
309 
310       // If the vsetvli is tail undisturbed, keep the whole register move.
311       if (!RISCVVType::isTailAgnostic(VType))
312         return false;
313 
314       // The checking is conservative. We only have register classes for
315       // LMUL = 1/2/4/8. We should be able to convert vmv1r.v to vmv.v.v
316       // for fractional LMUL operations. However, we could not use the vsetvli
317       // lmul for widening operations. The result of widening operation is
318       // 2 x LMUL.
319       return LMul == RISCVVType::getVLMUL(VType);
320     } else if (MBBI->isInlineAsm() || MBBI->isCall()) {
321       return false;
322     } else if (MBBI->getNumDefs()) {
323       // Check all the instructions which will change VL.
324       // For example, vleff has implicit def VL.
325       if (MBBI->modifiesRegister(RISCV::VL, /*TRI=*/nullptr))
326         return false;
327 
328       // Only converting whole register copies to vmv.v.v when the defining
329       // value appears in the explicit operands.
330       for (const MachineOperand &MO : MBBI->explicit_operands()) {
331         if (!MO.isReg() || !MO.isDef())
332           continue;
333         if (!FoundDef && TRI->regsOverlap(MO.getReg(), SrcReg)) {
334           // We only permit the source of COPY has the same LMUL as the defined
335           // operand.
336           // There are cases we need to keep the whole register copy if the LMUL
337           // is different.
338           // For example,
339           // $x0 = PseudoVSETIVLI 4, 73   // vsetivli zero, 4, e16,m2,ta,m
340           // $v28m4 = PseudoVWADD_VV_M2 $v26m2, $v8m2
341           // # The COPY may be created by vlmul_trunc intrinsic.
342           // $v26m2 = COPY renamable $v28m2, implicit killed $v28m4
343           //
344           // After widening, the valid value will be 4 x e32 elements. If we
345           // convert the COPY to vmv.v.v, it will only copy 4 x e16 elements.
346           // FIXME: The COPY of subregister of Zvlsseg register will not be able
347           // to convert to vmv.v.[v|i] under the constraint.
348           if (MO.getReg() != SrcReg)
349             return false;
350 
351           // In widening reduction instructions with LMUL_1 input vector case,
352           // only checking the LMUL is insufficient due to reduction result is
353           // always LMUL_1.
354           // For example,
355           // $x11 = PseudoVSETIVLI 1, 64 // vsetivli a1, 1, e8, m1, ta, mu
356           // $v8m1 = PseudoVWREDSUM_VS_M1 $v26, $v27
357           // $v26 = COPY killed renamable $v8
358           // After widening, The valid value will be 1 x e16 elements. If we
359           // convert the COPY to vmv.v.v, it will only copy 1 x e8 elements.
360           uint64_t TSFlags = MBBI->getDesc().TSFlags;
361           if (RISCVII::isRVVWideningReduction(TSFlags))
362             return false;
363 
364           // If the producing instruction does not depend on vsetvli, do not
365           // convert COPY to vmv.v.v. For example, VL1R_V or PseudoVRELOAD.
366           if (!RISCVII::hasSEWOp(TSFlags) || !RISCVII::hasVLOp(TSFlags))
367             return false;
368 
369           // Found the definition.
370           FoundDef = true;
371           DefMBBI = MBBI;
372           break;
373         }
374       }
375     }
376   }
377 
378   return false;
379 }
380 
copyPhysRegVector(MachineBasicBlock & MBB,MachineBasicBlock::iterator MBBI,const DebugLoc & DL,MCRegister DstReg,MCRegister SrcReg,bool KillSrc,const TargetRegisterClass * RegClass) const381 void RISCVInstrInfo::copyPhysRegVector(
382     MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
383     const DebugLoc &DL, MCRegister DstReg, MCRegister SrcReg, bool KillSrc,
384     const TargetRegisterClass *RegClass) const {
385   const TargetRegisterInfo *TRI = STI.getRegisterInfo();
386   RISCVVType::VLMUL LMul = RISCVRI::getLMul(RegClass->TSFlags);
387   unsigned NF = RISCVRI::getNF(RegClass->TSFlags);
388 
389   uint16_t SrcEncoding = TRI->getEncodingValue(SrcReg);
390   uint16_t DstEncoding = TRI->getEncodingValue(DstReg);
391   auto [LMulVal, Fractional] = RISCVVType::decodeVLMUL(LMul);
392   assert(!Fractional && "It is impossible be fractional lmul here.");
393   unsigned NumRegs = NF * LMulVal;
394   bool ReversedCopy =
395       forwardCopyWillClobberTuple(DstEncoding, SrcEncoding, NumRegs);
396   if (ReversedCopy) {
397     // If the src and dest overlap when copying a tuple, we need to copy the
398     // registers in reverse.
399     SrcEncoding += NumRegs - 1;
400     DstEncoding += NumRegs - 1;
401   }
402 
403   unsigned I = 0;
404   auto GetCopyInfo = [&](uint16_t SrcEncoding, uint16_t DstEncoding)
405       -> std::tuple<RISCVVType::VLMUL, const TargetRegisterClass &, unsigned,
406                     unsigned, unsigned> {
407     if (ReversedCopy) {
408       // For reversed copying, if there are enough aligned registers(8/4/2), we
409       // can do a larger copy(LMUL8/4/2).
410       // Besides, we have already known that DstEncoding is larger than
411       // SrcEncoding in forwardCopyWillClobberTuple, so the difference between
412       // DstEncoding and SrcEncoding should be >= LMUL value we try to use to
413       // avoid clobbering.
414       uint16_t Diff = DstEncoding - SrcEncoding;
415       if (I + 8 <= NumRegs && Diff >= 8 && SrcEncoding % 8 == 7 &&
416           DstEncoding % 8 == 7)
417         return {RISCVVType::LMUL_8, RISCV::VRM8RegClass, RISCV::VMV8R_V,
418                 RISCV::PseudoVMV_V_V_M8, RISCV::PseudoVMV_V_I_M8};
419       if (I + 4 <= NumRegs && Diff >= 4 && SrcEncoding % 4 == 3 &&
420           DstEncoding % 4 == 3)
421         return {RISCVVType::LMUL_4, RISCV::VRM4RegClass, RISCV::VMV4R_V,
422                 RISCV::PseudoVMV_V_V_M4, RISCV::PseudoVMV_V_I_M4};
423       if (I + 2 <= NumRegs && Diff >= 2 && SrcEncoding % 2 == 1 &&
424           DstEncoding % 2 == 1)
425         return {RISCVVType::LMUL_2, RISCV::VRM2RegClass, RISCV::VMV2R_V,
426                 RISCV::PseudoVMV_V_V_M2, RISCV::PseudoVMV_V_I_M2};
427       // Or we should do LMUL1 copying.
428       return {RISCVVType::LMUL_1, RISCV::VRRegClass, RISCV::VMV1R_V,
429               RISCV::PseudoVMV_V_V_M1, RISCV::PseudoVMV_V_I_M1};
430     }
431 
432     // For forward copying, if source register encoding and destination register
433     // encoding are aligned to 8/4/2, we can do a LMUL8/4/2 copying.
434     if (I + 8 <= NumRegs && SrcEncoding % 8 == 0 && DstEncoding % 8 == 0)
435       return {RISCVVType::LMUL_8, RISCV::VRM8RegClass, RISCV::VMV8R_V,
436               RISCV::PseudoVMV_V_V_M8, RISCV::PseudoVMV_V_I_M8};
437     if (I + 4 <= NumRegs && SrcEncoding % 4 == 0 && DstEncoding % 4 == 0)
438       return {RISCVVType::LMUL_4, RISCV::VRM4RegClass, RISCV::VMV4R_V,
439               RISCV::PseudoVMV_V_V_M4, RISCV::PseudoVMV_V_I_M4};
440     if (I + 2 <= NumRegs && SrcEncoding % 2 == 0 && DstEncoding % 2 == 0)
441       return {RISCVVType::LMUL_2, RISCV::VRM2RegClass, RISCV::VMV2R_V,
442               RISCV::PseudoVMV_V_V_M2, RISCV::PseudoVMV_V_I_M2};
443     // Or we should do LMUL1 copying.
444     return {RISCVVType::LMUL_1, RISCV::VRRegClass, RISCV::VMV1R_V,
445             RISCV::PseudoVMV_V_V_M1, RISCV::PseudoVMV_V_I_M1};
446   };
447   auto FindRegWithEncoding = [TRI](const TargetRegisterClass &RegClass,
448                                    uint16_t Encoding) {
449     MCRegister Reg = RISCV::V0 + Encoding;
450     if (RISCVRI::getLMul(RegClass.TSFlags) == RISCVVType::LMUL_1)
451       return Reg;
452     return TRI->getMatchingSuperReg(Reg, RISCV::sub_vrm1_0, &RegClass);
453   };
454   while (I != NumRegs) {
455     // For non-segment copying, we only do this once as the registers are always
456     // aligned.
457     // For segment copying, we may do this several times. If the registers are
458     // aligned to larger LMUL, we can eliminate some copyings.
459     auto [LMulCopied, RegClass, Opc, VVOpc, VIOpc] =
460         GetCopyInfo(SrcEncoding, DstEncoding);
461     auto [NumCopied, _] = RISCVVType::decodeVLMUL(LMulCopied);
462 
463     MachineBasicBlock::const_iterator DefMBBI;
464     if (LMul == LMulCopied &&
465         isConvertibleToVMV_V_V(STI, MBB, MBBI, DefMBBI, LMul)) {
466       Opc = VVOpc;
467       if (DefMBBI->getOpcode() == VIOpc)
468         Opc = VIOpc;
469     }
470 
471     // Emit actual copying.
472     // For reversed copying, the encoding should be decreased.
473     MCRegister ActualSrcReg = FindRegWithEncoding(
474         RegClass, ReversedCopy ? (SrcEncoding - NumCopied + 1) : SrcEncoding);
475     MCRegister ActualDstReg = FindRegWithEncoding(
476         RegClass, ReversedCopy ? (DstEncoding - NumCopied + 1) : DstEncoding);
477 
478     auto MIB = BuildMI(MBB, MBBI, DL, get(Opc), ActualDstReg);
479     bool UseVMV_V_I = RISCV::getRVVMCOpcode(Opc) == RISCV::VMV_V_I;
480     bool UseVMV = UseVMV_V_I || RISCV::getRVVMCOpcode(Opc) == RISCV::VMV_V_V;
481     if (UseVMV)
482       MIB.addReg(ActualDstReg, RegState::Undef);
483     if (UseVMV_V_I)
484       MIB = MIB.add(DefMBBI->getOperand(2));
485     else
486       MIB = MIB.addReg(ActualSrcReg, getKillRegState(KillSrc));
487     if (UseVMV) {
488       const MCInstrDesc &Desc = DefMBBI->getDesc();
489       MIB.add(DefMBBI->getOperand(RISCVII::getVLOpNum(Desc)));  // AVL
490       unsigned Log2SEW =
491           DefMBBI->getOperand(RISCVII::getSEWOpNum(Desc)).getImm();
492       MIB.addImm(Log2SEW ? Log2SEW : 3);                        // SEW
493       MIB.addImm(0);                                            // tu, mu
494       MIB.addReg(RISCV::VL, RegState::Implicit);
495       MIB.addReg(RISCV::VTYPE, RegState::Implicit);
496     }
497     // Add an implicit read of the original source to silence the verifier
498     // in the cases where some of the smaller VRs we're copying from might be
499     // undef, caused by the fact that the original, larger source VR might not
500     // be fully initialized at the time this COPY happens.
501     MIB.addReg(SrcReg, RegState::Implicit);
502 
503     // If we are copying reversely, we should decrease the encoding.
504     SrcEncoding += (ReversedCopy ? -NumCopied : NumCopied);
505     DstEncoding += (ReversedCopy ? -NumCopied : NumCopied);
506     I += NumCopied;
507   }
508 }
509 
copyPhysReg(MachineBasicBlock & MBB,MachineBasicBlock::iterator MBBI,const DebugLoc & DL,Register DstReg,Register SrcReg,bool KillSrc,bool RenamableDest,bool RenamableSrc) const510 void RISCVInstrInfo::copyPhysReg(MachineBasicBlock &MBB,
511                                  MachineBasicBlock::iterator MBBI,
512                                  const DebugLoc &DL, Register DstReg,
513                                  Register SrcReg, bool KillSrc,
514                                  bool RenamableDest, bool RenamableSrc) const {
515   const TargetRegisterInfo *TRI = STI.getRegisterInfo();
516   unsigned KillFlag = getKillRegState(KillSrc);
517 
518   if (RISCV::GPRRegClass.contains(DstReg, SrcReg)) {
519     BuildMI(MBB, MBBI, DL, get(RISCV::ADDI), DstReg)
520         .addReg(SrcReg, KillFlag | getRenamableRegState(RenamableSrc))
521         .addImm(0);
522     return;
523   }
524 
525   if (RISCV::GPRF16RegClass.contains(DstReg, SrcReg)) {
526     BuildMI(MBB, MBBI, DL, get(RISCV::PseudoMV_FPR16INX), DstReg)
527         .addReg(SrcReg, KillFlag | getRenamableRegState(RenamableSrc));
528     return;
529   }
530 
531   if (RISCV::GPRF32RegClass.contains(DstReg, SrcReg)) {
532     BuildMI(MBB, MBBI, DL, get(RISCV::PseudoMV_FPR32INX), DstReg)
533         .addReg(SrcReg, KillFlag | getRenamableRegState(RenamableSrc));
534     return;
535   }
536 
537   if (RISCV::GPRPairRegClass.contains(DstReg, SrcReg)) {
538     MCRegister EvenReg = TRI->getSubReg(SrcReg, RISCV::sub_gpr_even);
539     MCRegister OddReg = TRI->getSubReg(SrcReg, RISCV::sub_gpr_odd);
540     // We need to correct the odd register of X0_Pair.
541     if (OddReg == RISCV::DUMMY_REG_PAIR_WITH_X0)
542       OddReg = RISCV::X0;
543     assert(DstReg != RISCV::X0_Pair && "Cannot write to X0_Pair");
544 
545     // Emit an ADDI for both parts of GPRPair.
546     BuildMI(MBB, MBBI, DL, get(RISCV::ADDI),
547             TRI->getSubReg(DstReg, RISCV::sub_gpr_even))
548         .addReg(EvenReg, KillFlag)
549         .addImm(0);
550     BuildMI(MBB, MBBI, DL, get(RISCV::ADDI),
551             TRI->getSubReg(DstReg, RISCV::sub_gpr_odd))
552         .addReg(OddReg, KillFlag)
553         .addImm(0);
554     return;
555   }
556 
557   // Handle copy from csr
558   if (RISCV::VCSRRegClass.contains(SrcReg) &&
559       RISCV::GPRRegClass.contains(DstReg)) {
560     BuildMI(MBB, MBBI, DL, get(RISCV::CSRRS), DstReg)
561         .addImm(RISCVSysReg::lookupSysRegByName(TRI->getName(SrcReg))->Encoding)
562         .addReg(RISCV::X0);
563     return;
564   }
565 
566   if (RISCV::FPR16RegClass.contains(DstReg, SrcReg)) {
567     unsigned Opc;
568     if (STI.hasStdExtZfh()) {
569       Opc = RISCV::FSGNJ_H;
570     } else {
571       assert(STI.hasStdExtF() &&
572              (STI.hasStdExtZfhmin() || STI.hasStdExtZfbfmin()) &&
573              "Unexpected extensions");
574       // Zfhmin/Zfbfmin doesn't have FSGNJ_H, replace FSGNJ_H with FSGNJ_S.
575       DstReg = TRI->getMatchingSuperReg(DstReg, RISCV::sub_16,
576                                         &RISCV::FPR32RegClass);
577       SrcReg = TRI->getMatchingSuperReg(SrcReg, RISCV::sub_16,
578                                         &RISCV::FPR32RegClass);
579       Opc = RISCV::FSGNJ_S;
580     }
581     BuildMI(MBB, MBBI, DL, get(Opc), DstReg)
582         .addReg(SrcReg, KillFlag)
583         .addReg(SrcReg, KillFlag);
584     return;
585   }
586 
587   if (RISCV::FPR32RegClass.contains(DstReg, SrcReg)) {
588     BuildMI(MBB, MBBI, DL, get(RISCV::FSGNJ_S), DstReg)
589         .addReg(SrcReg, KillFlag)
590         .addReg(SrcReg, KillFlag);
591     return;
592   }
593 
594   if (RISCV::FPR64RegClass.contains(DstReg, SrcReg)) {
595     BuildMI(MBB, MBBI, DL, get(RISCV::FSGNJ_D), DstReg)
596         .addReg(SrcReg, KillFlag)
597         .addReg(SrcReg, KillFlag);
598     return;
599   }
600 
601   if (RISCV::FPR32RegClass.contains(DstReg) &&
602       RISCV::GPRRegClass.contains(SrcReg)) {
603     BuildMI(MBB, MBBI, DL, get(RISCV::FMV_W_X), DstReg)
604         .addReg(SrcReg, KillFlag);
605     return;
606   }
607 
608   if (RISCV::GPRRegClass.contains(DstReg) &&
609       RISCV::FPR32RegClass.contains(SrcReg)) {
610     BuildMI(MBB, MBBI, DL, get(RISCV::FMV_X_W), DstReg)
611         .addReg(SrcReg, KillFlag);
612     return;
613   }
614 
615   if (RISCV::FPR64RegClass.contains(DstReg) &&
616       RISCV::GPRRegClass.contains(SrcReg)) {
617     assert(STI.getXLen() == 64 && "Unexpected GPR size");
618     BuildMI(MBB, MBBI, DL, get(RISCV::FMV_D_X), DstReg)
619         .addReg(SrcReg, KillFlag);
620     return;
621   }
622 
623   if (RISCV::GPRRegClass.contains(DstReg) &&
624       RISCV::FPR64RegClass.contains(SrcReg)) {
625     assert(STI.getXLen() == 64 && "Unexpected GPR size");
626     BuildMI(MBB, MBBI, DL, get(RISCV::FMV_X_D), DstReg)
627         .addReg(SrcReg, KillFlag);
628     return;
629   }
630 
631   // VR->VR copies.
632   const TargetRegisterClass *RegClass =
633       TRI->getCommonMinimalPhysRegClass(SrcReg, DstReg);
634   if (RISCVRegisterInfo::isRVVRegClass(RegClass)) {
635     copyPhysRegVector(MBB, MBBI, DL, DstReg, SrcReg, KillSrc, RegClass);
636     return;
637   }
638 
639   llvm_unreachable("Impossible reg-to-reg copy");
640 }
641 
storeRegToStackSlot(MachineBasicBlock & MBB,MachineBasicBlock::iterator I,Register SrcReg,bool IsKill,int FI,const TargetRegisterClass * RC,const TargetRegisterInfo * TRI,Register VReg,MachineInstr::MIFlag Flags) const642 void RISCVInstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB,
643                                          MachineBasicBlock::iterator I,
644                                          Register SrcReg, bool IsKill, int FI,
645                                          const TargetRegisterClass *RC,
646                                          const TargetRegisterInfo *TRI,
647                                          Register VReg,
648                                          MachineInstr::MIFlag Flags) const {
649   MachineFunction *MF = MBB.getParent();
650   MachineFrameInfo &MFI = MF->getFrameInfo();
651 
652   unsigned Opcode;
653   if (RISCV::GPRRegClass.hasSubClassEq(RC)) {
654     Opcode = TRI->getRegSizeInBits(RISCV::GPRRegClass) == 32 ?
655              RISCV::SW : RISCV::SD;
656   } else if (RISCV::GPRF16RegClass.hasSubClassEq(RC)) {
657     Opcode = RISCV::SH_INX;
658   } else if (RISCV::GPRF32RegClass.hasSubClassEq(RC)) {
659     Opcode = RISCV::SW_INX;
660   } else if (RISCV::GPRPairRegClass.hasSubClassEq(RC)) {
661     Opcode = RISCV::PseudoRV32ZdinxSD;
662   } else if (RISCV::FPR16RegClass.hasSubClassEq(RC)) {
663     Opcode = RISCV::FSH;
664   } else if (RISCV::FPR32RegClass.hasSubClassEq(RC)) {
665     Opcode = RISCV::FSW;
666   } else if (RISCV::FPR64RegClass.hasSubClassEq(RC)) {
667     Opcode = RISCV::FSD;
668   } else if (RISCV::VRRegClass.hasSubClassEq(RC)) {
669     Opcode = RISCV::VS1R_V;
670   } else if (RISCV::VRM2RegClass.hasSubClassEq(RC)) {
671     Opcode = RISCV::VS2R_V;
672   } else if (RISCV::VRM4RegClass.hasSubClassEq(RC)) {
673     Opcode = RISCV::VS4R_V;
674   } else if (RISCV::VRM8RegClass.hasSubClassEq(RC)) {
675     Opcode = RISCV::VS8R_V;
676   } else if (RISCV::VRN2M1RegClass.hasSubClassEq(RC))
677     Opcode = RISCV::PseudoVSPILL2_M1;
678   else if (RISCV::VRN2M2RegClass.hasSubClassEq(RC))
679     Opcode = RISCV::PseudoVSPILL2_M2;
680   else if (RISCV::VRN2M4RegClass.hasSubClassEq(RC))
681     Opcode = RISCV::PseudoVSPILL2_M4;
682   else if (RISCV::VRN3M1RegClass.hasSubClassEq(RC))
683     Opcode = RISCV::PseudoVSPILL3_M1;
684   else if (RISCV::VRN3M2RegClass.hasSubClassEq(RC))
685     Opcode = RISCV::PseudoVSPILL3_M2;
686   else if (RISCV::VRN4M1RegClass.hasSubClassEq(RC))
687     Opcode = RISCV::PseudoVSPILL4_M1;
688   else if (RISCV::VRN4M2RegClass.hasSubClassEq(RC))
689     Opcode = RISCV::PseudoVSPILL4_M2;
690   else if (RISCV::VRN5M1RegClass.hasSubClassEq(RC))
691     Opcode = RISCV::PseudoVSPILL5_M1;
692   else if (RISCV::VRN6M1RegClass.hasSubClassEq(RC))
693     Opcode = RISCV::PseudoVSPILL6_M1;
694   else if (RISCV::VRN7M1RegClass.hasSubClassEq(RC))
695     Opcode = RISCV::PseudoVSPILL7_M1;
696   else if (RISCV::VRN8M1RegClass.hasSubClassEq(RC))
697     Opcode = RISCV::PseudoVSPILL8_M1;
698   else
699     llvm_unreachable("Can't store this register to stack slot");
700 
701   if (RISCVRegisterInfo::isRVVRegClass(RC)) {
702     MachineMemOperand *MMO = MF->getMachineMemOperand(
703         MachinePointerInfo::getFixedStack(*MF, FI), MachineMemOperand::MOStore,
704         TypeSize::getScalable(MFI.getObjectSize(FI)), MFI.getObjectAlign(FI));
705 
706     MFI.setStackID(FI, TargetStackID::ScalableVector);
707     BuildMI(MBB, I, DebugLoc(), get(Opcode))
708         .addReg(SrcReg, getKillRegState(IsKill))
709         .addFrameIndex(FI)
710         .addMemOperand(MMO)
711         .setMIFlag(Flags);
712     NumVRegSpilled += TRI->getRegSizeInBits(*RC) / RISCV::RVVBitsPerBlock;
713   } else {
714     MachineMemOperand *MMO = MF->getMachineMemOperand(
715         MachinePointerInfo::getFixedStack(*MF, FI), MachineMemOperand::MOStore,
716         MFI.getObjectSize(FI), MFI.getObjectAlign(FI));
717 
718     BuildMI(MBB, I, DebugLoc(), get(Opcode))
719         .addReg(SrcReg, getKillRegState(IsKill))
720         .addFrameIndex(FI)
721         .addImm(0)
722         .addMemOperand(MMO)
723         .setMIFlag(Flags);
724   }
725 }
726 
loadRegFromStackSlot(MachineBasicBlock & MBB,MachineBasicBlock::iterator I,Register DstReg,int FI,const TargetRegisterClass * RC,const TargetRegisterInfo * TRI,Register VReg,MachineInstr::MIFlag Flags) const727 void RISCVInstrInfo::loadRegFromStackSlot(
728     MachineBasicBlock &MBB, MachineBasicBlock::iterator I, Register DstReg,
729     int FI, const TargetRegisterClass *RC, const TargetRegisterInfo *TRI,
730     Register VReg, MachineInstr::MIFlag Flags) const {
731   MachineFunction *MF = MBB.getParent();
732   MachineFrameInfo &MFI = MF->getFrameInfo();
733   DebugLoc DL =
734       Flags & MachineInstr::FrameDestroy ? MBB.findDebugLoc(I) : DebugLoc();
735 
736   unsigned Opcode;
737   if (RISCV::GPRRegClass.hasSubClassEq(RC)) {
738     Opcode = TRI->getRegSizeInBits(RISCV::GPRRegClass) == 32 ?
739              RISCV::LW : RISCV::LD;
740   } else if (RISCV::GPRF16RegClass.hasSubClassEq(RC)) {
741     Opcode = RISCV::LH_INX;
742   } else if (RISCV::GPRF32RegClass.hasSubClassEq(RC)) {
743     Opcode = RISCV::LW_INX;
744   } else if (RISCV::GPRPairRegClass.hasSubClassEq(RC)) {
745     Opcode = RISCV::PseudoRV32ZdinxLD;
746   } else if (RISCV::FPR16RegClass.hasSubClassEq(RC)) {
747     Opcode = RISCV::FLH;
748   } else if (RISCV::FPR32RegClass.hasSubClassEq(RC)) {
749     Opcode = RISCV::FLW;
750   } else if (RISCV::FPR64RegClass.hasSubClassEq(RC)) {
751     Opcode = RISCV::FLD;
752   } else if (RISCV::VRRegClass.hasSubClassEq(RC)) {
753     Opcode = RISCV::VL1RE8_V;
754   } else if (RISCV::VRM2RegClass.hasSubClassEq(RC)) {
755     Opcode = RISCV::VL2RE8_V;
756   } else if (RISCV::VRM4RegClass.hasSubClassEq(RC)) {
757     Opcode = RISCV::VL4RE8_V;
758   } else if (RISCV::VRM8RegClass.hasSubClassEq(RC)) {
759     Opcode = RISCV::VL8RE8_V;
760   } else if (RISCV::VRN2M1RegClass.hasSubClassEq(RC))
761     Opcode = RISCV::PseudoVRELOAD2_M1;
762   else if (RISCV::VRN2M2RegClass.hasSubClassEq(RC))
763     Opcode = RISCV::PseudoVRELOAD2_M2;
764   else if (RISCV::VRN2M4RegClass.hasSubClassEq(RC))
765     Opcode = RISCV::PseudoVRELOAD2_M4;
766   else if (RISCV::VRN3M1RegClass.hasSubClassEq(RC))
767     Opcode = RISCV::PseudoVRELOAD3_M1;
768   else if (RISCV::VRN3M2RegClass.hasSubClassEq(RC))
769     Opcode = RISCV::PseudoVRELOAD3_M2;
770   else if (RISCV::VRN4M1RegClass.hasSubClassEq(RC))
771     Opcode = RISCV::PseudoVRELOAD4_M1;
772   else if (RISCV::VRN4M2RegClass.hasSubClassEq(RC))
773     Opcode = RISCV::PseudoVRELOAD4_M2;
774   else if (RISCV::VRN5M1RegClass.hasSubClassEq(RC))
775     Opcode = RISCV::PseudoVRELOAD5_M1;
776   else if (RISCV::VRN6M1RegClass.hasSubClassEq(RC))
777     Opcode = RISCV::PseudoVRELOAD6_M1;
778   else if (RISCV::VRN7M1RegClass.hasSubClassEq(RC))
779     Opcode = RISCV::PseudoVRELOAD7_M1;
780   else if (RISCV::VRN8M1RegClass.hasSubClassEq(RC))
781     Opcode = RISCV::PseudoVRELOAD8_M1;
782   else
783     llvm_unreachable("Can't load this register from stack slot");
784 
785   if (RISCVRegisterInfo::isRVVRegClass(RC)) {
786     MachineMemOperand *MMO = MF->getMachineMemOperand(
787         MachinePointerInfo::getFixedStack(*MF, FI), MachineMemOperand::MOLoad,
788         TypeSize::getScalable(MFI.getObjectSize(FI)), MFI.getObjectAlign(FI));
789 
790     MFI.setStackID(FI, TargetStackID::ScalableVector);
791     BuildMI(MBB, I, DL, get(Opcode), DstReg)
792         .addFrameIndex(FI)
793         .addMemOperand(MMO)
794         .setMIFlag(Flags);
795     NumVRegReloaded += TRI->getRegSizeInBits(*RC) / RISCV::RVVBitsPerBlock;
796   } else {
797     MachineMemOperand *MMO = MF->getMachineMemOperand(
798         MachinePointerInfo::getFixedStack(*MF, FI), MachineMemOperand::MOLoad,
799         MFI.getObjectSize(FI), MFI.getObjectAlign(FI));
800 
801     BuildMI(MBB, I, DL, get(Opcode), DstReg)
802         .addFrameIndex(FI)
803         .addImm(0)
804         .addMemOperand(MMO)
805         .setMIFlag(Flags);
806   }
807 }
getFoldedOpcode(MachineFunction & MF,MachineInstr & MI,ArrayRef<unsigned> Ops,const RISCVSubtarget & ST)808 std::optional<unsigned> getFoldedOpcode(MachineFunction &MF, MachineInstr &MI,
809                                         ArrayRef<unsigned> Ops,
810                                         const RISCVSubtarget &ST) {
811 
812   // The below optimizations narrow the load so they are only valid for little
813   // endian.
814   // TODO: Support big endian by adding an offset into the frame object?
815   if (MF.getDataLayout().isBigEndian())
816     return std::nullopt;
817 
818   // Fold load from stack followed by sext.b/sext.h/sext.w/zext.b/zext.h/zext.w.
819   if (Ops.size() != 1 || Ops[0] != 1)
820     return std::nullopt;
821 
822   switch (MI.getOpcode()) {
823   default:
824     if (RISCVInstrInfo::isSEXT_W(MI))
825       return RISCV::LW;
826     if (RISCVInstrInfo::isZEXT_W(MI))
827       return RISCV::LWU;
828     if (RISCVInstrInfo::isZEXT_B(MI))
829       return RISCV::LBU;
830     break;
831   case RISCV::SEXT_H:
832     return RISCV::LH;
833   case RISCV::SEXT_B:
834     return RISCV::LB;
835   case RISCV::ZEXT_H_RV32:
836   case RISCV::ZEXT_H_RV64:
837     return RISCV::LHU;
838   }
839 
840   switch (RISCV::getRVVMCOpcode(MI.getOpcode())) {
841   default:
842     return std::nullopt;
843   case RISCV::VMV_X_S: {
844     unsigned Log2SEW =
845         MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
846     if (ST.getXLen() < (1U << Log2SEW))
847       return std::nullopt;
848     switch (Log2SEW) {
849     case 3:
850       return RISCV::LB;
851     case 4:
852       return RISCV::LH;
853     case 5:
854       return RISCV::LW;
855     case 6:
856       return RISCV::LD;
857     default:
858       llvm_unreachable("Unexpected SEW");
859     }
860   }
861   case RISCV::VFMV_F_S: {
862     unsigned Log2SEW =
863         MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
864     switch (Log2SEW) {
865     case 4:
866       return RISCV::FLH;
867     case 5:
868       return RISCV::FLW;
869     case 6:
870       return RISCV::FLD;
871     default:
872       llvm_unreachable("Unexpected SEW");
873     }
874   }
875   }
876 }
877 
878 // This is the version used during inline spilling
foldMemoryOperandImpl(MachineFunction & MF,MachineInstr & MI,ArrayRef<unsigned> Ops,MachineBasicBlock::iterator InsertPt,int FrameIndex,LiveIntervals * LIS,VirtRegMap * VRM) const879 MachineInstr *RISCVInstrInfo::foldMemoryOperandImpl(
880     MachineFunction &MF, MachineInstr &MI, ArrayRef<unsigned> Ops,
881     MachineBasicBlock::iterator InsertPt, int FrameIndex, LiveIntervals *LIS,
882     VirtRegMap *VRM) const {
883 
884   std::optional<unsigned> LoadOpc = getFoldedOpcode(MF, MI, Ops, STI);
885   if (!LoadOpc)
886     return nullptr;
887   Register DstReg = MI.getOperand(0).getReg();
888   return BuildMI(*MI.getParent(), InsertPt, MI.getDebugLoc(), get(*LoadOpc),
889                  DstReg)
890       .addFrameIndex(FrameIndex)
891       .addImm(0);
892 }
893 
movImm(MachineBasicBlock & MBB,MachineBasicBlock::iterator MBBI,const DebugLoc & DL,Register DstReg,uint64_t Val,MachineInstr::MIFlag Flag,bool DstRenamable,bool DstIsDead) const894 void RISCVInstrInfo::movImm(MachineBasicBlock &MBB,
895                             MachineBasicBlock::iterator MBBI,
896                             const DebugLoc &DL, Register DstReg, uint64_t Val,
897                             MachineInstr::MIFlag Flag, bool DstRenamable,
898                             bool DstIsDead) const {
899   Register SrcReg = RISCV::X0;
900 
901   // For RV32, allow a sign or unsigned 32 bit value.
902   if (!STI.is64Bit() && !isInt<32>(Val)) {
903     // If have a uimm32 it will still fit in a register so we can allow it.
904     if (!isUInt<32>(Val))
905       report_fatal_error("Should only materialize 32-bit constants for RV32");
906 
907     // Sign extend for generateInstSeq.
908     Val = SignExtend64<32>(Val);
909   }
910 
911   RISCVMatInt::InstSeq Seq = RISCVMatInt::generateInstSeq(Val, STI);
912   assert(!Seq.empty());
913 
914   bool SrcRenamable = false;
915   unsigned Num = 0;
916 
917   for (const RISCVMatInt::Inst &Inst : Seq) {
918     bool LastItem = ++Num == Seq.size();
919     unsigned DstRegState = getDeadRegState(DstIsDead && LastItem) |
920                            getRenamableRegState(DstRenamable);
921     unsigned SrcRegState = getKillRegState(SrcReg != RISCV::X0) |
922                            getRenamableRegState(SrcRenamable);
923     switch (Inst.getOpndKind()) {
924     case RISCVMatInt::Imm:
925       BuildMI(MBB, MBBI, DL, get(Inst.getOpcode()))
926           .addReg(DstReg, RegState::Define | DstRegState)
927           .addImm(Inst.getImm())
928           .setMIFlag(Flag);
929       break;
930     case RISCVMatInt::RegX0:
931       BuildMI(MBB, MBBI, DL, get(Inst.getOpcode()))
932           .addReg(DstReg, RegState::Define | DstRegState)
933           .addReg(SrcReg, SrcRegState)
934           .addReg(RISCV::X0)
935           .setMIFlag(Flag);
936       break;
937     case RISCVMatInt::RegReg:
938       BuildMI(MBB, MBBI, DL, get(Inst.getOpcode()))
939           .addReg(DstReg, RegState::Define | DstRegState)
940           .addReg(SrcReg, SrcRegState)
941           .addReg(SrcReg, SrcRegState)
942           .setMIFlag(Flag);
943       break;
944     case RISCVMatInt::RegImm:
945       BuildMI(MBB, MBBI, DL, get(Inst.getOpcode()))
946           .addReg(DstReg, RegState::Define | DstRegState)
947           .addReg(SrcReg, SrcRegState)
948           .addImm(Inst.getImm())
949           .setMIFlag(Flag);
950       break;
951     }
952 
953     // Only the first instruction has X0 as its source.
954     SrcReg = DstReg;
955     SrcRenamable = DstRenamable;
956   }
957 }
958 
getCondFromBranchOpc(unsigned Opc)959 RISCVCC::CondCode RISCVInstrInfo::getCondFromBranchOpc(unsigned Opc) {
960   switch (Opc) {
961   default:
962     return RISCVCC::COND_INVALID;
963   case RISCV::BEQ:
964   case RISCV::CV_BEQIMM:
965   case RISCV::QC_BEQI:
966   case RISCV::QC_E_BEQI:
967   case RISCV::NDS_BBC:
968   case RISCV::NDS_BEQC:
969     return RISCVCC::COND_EQ;
970   case RISCV::BNE:
971   case RISCV::QC_BNEI:
972   case RISCV::QC_E_BNEI:
973   case RISCV::CV_BNEIMM:
974   case RISCV::NDS_BBS:
975   case RISCV::NDS_BNEC:
976     return RISCVCC::COND_NE;
977   case RISCV::BLT:
978   case RISCV::QC_BLTI:
979   case RISCV::QC_E_BLTI:
980     return RISCVCC::COND_LT;
981   case RISCV::BGE:
982   case RISCV::QC_BGEI:
983   case RISCV::QC_E_BGEI:
984     return RISCVCC::COND_GE;
985   case RISCV::BLTU:
986   case RISCV::QC_BLTUI:
987   case RISCV::QC_E_BLTUI:
988     return RISCVCC::COND_LTU;
989   case RISCV::BGEU:
990   case RISCV::QC_BGEUI:
991   case RISCV::QC_E_BGEUI:
992     return RISCVCC::COND_GEU;
993   }
994 }
995 
evaluateCondBranch(RISCVCC::CondCode CC,int64_t C0,int64_t C1)996 bool RISCVInstrInfo::evaluateCondBranch(RISCVCC::CondCode CC, int64_t C0,
997                                         int64_t C1) {
998   switch (CC) {
999   default:
1000     llvm_unreachable("Unexpected CC");
1001   case RISCVCC::COND_EQ:
1002     return C0 == C1;
1003   case RISCVCC::COND_NE:
1004     return C0 != C1;
1005   case RISCVCC::COND_LT:
1006     return C0 < C1;
1007   case RISCVCC::COND_GE:
1008     return C0 >= C1;
1009   case RISCVCC::COND_LTU:
1010     return (uint64_t)C0 < (uint64_t)C1;
1011   case RISCVCC::COND_GEU:
1012     return (uint64_t)C0 >= (uint64_t)C1;
1013   }
1014 }
1015 
1016 // The contents of values added to Cond are not examined outside of
1017 // RISCVInstrInfo, giving us flexibility in what to push to it. For RISCV, we
1018 // push BranchOpcode, Reg1, Reg2.
parseCondBranch(MachineInstr & LastInst,MachineBasicBlock * & Target,SmallVectorImpl<MachineOperand> & Cond)1019 static void parseCondBranch(MachineInstr &LastInst, MachineBasicBlock *&Target,
1020                             SmallVectorImpl<MachineOperand> &Cond) {
1021   // Block ends with fall-through condbranch.
1022   assert(LastInst.getDesc().isConditionalBranch() &&
1023          "Unknown conditional branch");
1024   Target = LastInst.getOperand(2).getMBB();
1025   Cond.push_back(MachineOperand::CreateImm(LastInst.getOpcode()));
1026   Cond.push_back(LastInst.getOperand(0));
1027   Cond.push_back(LastInst.getOperand(1));
1028 }
1029 
getBrCond(RISCVCC::CondCode CC,unsigned SelectOpc)1030 unsigned RISCVCC::getBrCond(RISCVCC::CondCode CC, unsigned SelectOpc) {
1031   switch (SelectOpc) {
1032   default:
1033     switch (CC) {
1034     default:
1035       llvm_unreachable("Unexpected condition code!");
1036     case RISCVCC::COND_EQ:
1037       return RISCV::BEQ;
1038     case RISCVCC::COND_NE:
1039       return RISCV::BNE;
1040     case RISCVCC::COND_LT:
1041       return RISCV::BLT;
1042     case RISCVCC::COND_GE:
1043       return RISCV::BGE;
1044     case RISCVCC::COND_LTU:
1045       return RISCV::BLTU;
1046     case RISCVCC::COND_GEU:
1047       return RISCV::BGEU;
1048     }
1049     break;
1050   case RISCV::Select_GPR_Using_CC_SImm5_CV:
1051     switch (CC) {
1052     default:
1053       llvm_unreachable("Unexpected condition code!");
1054     case RISCVCC::COND_EQ:
1055       return RISCV::CV_BEQIMM;
1056     case RISCVCC::COND_NE:
1057       return RISCV::CV_BNEIMM;
1058     }
1059     break;
1060   case RISCV::Select_GPRNoX0_Using_CC_SImm5NonZero_QC:
1061     switch (CC) {
1062     default:
1063       llvm_unreachable("Unexpected condition code!");
1064     case RISCVCC::COND_EQ:
1065       return RISCV::QC_BEQI;
1066     case RISCVCC::COND_NE:
1067       return RISCV::QC_BNEI;
1068     case RISCVCC::COND_LT:
1069       return RISCV::QC_BLTI;
1070     case RISCVCC::COND_GE:
1071       return RISCV::QC_BGEI;
1072     }
1073     break;
1074   case RISCV::Select_GPRNoX0_Using_CC_UImm5NonZero_QC:
1075     switch (CC) {
1076     default:
1077       llvm_unreachable("Unexpected condition code!");
1078     case RISCVCC::COND_LTU:
1079       return RISCV::QC_BLTUI;
1080     case RISCVCC::COND_GEU:
1081       return RISCV::QC_BGEUI;
1082     }
1083     break;
1084   case RISCV::Select_GPRNoX0_Using_CC_SImm16NonZero_QC:
1085     switch (CC) {
1086     default:
1087       llvm_unreachable("Unexpected condition code!");
1088     case RISCVCC::COND_EQ:
1089       return RISCV::QC_E_BEQI;
1090     case RISCVCC::COND_NE:
1091       return RISCV::QC_E_BNEI;
1092     case RISCVCC::COND_LT:
1093       return RISCV::QC_E_BLTI;
1094     case RISCVCC::COND_GE:
1095       return RISCV::QC_E_BGEI;
1096     }
1097     break;
1098   case RISCV::Select_GPRNoX0_Using_CC_UImm16NonZero_QC:
1099     switch (CC) {
1100     default:
1101       llvm_unreachable("Unexpected condition code!");
1102     case RISCVCC::COND_LTU:
1103       return RISCV::QC_E_BLTUI;
1104     case RISCVCC::COND_GEU:
1105       return RISCV::QC_E_BGEUI;
1106     }
1107     break;
1108   case RISCV::Select_GPR_Using_CC_UImmLog2XLen_NDS:
1109     switch (CC) {
1110     default:
1111       llvm_unreachable("Unexpected condition code!");
1112     case RISCVCC::COND_EQ:
1113       return RISCV::NDS_BBC;
1114     case RISCVCC::COND_NE:
1115       return RISCV::NDS_BBS;
1116     }
1117     break;
1118   case RISCV::Select_GPR_Using_CC_UImm7_NDS:
1119     switch (CC) {
1120     default:
1121       llvm_unreachable("Unexpected condition code!");
1122     case RISCVCC::COND_EQ:
1123       return RISCV::NDS_BEQC;
1124     case RISCVCC::COND_NE:
1125       return RISCV::NDS_BNEC;
1126     }
1127     break;
1128   }
1129 }
1130 
getOppositeBranchCondition(RISCVCC::CondCode CC)1131 RISCVCC::CondCode RISCVCC::getOppositeBranchCondition(RISCVCC::CondCode CC) {
1132   switch (CC) {
1133   default:
1134     llvm_unreachable("Unrecognized conditional branch");
1135   case RISCVCC::COND_EQ:
1136     return RISCVCC::COND_NE;
1137   case RISCVCC::COND_NE:
1138     return RISCVCC::COND_EQ;
1139   case RISCVCC::COND_LT:
1140     return RISCVCC::COND_GE;
1141   case RISCVCC::COND_GE:
1142     return RISCVCC::COND_LT;
1143   case RISCVCC::COND_LTU:
1144     return RISCVCC::COND_GEU;
1145   case RISCVCC::COND_GEU:
1146     return RISCVCC::COND_LTU;
1147   }
1148 }
1149 
analyzeBranch(MachineBasicBlock & MBB,MachineBasicBlock * & TBB,MachineBasicBlock * & FBB,SmallVectorImpl<MachineOperand> & Cond,bool AllowModify) const1150 bool RISCVInstrInfo::analyzeBranch(MachineBasicBlock &MBB,
1151                                    MachineBasicBlock *&TBB,
1152                                    MachineBasicBlock *&FBB,
1153                                    SmallVectorImpl<MachineOperand> &Cond,
1154                                    bool AllowModify) const {
1155   TBB = FBB = nullptr;
1156   Cond.clear();
1157 
1158   // If the block has no terminators, it just falls into the block after it.
1159   MachineBasicBlock::iterator I = MBB.getLastNonDebugInstr();
1160   if (I == MBB.end() || !isUnpredicatedTerminator(*I))
1161     return false;
1162 
1163   // Count the number of terminators and find the first unconditional or
1164   // indirect branch.
1165   MachineBasicBlock::iterator FirstUncondOrIndirectBr = MBB.end();
1166   int NumTerminators = 0;
1167   for (auto J = I.getReverse(); J != MBB.rend() && isUnpredicatedTerminator(*J);
1168        J++) {
1169     NumTerminators++;
1170     if (J->getDesc().isUnconditionalBranch() ||
1171         J->getDesc().isIndirectBranch()) {
1172       FirstUncondOrIndirectBr = J.getReverse();
1173     }
1174   }
1175 
1176   // If AllowModify is true, we can erase any terminators after
1177   // FirstUncondOrIndirectBR.
1178   if (AllowModify && FirstUncondOrIndirectBr != MBB.end()) {
1179     while (std::next(FirstUncondOrIndirectBr) != MBB.end()) {
1180       std::next(FirstUncondOrIndirectBr)->eraseFromParent();
1181       NumTerminators--;
1182     }
1183     I = FirstUncondOrIndirectBr;
1184   }
1185 
1186   // We can't handle blocks that end in an indirect branch.
1187   if (I->getDesc().isIndirectBranch())
1188     return true;
1189 
1190   // We can't handle Generic branch opcodes from Global ISel.
1191   if (I->isPreISelOpcode())
1192     return true;
1193 
1194   // We can't handle blocks with more than 2 terminators.
1195   if (NumTerminators > 2)
1196     return true;
1197 
1198   // Handle a single unconditional branch.
1199   if (NumTerminators == 1 && I->getDesc().isUnconditionalBranch()) {
1200     TBB = getBranchDestBlock(*I);
1201     return false;
1202   }
1203 
1204   // Handle a single conditional branch.
1205   if (NumTerminators == 1 && I->getDesc().isConditionalBranch()) {
1206     parseCondBranch(*I, TBB, Cond);
1207     return false;
1208   }
1209 
1210   // Handle a conditional branch followed by an unconditional branch.
1211   if (NumTerminators == 2 && std::prev(I)->getDesc().isConditionalBranch() &&
1212       I->getDesc().isUnconditionalBranch()) {
1213     parseCondBranch(*std::prev(I), TBB, Cond);
1214     FBB = getBranchDestBlock(*I);
1215     return false;
1216   }
1217 
1218   // Otherwise, we can't handle this.
1219   return true;
1220 }
1221 
removeBranch(MachineBasicBlock & MBB,int * BytesRemoved) const1222 unsigned RISCVInstrInfo::removeBranch(MachineBasicBlock &MBB,
1223                                       int *BytesRemoved) const {
1224   if (BytesRemoved)
1225     *BytesRemoved = 0;
1226   MachineBasicBlock::iterator I = MBB.getLastNonDebugInstr();
1227   if (I == MBB.end())
1228     return 0;
1229 
1230   if (!I->getDesc().isUnconditionalBranch() &&
1231       !I->getDesc().isConditionalBranch())
1232     return 0;
1233 
1234   // Remove the branch.
1235   if (BytesRemoved)
1236     *BytesRemoved += getInstSizeInBytes(*I);
1237   I->eraseFromParent();
1238 
1239   I = MBB.end();
1240 
1241   if (I == MBB.begin())
1242     return 1;
1243   --I;
1244   if (!I->getDesc().isConditionalBranch())
1245     return 1;
1246 
1247   // Remove the branch.
1248   if (BytesRemoved)
1249     *BytesRemoved += getInstSizeInBytes(*I);
1250   I->eraseFromParent();
1251   return 2;
1252 }
1253 
1254 // Inserts a branch into the end of the specific MachineBasicBlock, returning
1255 // the number of instructions inserted.
insertBranch(MachineBasicBlock & MBB,MachineBasicBlock * TBB,MachineBasicBlock * FBB,ArrayRef<MachineOperand> Cond,const DebugLoc & DL,int * BytesAdded) const1256 unsigned RISCVInstrInfo::insertBranch(
1257     MachineBasicBlock &MBB, MachineBasicBlock *TBB, MachineBasicBlock *FBB,
1258     ArrayRef<MachineOperand> Cond, const DebugLoc &DL, int *BytesAdded) const {
1259   if (BytesAdded)
1260     *BytesAdded = 0;
1261 
1262   // Shouldn't be a fall through.
1263   assert(TBB && "insertBranch must not be told to insert a fallthrough");
1264   assert((Cond.size() == 3 || Cond.size() == 0) &&
1265          "RISC-V branch conditions have two components!");
1266 
1267   // Unconditional branch.
1268   if (Cond.empty()) {
1269     MachineInstr &MI = *BuildMI(&MBB, DL, get(RISCV::PseudoBR)).addMBB(TBB);
1270     if (BytesAdded)
1271       *BytesAdded += getInstSizeInBytes(MI);
1272     return 1;
1273   }
1274 
1275   // Either a one or two-way conditional branch.
1276   MachineInstr &CondMI = *BuildMI(&MBB, DL, get(Cond[0].getImm()))
1277                               .add(Cond[1])
1278                               .add(Cond[2])
1279                               .addMBB(TBB);
1280   if (BytesAdded)
1281     *BytesAdded += getInstSizeInBytes(CondMI);
1282 
1283   // One-way conditional branch.
1284   if (!FBB)
1285     return 1;
1286 
1287   // Two-way conditional branch.
1288   MachineInstr &MI = *BuildMI(&MBB, DL, get(RISCV::PseudoBR)).addMBB(FBB);
1289   if (BytesAdded)
1290     *BytesAdded += getInstSizeInBytes(MI);
1291   return 2;
1292 }
1293 
insertIndirectBranch(MachineBasicBlock & MBB,MachineBasicBlock & DestBB,MachineBasicBlock & RestoreBB,const DebugLoc & DL,int64_t BrOffset,RegScavenger * RS) const1294 void RISCVInstrInfo::insertIndirectBranch(MachineBasicBlock &MBB,
1295                                           MachineBasicBlock &DestBB,
1296                                           MachineBasicBlock &RestoreBB,
1297                                           const DebugLoc &DL, int64_t BrOffset,
1298                                           RegScavenger *RS) const {
1299   assert(RS && "RegScavenger required for long branching");
1300   assert(MBB.empty() &&
1301          "new block should be inserted for expanding unconditional branch");
1302   assert(MBB.pred_size() == 1);
1303   assert(RestoreBB.empty() &&
1304          "restore block should be inserted for restoring clobbered registers");
1305 
1306   MachineFunction *MF = MBB.getParent();
1307   MachineRegisterInfo &MRI = MF->getRegInfo();
1308   RISCVMachineFunctionInfo *RVFI = MF->getInfo<RISCVMachineFunctionInfo>();
1309   const TargetRegisterInfo *TRI = MF->getSubtarget().getRegisterInfo();
1310 
1311   if (!isInt<32>(BrOffset))
1312     report_fatal_error(
1313         "Branch offsets outside of the signed 32-bit range not supported");
1314 
1315   // FIXME: A virtual register must be used initially, as the register
1316   // scavenger won't work with empty blocks (SIInstrInfo::insertIndirectBranch
1317   // uses the same workaround).
1318   Register ScratchReg = MRI.createVirtualRegister(&RISCV::GPRJALRRegClass);
1319   auto II = MBB.end();
1320   // We may also update the jump target to RestoreBB later.
1321   MachineInstr &MI = *BuildMI(MBB, II, DL, get(RISCV::PseudoJump))
1322                           .addReg(ScratchReg, RegState::Define | RegState::Dead)
1323                           .addMBB(&DestBB, RISCVII::MO_CALL);
1324 
1325   RS->enterBasicBlockEnd(MBB);
1326   Register TmpGPR =
1327       RS->scavengeRegisterBackwards(RISCV::GPRRegClass, MI.getIterator(),
1328                                     /*RestoreAfter=*/false, /*SpAdj=*/0,
1329                                     /*AllowSpill=*/false);
1330   if (TmpGPR != RISCV::NoRegister)
1331     RS->setRegUsed(TmpGPR);
1332   else {
1333     // The case when there is no scavenged register needs special handling.
1334 
1335     // Pick s11(or s1 for rve) because it doesn't make a difference.
1336     TmpGPR = STI.hasStdExtE() ? RISCV::X9 : RISCV::X27;
1337 
1338     int FrameIndex = RVFI->getBranchRelaxationScratchFrameIndex();
1339     if (FrameIndex == -1)
1340       report_fatal_error("underestimated function size");
1341 
1342     storeRegToStackSlot(MBB, MI, TmpGPR, /*IsKill=*/true, FrameIndex,
1343                         &RISCV::GPRRegClass, TRI, Register());
1344     TRI->eliminateFrameIndex(std::prev(MI.getIterator()),
1345                              /*SpAdj=*/0, /*FIOperandNum=*/1);
1346 
1347     MI.getOperand(1).setMBB(&RestoreBB);
1348 
1349     loadRegFromStackSlot(RestoreBB, RestoreBB.end(), TmpGPR, FrameIndex,
1350                          &RISCV::GPRRegClass, TRI, Register());
1351     TRI->eliminateFrameIndex(RestoreBB.back(),
1352                              /*SpAdj=*/0, /*FIOperandNum=*/1);
1353   }
1354 
1355   MRI.replaceRegWith(ScratchReg, TmpGPR);
1356   MRI.clearVirtRegs();
1357 }
1358 
reverseBranchCondition(SmallVectorImpl<MachineOperand> & Cond) const1359 bool RISCVInstrInfo::reverseBranchCondition(
1360     SmallVectorImpl<MachineOperand> &Cond) const {
1361   assert((Cond.size() == 3) && "Invalid branch condition!");
1362   switch (Cond[0].getImm()) {
1363   default:
1364     llvm_unreachable("Unknown conditional branch!");
1365   case RISCV::BEQ:
1366     Cond[0].setImm(RISCV::BNE);
1367     break;
1368   case RISCV::BNE:
1369     Cond[0].setImm(RISCV::BEQ);
1370     break;
1371   case RISCV::BLT:
1372     Cond[0].setImm(RISCV::BGE);
1373     break;
1374   case RISCV::BGE:
1375     Cond[0].setImm(RISCV::BLT);
1376     break;
1377   case RISCV::BLTU:
1378     Cond[0].setImm(RISCV::BGEU);
1379     break;
1380   case RISCV::BGEU:
1381     Cond[0].setImm(RISCV::BLTU);
1382     break;
1383   case RISCV::CV_BEQIMM:
1384     Cond[0].setImm(RISCV::CV_BNEIMM);
1385     break;
1386   case RISCV::CV_BNEIMM:
1387     Cond[0].setImm(RISCV::CV_BEQIMM);
1388     break;
1389   case RISCV::QC_BEQI:
1390     Cond[0].setImm(RISCV::QC_BNEI);
1391     break;
1392   case RISCV::QC_BNEI:
1393     Cond[0].setImm(RISCV::QC_BEQI);
1394     break;
1395   case RISCV::QC_BGEI:
1396     Cond[0].setImm(RISCV::QC_BLTI);
1397     break;
1398   case RISCV::QC_BLTI:
1399     Cond[0].setImm(RISCV::QC_BGEI);
1400     break;
1401   case RISCV::QC_BGEUI:
1402     Cond[0].setImm(RISCV::QC_BLTUI);
1403     break;
1404   case RISCV::QC_BLTUI:
1405     Cond[0].setImm(RISCV::QC_BGEUI);
1406     break;
1407   case RISCV::QC_E_BEQI:
1408     Cond[0].setImm(RISCV::QC_E_BNEI);
1409     break;
1410   case RISCV::QC_E_BNEI:
1411     Cond[0].setImm(RISCV::QC_E_BEQI);
1412     break;
1413   case RISCV::QC_E_BGEI:
1414     Cond[0].setImm(RISCV::QC_E_BLTI);
1415     break;
1416   case RISCV::QC_E_BLTI:
1417     Cond[0].setImm(RISCV::QC_E_BGEI);
1418     break;
1419   case RISCV::QC_E_BGEUI:
1420     Cond[0].setImm(RISCV::QC_E_BLTUI);
1421     break;
1422   case RISCV::QC_E_BLTUI:
1423     Cond[0].setImm(RISCV::QC_E_BGEUI);
1424     break;
1425   case RISCV::NDS_BBC:
1426     Cond[0].setImm(RISCV::NDS_BBS);
1427     break;
1428   case RISCV::NDS_BBS:
1429     Cond[0].setImm(RISCV::NDS_BBC);
1430     break;
1431   case RISCV::NDS_BEQC:
1432     Cond[0].setImm(RISCV::NDS_BNEC);
1433     break;
1434   case RISCV::NDS_BNEC:
1435     Cond[0].setImm(RISCV::NDS_BEQC);
1436     break;
1437   }
1438 
1439   return false;
1440 }
1441 
1442 // Return true if the instruction is a load immediate instruction (i.e.
1443 // ADDI x0, imm).
isLoadImm(const MachineInstr * MI,int64_t & Imm)1444 static bool isLoadImm(const MachineInstr *MI, int64_t &Imm) {
1445   if (MI->getOpcode() == RISCV::ADDI && MI->getOperand(1).isReg() &&
1446       MI->getOperand(1).getReg() == RISCV::X0) {
1447     Imm = MI->getOperand(2).getImm();
1448     return true;
1449   }
1450   return false;
1451 }
1452 
isFromLoadImm(const MachineRegisterInfo & MRI,const MachineOperand & Op,int64_t & Imm)1453 bool RISCVInstrInfo::isFromLoadImm(const MachineRegisterInfo &MRI,
1454                                    const MachineOperand &Op, int64_t &Imm) {
1455   // Either a load from immediate instruction or X0.
1456   if (!Op.isReg())
1457     return false;
1458 
1459   Register Reg = Op.getReg();
1460   if (Reg == RISCV::X0) {
1461     Imm = 0;
1462     return true;
1463   }
1464   return Reg.isVirtual() && isLoadImm(MRI.getVRegDef(Reg), Imm);
1465 }
1466 
optimizeCondBranch(MachineInstr & MI) const1467 bool RISCVInstrInfo::optimizeCondBranch(MachineInstr &MI) const {
1468   bool IsSigned = false;
1469   bool IsEquality = false;
1470   switch (MI.getOpcode()) {
1471   default:
1472     return false;
1473   case RISCV::BEQ:
1474   case RISCV::BNE:
1475     IsEquality = true;
1476     break;
1477   case RISCV::BGE:
1478   case RISCV::BLT:
1479     IsSigned = true;
1480     break;
1481   case RISCV::BGEU:
1482   case RISCV::BLTU:
1483     break;
1484   }
1485 
1486   MachineBasicBlock *MBB = MI.getParent();
1487   MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
1488 
1489   const MachineOperand &LHS = MI.getOperand(0);
1490   const MachineOperand &RHS = MI.getOperand(1);
1491   MachineBasicBlock *TBB = MI.getOperand(2).getMBB();
1492 
1493   RISCVCC::CondCode CC = getCondFromBranchOpc(MI.getOpcode());
1494   assert(CC != RISCVCC::COND_INVALID);
1495 
1496   // Canonicalize conditional branches which can be constant folded into
1497   // beqz or bnez.  We can't modify the CFG here.
1498   int64_t C0, C1;
1499   if (isFromLoadImm(MRI, LHS, C0) && isFromLoadImm(MRI, RHS, C1)) {
1500     unsigned NewOpc = evaluateCondBranch(CC, C0, C1) ? RISCV::BEQ : RISCV::BNE;
1501     // Build the new branch and remove the old one.
1502     BuildMI(*MBB, MI, MI.getDebugLoc(), get(NewOpc))
1503         .addReg(RISCV::X0)
1504         .addReg(RISCV::X0)
1505         .addMBB(TBB);
1506     MI.eraseFromParent();
1507     return true;
1508   }
1509 
1510   if (IsEquality)
1511     return false;
1512 
1513   // For two constants C0 and C1 from
1514   // ```
1515   // li Y, C0
1516   // li Z, C1
1517   // ```
1518   // 1. if C1 = C0 + 1
1519   // we can turn:
1520   //  (a) blt Y, X -> bge X, Z
1521   //  (b) bge Y, X -> blt X, Z
1522   //
1523   // 2. if C1 = C0 - 1
1524   // we can turn:
1525   //  (a) blt X, Y -> bge Z, X
1526   //  (b) bge X, Y -> blt Z, X
1527   //
1528   // To make sure this optimization is really beneficial, we only
1529   // optimize for cases where Y had only one use (i.e. only used by the branch).
1530   // Try to find the register for constant Z; return
1531   // invalid register otherwise.
1532   auto searchConst = [&](int64_t C1) -> Register {
1533     MachineBasicBlock::reverse_iterator II(&MI), E = MBB->rend();
1534     auto DefC1 = std::find_if(++II, E, [&](const MachineInstr &I) -> bool {
1535       int64_t Imm;
1536       return isLoadImm(&I, Imm) && Imm == C1 &&
1537              I.getOperand(0).getReg().isVirtual();
1538     });
1539     if (DefC1 != E)
1540       return DefC1->getOperand(0).getReg();
1541 
1542     return Register();
1543   };
1544 
1545   unsigned NewOpc = RISCVCC::getBrCond(getOppositeBranchCondition(CC));
1546 
1547   // Might be case 1.
1548   // Don't change 0 to 1 since we can use x0.
1549   // For unsigned cases changing -1U to 0 would be incorrect.
1550   // The incorrect case for signed would be INT_MAX, but isFromLoadImm can't
1551   // return that.
1552   if (isFromLoadImm(MRI, LHS, C0) && C0 != 0 && LHS.getReg().isVirtual() &&
1553       MRI.hasOneUse(LHS.getReg()) && (IsSigned || C0 != -1)) {
1554     assert(isInt<12>(C0) && "Unexpected immediate");
1555     if (Register RegZ = searchConst(C0 + 1)) {
1556       BuildMI(*MBB, MI, MI.getDebugLoc(), get(NewOpc))
1557           .add(RHS)
1558           .addReg(RegZ)
1559           .addMBB(TBB);
1560       // We might extend the live range of Z, clear its kill flag to
1561       // account for this.
1562       MRI.clearKillFlags(RegZ);
1563       MI.eraseFromParent();
1564       return true;
1565     }
1566   }
1567 
1568   // Might be case 2.
1569   // For signed cases we don't want to change 0 since we can use x0.
1570   // For unsigned cases changing 0 to -1U would be incorrect.
1571   // The incorrect case for signed would be INT_MIN, but isFromLoadImm can't
1572   // return that.
1573   if (isFromLoadImm(MRI, RHS, C0) && C0 != 0 && RHS.getReg().isVirtual() &&
1574       MRI.hasOneUse(RHS.getReg())) {
1575     assert(isInt<12>(C0) && "Unexpected immediate");
1576     if (Register RegZ = searchConst(C0 - 1)) {
1577       BuildMI(*MBB, MI, MI.getDebugLoc(), get(NewOpc))
1578           .addReg(RegZ)
1579           .add(LHS)
1580           .addMBB(TBB);
1581       // We might extend the live range of Z, clear its kill flag to
1582       // account for this.
1583       MRI.clearKillFlags(RegZ);
1584       MI.eraseFromParent();
1585       return true;
1586     }
1587   }
1588 
1589   return false;
1590 }
1591 
1592 MachineBasicBlock *
getBranchDestBlock(const MachineInstr & MI) const1593 RISCVInstrInfo::getBranchDestBlock(const MachineInstr &MI) const {
1594   assert(MI.getDesc().isBranch() && "Unexpected opcode!");
1595   // The branch target is always the last operand.
1596   int NumOp = MI.getNumExplicitOperands();
1597   return MI.getOperand(NumOp - 1).getMBB();
1598 }
1599 
isBranchOffsetInRange(unsigned BranchOp,int64_t BrOffset) const1600 bool RISCVInstrInfo::isBranchOffsetInRange(unsigned BranchOp,
1601                                            int64_t BrOffset) const {
1602   unsigned XLen = STI.getXLen();
1603   // Ideally we could determine the supported branch offset from the
1604   // RISCVII::FormMask, but this can't be used for Pseudo instructions like
1605   // PseudoBR.
1606   switch (BranchOp) {
1607   default:
1608     llvm_unreachable("Unexpected opcode!");
1609   case RISCV::NDS_BBC:
1610   case RISCV::NDS_BBS:
1611   case RISCV::NDS_BEQC:
1612   case RISCV::NDS_BNEC:
1613     return isInt<11>(BrOffset);
1614   case RISCV::BEQ:
1615   case RISCV::BNE:
1616   case RISCV::BLT:
1617   case RISCV::BGE:
1618   case RISCV::BLTU:
1619   case RISCV::BGEU:
1620   case RISCV::CV_BEQIMM:
1621   case RISCV::CV_BNEIMM:
1622   case RISCV::QC_BEQI:
1623   case RISCV::QC_BNEI:
1624   case RISCV::QC_BGEI:
1625   case RISCV::QC_BLTI:
1626   case RISCV::QC_BLTUI:
1627   case RISCV::QC_BGEUI:
1628   case RISCV::QC_E_BEQI:
1629   case RISCV::QC_E_BNEI:
1630   case RISCV::QC_E_BGEI:
1631   case RISCV::QC_E_BLTI:
1632   case RISCV::QC_E_BLTUI:
1633   case RISCV::QC_E_BGEUI:
1634     return isInt<13>(BrOffset);
1635   case RISCV::JAL:
1636   case RISCV::PseudoBR:
1637     return isInt<21>(BrOffset);
1638   case RISCV::PseudoJump:
1639     return isInt<32>(SignExtend64(BrOffset + 0x800, XLen));
1640   }
1641 }
1642 
1643 // If the operation has a predicated pseudo instruction, return the pseudo
1644 // instruction opcode. Otherwise, return RISCV::INSTRUCTION_LIST_END.
1645 // TODO: Support more operations.
getPredicatedOpcode(unsigned Opcode)1646 unsigned getPredicatedOpcode(unsigned Opcode) {
1647   switch (Opcode) {
1648   case RISCV::ADD:   return RISCV::PseudoCCADD;   break;
1649   case RISCV::SUB:   return RISCV::PseudoCCSUB;   break;
1650   case RISCV::SLL:   return RISCV::PseudoCCSLL;   break;
1651   case RISCV::SRL:   return RISCV::PseudoCCSRL;   break;
1652   case RISCV::SRA:   return RISCV::PseudoCCSRA;   break;
1653   case RISCV::AND:   return RISCV::PseudoCCAND;   break;
1654   case RISCV::OR:    return RISCV::PseudoCCOR;    break;
1655   case RISCV::XOR:   return RISCV::PseudoCCXOR;   break;
1656 
1657   case RISCV::ADDI:  return RISCV::PseudoCCADDI;  break;
1658   case RISCV::SLLI:  return RISCV::PseudoCCSLLI;  break;
1659   case RISCV::SRLI:  return RISCV::PseudoCCSRLI;  break;
1660   case RISCV::SRAI:  return RISCV::PseudoCCSRAI;  break;
1661   case RISCV::ANDI:  return RISCV::PseudoCCANDI;  break;
1662   case RISCV::ORI:   return RISCV::PseudoCCORI;   break;
1663   case RISCV::XORI:  return RISCV::PseudoCCXORI;  break;
1664 
1665   case RISCV::ADDW:  return RISCV::PseudoCCADDW;  break;
1666   case RISCV::SUBW:  return RISCV::PseudoCCSUBW;  break;
1667   case RISCV::SLLW:  return RISCV::PseudoCCSLLW;  break;
1668   case RISCV::SRLW:  return RISCV::PseudoCCSRLW;  break;
1669   case RISCV::SRAW:  return RISCV::PseudoCCSRAW;  break;
1670 
1671   case RISCV::ADDIW: return RISCV::PseudoCCADDIW; break;
1672   case RISCV::SLLIW: return RISCV::PseudoCCSLLIW; break;
1673   case RISCV::SRLIW: return RISCV::PseudoCCSRLIW; break;
1674   case RISCV::SRAIW: return RISCV::PseudoCCSRAIW; break;
1675 
1676   case RISCV::ANDN:  return RISCV::PseudoCCANDN;  break;
1677   case RISCV::ORN:   return RISCV::PseudoCCORN;   break;
1678   case RISCV::XNOR:  return RISCV::PseudoCCXNOR;  break;
1679 
1680   case RISCV::NDS_BFOS:  return RISCV::PseudoCCNDS_BFOS;  break;
1681   case RISCV::NDS_BFOZ:  return RISCV::PseudoCCNDS_BFOZ;  break;
1682   }
1683 
1684   return RISCV::INSTRUCTION_LIST_END;
1685 }
1686 
1687 /// Identify instructions that can be folded into a CCMOV instruction, and
1688 /// return the defining instruction.
canFoldAsPredicatedOp(Register Reg,const MachineRegisterInfo & MRI,const TargetInstrInfo * TII)1689 static MachineInstr *canFoldAsPredicatedOp(Register Reg,
1690                                            const MachineRegisterInfo &MRI,
1691                                            const TargetInstrInfo *TII) {
1692   if (!Reg.isVirtual())
1693     return nullptr;
1694   if (!MRI.hasOneNonDBGUse(Reg))
1695     return nullptr;
1696   MachineInstr *MI = MRI.getVRegDef(Reg);
1697   if (!MI)
1698     return nullptr;
1699   // Check if MI can be predicated and folded into the CCMOV.
1700   if (getPredicatedOpcode(MI->getOpcode()) == RISCV::INSTRUCTION_LIST_END)
1701     return nullptr;
1702   // Don't predicate li idiom.
1703   if (MI->getOpcode() == RISCV::ADDI && MI->getOperand(1).isReg() &&
1704       MI->getOperand(1).getReg() == RISCV::X0)
1705     return nullptr;
1706   // Check if MI has any other defs or physreg uses.
1707   for (const MachineOperand &MO : llvm::drop_begin(MI->operands())) {
1708     // Reject frame index operands, PEI can't handle the predicated pseudos.
1709     if (MO.isFI() || MO.isCPI() || MO.isJTI())
1710       return nullptr;
1711     if (!MO.isReg())
1712       continue;
1713     // MI can't have any tied operands, that would conflict with predication.
1714     if (MO.isTied())
1715       return nullptr;
1716     if (MO.isDef())
1717       return nullptr;
1718     // Allow constant physregs.
1719     if (MO.getReg().isPhysical() && !MRI.isConstantPhysReg(MO.getReg()))
1720       return nullptr;
1721   }
1722   bool DontMoveAcrossStores = true;
1723   if (!MI->isSafeToMove(DontMoveAcrossStores))
1724     return nullptr;
1725   return MI;
1726 }
1727 
analyzeSelect(const MachineInstr & MI,SmallVectorImpl<MachineOperand> & Cond,unsigned & TrueOp,unsigned & FalseOp,bool & Optimizable) const1728 bool RISCVInstrInfo::analyzeSelect(const MachineInstr &MI,
1729                                    SmallVectorImpl<MachineOperand> &Cond,
1730                                    unsigned &TrueOp, unsigned &FalseOp,
1731                                    bool &Optimizable) const {
1732   assert(MI.getOpcode() == RISCV::PseudoCCMOVGPR &&
1733          "Unknown select instruction");
1734   // CCMOV operands:
1735   // 0: Def.
1736   // 1: LHS of compare.
1737   // 2: RHS of compare.
1738   // 3: Condition code.
1739   // 4: False use.
1740   // 5: True use.
1741   TrueOp = 5;
1742   FalseOp = 4;
1743   Cond.push_back(MI.getOperand(1));
1744   Cond.push_back(MI.getOperand(2));
1745   Cond.push_back(MI.getOperand(3));
1746   // We can only fold when we support short forward branch opt.
1747   Optimizable = STI.hasShortForwardBranchOpt();
1748   return false;
1749 }
1750 
1751 MachineInstr *
optimizeSelect(MachineInstr & MI,SmallPtrSetImpl<MachineInstr * > & SeenMIs,bool PreferFalse) const1752 RISCVInstrInfo::optimizeSelect(MachineInstr &MI,
1753                                SmallPtrSetImpl<MachineInstr *> &SeenMIs,
1754                                bool PreferFalse) const {
1755   assert(MI.getOpcode() == RISCV::PseudoCCMOVGPR &&
1756          "Unknown select instruction");
1757   if (!STI.hasShortForwardBranchOpt())
1758     return nullptr;
1759 
1760   MachineRegisterInfo &MRI = MI.getParent()->getParent()->getRegInfo();
1761   MachineInstr *DefMI =
1762       canFoldAsPredicatedOp(MI.getOperand(5).getReg(), MRI, this);
1763   bool Invert = !DefMI;
1764   if (!DefMI)
1765     DefMI = canFoldAsPredicatedOp(MI.getOperand(4).getReg(), MRI, this);
1766   if (!DefMI)
1767     return nullptr;
1768 
1769   // Find new register class to use.
1770   MachineOperand FalseReg = MI.getOperand(Invert ? 5 : 4);
1771   Register DestReg = MI.getOperand(0).getReg();
1772   const TargetRegisterClass *PreviousClass = MRI.getRegClass(FalseReg.getReg());
1773   if (!MRI.constrainRegClass(DestReg, PreviousClass))
1774     return nullptr;
1775 
1776   unsigned PredOpc = getPredicatedOpcode(DefMI->getOpcode());
1777   assert(PredOpc != RISCV::INSTRUCTION_LIST_END && "Unexpected opcode!");
1778 
1779   // Create a new predicated version of DefMI.
1780   MachineInstrBuilder NewMI =
1781       BuildMI(*MI.getParent(), MI, MI.getDebugLoc(), get(PredOpc), DestReg);
1782 
1783   // Copy the condition portion.
1784   NewMI.add(MI.getOperand(1));
1785   NewMI.add(MI.getOperand(2));
1786 
1787   // Add condition code, inverting if necessary.
1788   auto CC = static_cast<RISCVCC::CondCode>(MI.getOperand(3).getImm());
1789   if (Invert)
1790     CC = RISCVCC::getOppositeBranchCondition(CC);
1791   NewMI.addImm(CC);
1792 
1793   // Copy the false register.
1794   NewMI.add(FalseReg);
1795 
1796   // Copy all the DefMI operands.
1797   const MCInstrDesc &DefDesc = DefMI->getDesc();
1798   for (unsigned i = 1, e = DefDesc.getNumOperands(); i != e; ++i)
1799     NewMI.add(DefMI->getOperand(i));
1800 
1801   // Update SeenMIs set: register newly created MI and erase removed DefMI.
1802   SeenMIs.insert(NewMI);
1803   SeenMIs.erase(DefMI);
1804 
1805   // If MI is inside a loop, and DefMI is outside the loop, then kill flags on
1806   // DefMI would be invalid when transferred inside the loop.  Checking for a
1807   // loop is expensive, but at least remove kill flags if they are in different
1808   // BBs.
1809   if (DefMI->getParent() != MI.getParent())
1810     NewMI->clearKillInfo();
1811 
1812   // The caller will erase MI, but not DefMI.
1813   DefMI->eraseFromParent();
1814   return NewMI;
1815 }
1816 
getInstSizeInBytes(const MachineInstr & MI) const1817 unsigned RISCVInstrInfo::getInstSizeInBytes(const MachineInstr &MI) const {
1818   if (MI.isMetaInstruction())
1819     return 0;
1820 
1821   unsigned Opcode = MI.getOpcode();
1822 
1823   if (Opcode == TargetOpcode::INLINEASM ||
1824       Opcode == TargetOpcode::INLINEASM_BR) {
1825     const MachineFunction &MF = *MI.getParent()->getParent();
1826     return getInlineAsmLength(MI.getOperand(0).getSymbolName(),
1827                               *MF.getTarget().getMCAsmInfo());
1828   }
1829 
1830   if (!MI.memoperands_empty()) {
1831     MachineMemOperand *MMO = *(MI.memoperands_begin());
1832     if (STI.hasStdExtZihintntl() && MMO->isNonTemporal()) {
1833       if (STI.hasStdExtZca()) {
1834         if (isCompressibleInst(MI, STI))
1835           return 4; // c.ntl.all + c.load/c.store
1836         return 6;   // c.ntl.all + load/store
1837       }
1838       return 8; // ntl.all + load/store
1839     }
1840   }
1841 
1842   if (Opcode == TargetOpcode::BUNDLE)
1843     return getInstBundleLength(MI);
1844 
1845   if (MI.getParent() && MI.getParent()->getParent()) {
1846     if (isCompressibleInst(MI, STI))
1847       return 2;
1848   }
1849 
1850   switch (Opcode) {
1851   case RISCV::PseudoMV_FPR16INX:
1852   case RISCV::PseudoMV_FPR32INX:
1853     // MV is always compressible to either c.mv or c.li rd, 0.
1854     return STI.hasStdExtZca() ? 2 : 4;
1855   case TargetOpcode::STACKMAP:
1856     // The upper bound for a stackmap intrinsic is the full length of its shadow
1857     return StackMapOpers(&MI).getNumPatchBytes();
1858   case TargetOpcode::PATCHPOINT:
1859     // The size of the patchpoint intrinsic is the number of bytes requested
1860     return PatchPointOpers(&MI).getNumPatchBytes();
1861   case TargetOpcode::STATEPOINT: {
1862     // The size of the statepoint intrinsic is the number of bytes requested
1863     unsigned NumBytes = StatepointOpers(&MI).getNumPatchBytes();
1864     // No patch bytes means at most a PseudoCall is emitted
1865     return std::max(NumBytes, 8U);
1866   }
1867   case TargetOpcode::PATCHABLE_FUNCTION_ENTER:
1868   case TargetOpcode::PATCHABLE_FUNCTION_EXIT:
1869   case TargetOpcode::PATCHABLE_TAIL_CALL: {
1870     const MachineFunction &MF = *MI.getParent()->getParent();
1871     const Function &F = MF.getFunction();
1872     if (Opcode == TargetOpcode::PATCHABLE_FUNCTION_ENTER &&
1873         F.hasFnAttribute("patchable-function-entry")) {
1874       unsigned Num;
1875       if (F.getFnAttribute("patchable-function-entry")
1876               .getValueAsString()
1877               .getAsInteger(10, Num))
1878         return get(Opcode).getSize();
1879 
1880       // Number of C.NOP or NOP
1881       return (STI.hasStdExtZca() ? 2 : 4) * Num;
1882     }
1883     // XRay uses C.JAL + 21 or 33 C.NOP for each sled in RV32 and RV64,
1884     // respectively.
1885     return STI.is64Bit() ? 68 : 44;
1886   }
1887   default:
1888     return get(Opcode).getSize();
1889   }
1890 }
1891 
getInstBundleLength(const MachineInstr & MI) const1892 unsigned RISCVInstrInfo::getInstBundleLength(const MachineInstr &MI) const {
1893   unsigned Size = 0;
1894   MachineBasicBlock::const_instr_iterator I = MI.getIterator();
1895   MachineBasicBlock::const_instr_iterator E = MI.getParent()->instr_end();
1896   while (++I != E && I->isInsideBundle()) {
1897     assert(!I->isBundle() && "No nested bundle!");
1898     Size += getInstSizeInBytes(*I);
1899   }
1900   return Size;
1901 }
1902 
isAsCheapAsAMove(const MachineInstr & MI) const1903 bool RISCVInstrInfo::isAsCheapAsAMove(const MachineInstr &MI) const {
1904   const unsigned Opcode = MI.getOpcode();
1905   switch (Opcode) {
1906   default:
1907     break;
1908   case RISCV::FSGNJ_D:
1909   case RISCV::FSGNJ_S:
1910   case RISCV::FSGNJ_H:
1911   case RISCV::FSGNJ_D_INX:
1912   case RISCV::FSGNJ_D_IN32X:
1913   case RISCV::FSGNJ_S_INX:
1914   case RISCV::FSGNJ_H_INX:
1915     // The canonical floating-point move is fsgnj rd, rs, rs.
1916     return MI.getOperand(1).isReg() && MI.getOperand(2).isReg() &&
1917            MI.getOperand(1).getReg() == MI.getOperand(2).getReg();
1918   case RISCV::ADDI:
1919   case RISCV::ORI:
1920   case RISCV::XORI:
1921     return (MI.getOperand(1).isReg() &&
1922             MI.getOperand(1).getReg() == RISCV::X0) ||
1923            (MI.getOperand(2).isImm() && MI.getOperand(2).getImm() == 0);
1924   }
1925   return MI.isAsCheapAsAMove();
1926 }
1927 
1928 std::optional<DestSourcePair>
isCopyInstrImpl(const MachineInstr & MI) const1929 RISCVInstrInfo::isCopyInstrImpl(const MachineInstr &MI) const {
1930   if (MI.isMoveReg())
1931     return DestSourcePair{MI.getOperand(0), MI.getOperand(1)};
1932   switch (MI.getOpcode()) {
1933   default:
1934     break;
1935   case RISCV::ADD:
1936   case RISCV::OR:
1937   case RISCV::XOR:
1938     if (MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0 &&
1939         MI.getOperand(2).isReg())
1940       return DestSourcePair{MI.getOperand(0), MI.getOperand(2)};
1941     if (MI.getOperand(2).isReg() && MI.getOperand(2).getReg() == RISCV::X0 &&
1942         MI.getOperand(1).isReg())
1943       return DestSourcePair{MI.getOperand(0), MI.getOperand(1)};
1944     break;
1945   case RISCV::ADDI:
1946     // Operand 1 can be a frameindex but callers expect registers
1947     if (MI.getOperand(1).isReg() && MI.getOperand(2).isImm() &&
1948         MI.getOperand(2).getImm() == 0)
1949       return DestSourcePair{MI.getOperand(0), MI.getOperand(1)};
1950     break;
1951   case RISCV::SUB:
1952     if (MI.getOperand(2).isReg() && MI.getOperand(2).getReg() == RISCV::X0 &&
1953         MI.getOperand(1).isReg())
1954       return DestSourcePair{MI.getOperand(0), MI.getOperand(1)};
1955     break;
1956   case RISCV::SH1ADD:
1957   case RISCV::SH1ADD_UW:
1958   case RISCV::SH2ADD:
1959   case RISCV::SH2ADD_UW:
1960   case RISCV::SH3ADD:
1961   case RISCV::SH3ADD_UW:
1962     if (MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0 &&
1963         MI.getOperand(2).isReg())
1964       return DestSourcePair{MI.getOperand(0), MI.getOperand(2)};
1965     break;
1966   case RISCV::FSGNJ_D:
1967   case RISCV::FSGNJ_S:
1968   case RISCV::FSGNJ_H:
1969   case RISCV::FSGNJ_D_INX:
1970   case RISCV::FSGNJ_D_IN32X:
1971   case RISCV::FSGNJ_S_INX:
1972   case RISCV::FSGNJ_H_INX:
1973     // The canonical floating-point move is fsgnj rd, rs, rs.
1974     if (MI.getOperand(1).isReg() && MI.getOperand(2).isReg() &&
1975         MI.getOperand(1).getReg() == MI.getOperand(2).getReg())
1976       return DestSourcePair{MI.getOperand(0), MI.getOperand(1)};
1977     break;
1978   }
1979   return std::nullopt;
1980 }
1981 
getMachineCombinerTraceStrategy() const1982 MachineTraceStrategy RISCVInstrInfo::getMachineCombinerTraceStrategy() const {
1983   if (ForceMachineCombinerStrategy.getNumOccurrences() == 0) {
1984     // The option is unused. Choose Local strategy only for in-order cores. When
1985     // scheduling model is unspecified, use MinInstrCount strategy as more
1986     // generic one.
1987     const auto &SchedModel = STI.getSchedModel();
1988     return (!SchedModel.hasInstrSchedModel() || SchedModel.isOutOfOrder())
1989                ? MachineTraceStrategy::TS_MinInstrCount
1990                : MachineTraceStrategy::TS_Local;
1991   }
1992   // The strategy was forced by the option.
1993   return ForceMachineCombinerStrategy;
1994 }
1995 
finalizeInsInstrs(MachineInstr & Root,unsigned & Pattern,SmallVectorImpl<MachineInstr * > & InsInstrs) const1996 void RISCVInstrInfo::finalizeInsInstrs(
1997     MachineInstr &Root, unsigned &Pattern,
1998     SmallVectorImpl<MachineInstr *> &InsInstrs) const {
1999   int16_t FrmOpIdx =
2000       RISCV::getNamedOperandIdx(Root.getOpcode(), RISCV::OpName::frm);
2001   if (FrmOpIdx < 0) {
2002     assert(all_of(InsInstrs,
2003                   [](MachineInstr *MI) {
2004                     return RISCV::getNamedOperandIdx(MI->getOpcode(),
2005                                                      RISCV::OpName::frm) < 0;
2006                   }) &&
2007            "New instructions require FRM whereas the old one does not have it");
2008     return;
2009   }
2010 
2011   const MachineOperand &FRM = Root.getOperand(FrmOpIdx);
2012   MachineFunction &MF = *Root.getMF();
2013 
2014   for (auto *NewMI : InsInstrs) {
2015     // We'd already added the FRM operand.
2016     if (static_cast<unsigned>(RISCV::getNamedOperandIdx(
2017             NewMI->getOpcode(), RISCV::OpName::frm)) != NewMI->getNumOperands())
2018       continue;
2019     MachineInstrBuilder MIB(MF, NewMI);
2020     MIB.add(FRM);
2021     if (FRM.getImm() == RISCVFPRndMode::DYN)
2022       MIB.addUse(RISCV::FRM, RegState::Implicit);
2023   }
2024 }
2025 
isFADD(unsigned Opc)2026 static bool isFADD(unsigned Opc) {
2027   switch (Opc) {
2028   default:
2029     return false;
2030   case RISCV::FADD_H:
2031   case RISCV::FADD_S:
2032   case RISCV::FADD_D:
2033     return true;
2034   }
2035 }
2036 
isFSUB(unsigned Opc)2037 static bool isFSUB(unsigned Opc) {
2038   switch (Opc) {
2039   default:
2040     return false;
2041   case RISCV::FSUB_H:
2042   case RISCV::FSUB_S:
2043   case RISCV::FSUB_D:
2044     return true;
2045   }
2046 }
2047 
isFMUL(unsigned Opc)2048 static bool isFMUL(unsigned Opc) {
2049   switch (Opc) {
2050   default:
2051     return false;
2052   case RISCV::FMUL_H:
2053   case RISCV::FMUL_S:
2054   case RISCV::FMUL_D:
2055     return true;
2056   }
2057 }
2058 
isVectorAssociativeAndCommutative(const MachineInstr & Inst,bool Invert) const2059 bool RISCVInstrInfo::isVectorAssociativeAndCommutative(const MachineInstr &Inst,
2060                                                        bool Invert) const {
2061 #define OPCODE_LMUL_CASE(OPC)                                                  \
2062   case RISCV::OPC##_M1:                                                        \
2063   case RISCV::OPC##_M2:                                                        \
2064   case RISCV::OPC##_M4:                                                        \
2065   case RISCV::OPC##_M8:                                                        \
2066   case RISCV::OPC##_MF2:                                                       \
2067   case RISCV::OPC##_MF4:                                                       \
2068   case RISCV::OPC##_MF8
2069 
2070 #define OPCODE_LMUL_MASK_CASE(OPC)                                             \
2071   case RISCV::OPC##_M1_MASK:                                                   \
2072   case RISCV::OPC##_M2_MASK:                                                   \
2073   case RISCV::OPC##_M4_MASK:                                                   \
2074   case RISCV::OPC##_M8_MASK:                                                   \
2075   case RISCV::OPC##_MF2_MASK:                                                  \
2076   case RISCV::OPC##_MF4_MASK:                                                  \
2077   case RISCV::OPC##_MF8_MASK
2078 
2079   unsigned Opcode = Inst.getOpcode();
2080   if (Invert) {
2081     if (auto InvOpcode = getInverseOpcode(Opcode))
2082       Opcode = *InvOpcode;
2083     else
2084       return false;
2085   }
2086 
2087   // clang-format off
2088   switch (Opcode) {
2089   default:
2090     return false;
2091   OPCODE_LMUL_CASE(PseudoVADD_VV):
2092   OPCODE_LMUL_MASK_CASE(PseudoVADD_VV):
2093   OPCODE_LMUL_CASE(PseudoVMUL_VV):
2094   OPCODE_LMUL_MASK_CASE(PseudoVMUL_VV):
2095     return true;
2096   }
2097   // clang-format on
2098 
2099 #undef OPCODE_LMUL_MASK_CASE
2100 #undef OPCODE_LMUL_CASE
2101 }
2102 
areRVVInstsReassociable(const MachineInstr & Root,const MachineInstr & Prev) const2103 bool RISCVInstrInfo::areRVVInstsReassociable(const MachineInstr &Root,
2104                                              const MachineInstr &Prev) const {
2105   if (!areOpcodesEqualOrInverse(Root.getOpcode(), Prev.getOpcode()))
2106     return false;
2107 
2108   assert(Root.getMF() == Prev.getMF());
2109   const MachineRegisterInfo *MRI = &Root.getMF()->getRegInfo();
2110   const TargetRegisterInfo *TRI = MRI->getTargetRegisterInfo();
2111 
2112   // Make sure vtype operands are also the same.
2113   const MCInstrDesc &Desc = get(Root.getOpcode());
2114   const uint64_t TSFlags = Desc.TSFlags;
2115 
2116   auto checkImmOperand = [&](unsigned OpIdx) {
2117     return Root.getOperand(OpIdx).getImm() == Prev.getOperand(OpIdx).getImm();
2118   };
2119 
2120   auto checkRegOperand = [&](unsigned OpIdx) {
2121     return Root.getOperand(OpIdx).getReg() == Prev.getOperand(OpIdx).getReg();
2122   };
2123 
2124   // PassThru
2125   // TODO: Potentially we can loosen the condition to consider Root to be
2126   // associable with Prev if Root has NoReg as passthru. In which case we
2127   // also need to loosen the condition on vector policies between these.
2128   if (!checkRegOperand(1))
2129     return false;
2130 
2131   // SEW
2132   if (RISCVII::hasSEWOp(TSFlags) &&
2133       !checkImmOperand(RISCVII::getSEWOpNum(Desc)))
2134     return false;
2135 
2136   // Mask
2137   if (RISCVII::usesMaskPolicy(TSFlags)) {
2138     const MachineBasicBlock *MBB = Root.getParent();
2139     const MachineBasicBlock::const_reverse_iterator It1(&Root);
2140     const MachineBasicBlock::const_reverse_iterator It2(&Prev);
2141     Register MI1VReg;
2142 
2143     bool SeenMI2 = false;
2144     for (auto End = MBB->rend(), It = It1; It != End; ++It) {
2145       if (It == It2) {
2146         SeenMI2 = true;
2147         if (!MI1VReg.isValid())
2148           // There is no V0 def between Root and Prev; they're sharing the
2149           // same V0.
2150           break;
2151       }
2152 
2153       if (It->modifiesRegister(RISCV::V0, TRI)) {
2154         Register SrcReg = It->getOperand(1).getReg();
2155         // If it's not VReg it'll be more difficult to track its defs, so
2156         // bailing out here just to be safe.
2157         if (!SrcReg.isVirtual())
2158           return false;
2159 
2160         if (!MI1VReg.isValid()) {
2161           // This is the V0 def for Root.
2162           MI1VReg = SrcReg;
2163           continue;
2164         }
2165 
2166         // Some random mask updates.
2167         if (!SeenMI2)
2168           continue;
2169 
2170         // This is the V0 def for Prev; check if it's the same as that of
2171         // Root.
2172         if (MI1VReg != SrcReg)
2173           return false;
2174         else
2175           break;
2176       }
2177     }
2178 
2179     // If we haven't encountered Prev, it's likely that this function was
2180     // called in a wrong way (e.g. Root is before Prev).
2181     assert(SeenMI2 && "Prev is expected to appear before Root");
2182   }
2183 
2184   // Tail / Mask policies
2185   if (RISCVII::hasVecPolicyOp(TSFlags) &&
2186       !checkImmOperand(RISCVII::getVecPolicyOpNum(Desc)))
2187     return false;
2188 
2189   // VL
2190   if (RISCVII::hasVLOp(TSFlags)) {
2191     unsigned OpIdx = RISCVII::getVLOpNum(Desc);
2192     const MachineOperand &Op1 = Root.getOperand(OpIdx);
2193     const MachineOperand &Op2 = Prev.getOperand(OpIdx);
2194     if (Op1.getType() != Op2.getType())
2195       return false;
2196     switch (Op1.getType()) {
2197     case MachineOperand::MO_Register:
2198       if (Op1.getReg() != Op2.getReg())
2199         return false;
2200       break;
2201     case MachineOperand::MO_Immediate:
2202       if (Op1.getImm() != Op2.getImm())
2203         return false;
2204       break;
2205     default:
2206       llvm_unreachable("Unrecognized VL operand type");
2207     }
2208   }
2209 
2210   // Rounding modes
2211   if (RISCVII::hasRoundModeOp(TSFlags) &&
2212       !checkImmOperand(RISCVII::getVLOpNum(Desc) - 1))
2213     return false;
2214 
2215   return true;
2216 }
2217 
2218 // Most of our RVV pseudos have passthru operand, so the real operands
2219 // start from index = 2.
hasReassociableVectorSibling(const MachineInstr & Inst,bool & Commuted) const2220 bool RISCVInstrInfo::hasReassociableVectorSibling(const MachineInstr &Inst,
2221                                                   bool &Commuted) const {
2222   const MachineBasicBlock *MBB = Inst.getParent();
2223   const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
2224   assert(RISCVII::isFirstDefTiedToFirstUse(get(Inst.getOpcode())) &&
2225          "Expect the present of passthrough operand.");
2226   MachineInstr *MI1 = MRI.getUniqueVRegDef(Inst.getOperand(2).getReg());
2227   MachineInstr *MI2 = MRI.getUniqueVRegDef(Inst.getOperand(3).getReg());
2228 
2229   // If only one operand has the same or inverse opcode and it's the second
2230   // source operand, the operands must be commuted.
2231   Commuted = !areRVVInstsReassociable(Inst, *MI1) &&
2232              areRVVInstsReassociable(Inst, *MI2);
2233   if (Commuted)
2234     std::swap(MI1, MI2);
2235 
2236   return areRVVInstsReassociable(Inst, *MI1) &&
2237          (isVectorAssociativeAndCommutative(*MI1) ||
2238           isVectorAssociativeAndCommutative(*MI1, /* Invert */ true)) &&
2239          hasReassociableOperands(*MI1, MBB) &&
2240          MRI.hasOneNonDBGUse(MI1->getOperand(0).getReg());
2241 }
2242 
hasReassociableOperands(const MachineInstr & Inst,const MachineBasicBlock * MBB) const2243 bool RISCVInstrInfo::hasReassociableOperands(
2244     const MachineInstr &Inst, const MachineBasicBlock *MBB) const {
2245   if (!isVectorAssociativeAndCommutative(Inst) &&
2246       !isVectorAssociativeAndCommutative(Inst, /*Invert=*/true))
2247     return TargetInstrInfo::hasReassociableOperands(Inst, MBB);
2248 
2249   const MachineOperand &Op1 = Inst.getOperand(2);
2250   const MachineOperand &Op2 = Inst.getOperand(3);
2251   const MachineRegisterInfo &MRI = MBB->getParent()->getRegInfo();
2252 
2253   // We need virtual register definitions for the operands that we will
2254   // reassociate.
2255   MachineInstr *MI1 = nullptr;
2256   MachineInstr *MI2 = nullptr;
2257   if (Op1.isReg() && Op1.getReg().isVirtual())
2258     MI1 = MRI.getUniqueVRegDef(Op1.getReg());
2259   if (Op2.isReg() && Op2.getReg().isVirtual())
2260     MI2 = MRI.getUniqueVRegDef(Op2.getReg());
2261 
2262   // And at least one operand must be defined in MBB.
2263   return MI1 && MI2 && (MI1->getParent() == MBB || MI2->getParent() == MBB);
2264 }
2265 
getReassociateOperandIndices(const MachineInstr & Root,unsigned Pattern,std::array<unsigned,5> & OperandIndices) const2266 void RISCVInstrInfo::getReassociateOperandIndices(
2267     const MachineInstr &Root, unsigned Pattern,
2268     std::array<unsigned, 5> &OperandIndices) const {
2269   TargetInstrInfo::getReassociateOperandIndices(Root, Pattern, OperandIndices);
2270   if (RISCV::getRVVMCOpcode(Root.getOpcode())) {
2271     // Skip the passthrough operand, so increment all indices by one.
2272     for (unsigned I = 0; I < 5; ++I)
2273       ++OperandIndices[I];
2274   }
2275 }
2276 
hasReassociableSibling(const MachineInstr & Inst,bool & Commuted) const2277 bool RISCVInstrInfo::hasReassociableSibling(const MachineInstr &Inst,
2278                                             bool &Commuted) const {
2279   if (isVectorAssociativeAndCommutative(Inst) ||
2280       isVectorAssociativeAndCommutative(Inst, /*Invert=*/true))
2281     return hasReassociableVectorSibling(Inst, Commuted);
2282 
2283   if (!TargetInstrInfo::hasReassociableSibling(Inst, Commuted))
2284     return false;
2285 
2286   const MachineRegisterInfo &MRI = Inst.getMF()->getRegInfo();
2287   unsigned OperandIdx = Commuted ? 2 : 1;
2288   const MachineInstr &Sibling =
2289       *MRI.getVRegDef(Inst.getOperand(OperandIdx).getReg());
2290 
2291   int16_t InstFrmOpIdx =
2292       RISCV::getNamedOperandIdx(Inst.getOpcode(), RISCV::OpName::frm);
2293   int16_t SiblingFrmOpIdx =
2294       RISCV::getNamedOperandIdx(Sibling.getOpcode(), RISCV::OpName::frm);
2295 
2296   return (InstFrmOpIdx < 0 && SiblingFrmOpIdx < 0) ||
2297          RISCV::hasEqualFRM(Inst, Sibling);
2298 }
2299 
isAssociativeAndCommutative(const MachineInstr & Inst,bool Invert) const2300 bool RISCVInstrInfo::isAssociativeAndCommutative(const MachineInstr &Inst,
2301                                                  bool Invert) const {
2302   if (isVectorAssociativeAndCommutative(Inst, Invert))
2303     return true;
2304 
2305   unsigned Opc = Inst.getOpcode();
2306   if (Invert) {
2307     auto InverseOpcode = getInverseOpcode(Opc);
2308     if (!InverseOpcode)
2309       return false;
2310     Opc = *InverseOpcode;
2311   }
2312 
2313   if (isFADD(Opc) || isFMUL(Opc))
2314     return Inst.getFlag(MachineInstr::MIFlag::FmReassoc) &&
2315            Inst.getFlag(MachineInstr::MIFlag::FmNsz);
2316 
2317   switch (Opc) {
2318   default:
2319     return false;
2320   case RISCV::ADD:
2321   case RISCV::ADDW:
2322   case RISCV::AND:
2323   case RISCV::OR:
2324   case RISCV::XOR:
2325   // From RISC-V ISA spec, if both the high and low bits of the same product
2326   // are required, then the recommended code sequence is:
2327   //
2328   // MULH[[S]U] rdh, rs1, rs2
2329   // MUL        rdl, rs1, rs2
2330   // (source register specifiers must be in same order and rdh cannot be the
2331   //  same as rs1 or rs2)
2332   //
2333   // Microarchitectures can then fuse these into a single multiply operation
2334   // instead of performing two separate multiplies.
2335   // MachineCombiner may reassociate MUL operands and lose the fusion
2336   // opportunity.
2337   case RISCV::MUL:
2338   case RISCV::MULW:
2339   case RISCV::MIN:
2340   case RISCV::MINU:
2341   case RISCV::MAX:
2342   case RISCV::MAXU:
2343   case RISCV::FMIN_H:
2344   case RISCV::FMIN_S:
2345   case RISCV::FMIN_D:
2346   case RISCV::FMAX_H:
2347   case RISCV::FMAX_S:
2348   case RISCV::FMAX_D:
2349     return true;
2350   }
2351 
2352   return false;
2353 }
2354 
2355 std::optional<unsigned>
getInverseOpcode(unsigned Opcode) const2356 RISCVInstrInfo::getInverseOpcode(unsigned Opcode) const {
2357 #define RVV_OPC_LMUL_CASE(OPC, INV)                                            \
2358   case RISCV::OPC##_M1:                                                        \
2359     return RISCV::INV##_M1;                                                    \
2360   case RISCV::OPC##_M2:                                                        \
2361     return RISCV::INV##_M2;                                                    \
2362   case RISCV::OPC##_M4:                                                        \
2363     return RISCV::INV##_M4;                                                    \
2364   case RISCV::OPC##_M8:                                                        \
2365     return RISCV::INV##_M8;                                                    \
2366   case RISCV::OPC##_MF2:                                                       \
2367     return RISCV::INV##_MF2;                                                   \
2368   case RISCV::OPC##_MF4:                                                       \
2369     return RISCV::INV##_MF4;                                                   \
2370   case RISCV::OPC##_MF8:                                                       \
2371     return RISCV::INV##_MF8
2372 
2373 #define RVV_OPC_LMUL_MASK_CASE(OPC, INV)                                       \
2374   case RISCV::OPC##_M1_MASK:                                                   \
2375     return RISCV::INV##_M1_MASK;                                               \
2376   case RISCV::OPC##_M2_MASK:                                                   \
2377     return RISCV::INV##_M2_MASK;                                               \
2378   case RISCV::OPC##_M4_MASK:                                                   \
2379     return RISCV::INV##_M4_MASK;                                               \
2380   case RISCV::OPC##_M8_MASK:                                                   \
2381     return RISCV::INV##_M8_MASK;                                               \
2382   case RISCV::OPC##_MF2_MASK:                                                  \
2383     return RISCV::INV##_MF2_MASK;                                              \
2384   case RISCV::OPC##_MF4_MASK:                                                  \
2385     return RISCV::INV##_MF4_MASK;                                              \
2386   case RISCV::OPC##_MF8_MASK:                                                  \
2387     return RISCV::INV##_MF8_MASK
2388 
2389   switch (Opcode) {
2390   default:
2391     return std::nullopt;
2392   case RISCV::FADD_H:
2393     return RISCV::FSUB_H;
2394   case RISCV::FADD_S:
2395     return RISCV::FSUB_S;
2396   case RISCV::FADD_D:
2397     return RISCV::FSUB_D;
2398   case RISCV::FSUB_H:
2399     return RISCV::FADD_H;
2400   case RISCV::FSUB_S:
2401     return RISCV::FADD_S;
2402   case RISCV::FSUB_D:
2403     return RISCV::FADD_D;
2404   case RISCV::ADD:
2405     return RISCV::SUB;
2406   case RISCV::SUB:
2407     return RISCV::ADD;
2408   case RISCV::ADDW:
2409     return RISCV::SUBW;
2410   case RISCV::SUBW:
2411     return RISCV::ADDW;
2412     // clang-format off
2413   RVV_OPC_LMUL_CASE(PseudoVADD_VV, PseudoVSUB_VV);
2414   RVV_OPC_LMUL_MASK_CASE(PseudoVADD_VV, PseudoVSUB_VV);
2415   RVV_OPC_LMUL_CASE(PseudoVSUB_VV, PseudoVADD_VV);
2416   RVV_OPC_LMUL_MASK_CASE(PseudoVSUB_VV, PseudoVADD_VV);
2417     // clang-format on
2418   }
2419 
2420 #undef RVV_OPC_LMUL_MASK_CASE
2421 #undef RVV_OPC_LMUL_CASE
2422 }
2423 
canCombineFPFusedMultiply(const MachineInstr & Root,const MachineOperand & MO,bool DoRegPressureReduce)2424 static bool canCombineFPFusedMultiply(const MachineInstr &Root,
2425                                       const MachineOperand &MO,
2426                                       bool DoRegPressureReduce) {
2427   if (!MO.isReg() || !MO.getReg().isVirtual())
2428     return false;
2429   const MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
2430   MachineInstr *MI = MRI.getVRegDef(MO.getReg());
2431   if (!MI || !isFMUL(MI->getOpcode()))
2432     return false;
2433 
2434   if (!Root.getFlag(MachineInstr::MIFlag::FmContract) ||
2435       !MI->getFlag(MachineInstr::MIFlag::FmContract))
2436     return false;
2437 
2438   // Try combining even if fmul has more than one use as it eliminates
2439   // dependency between fadd(fsub) and fmul. However, it can extend liveranges
2440   // for fmul operands, so reject the transformation in register pressure
2441   // reduction mode.
2442   if (DoRegPressureReduce && !MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
2443     return false;
2444 
2445   // Do not combine instructions from different basic blocks.
2446   if (Root.getParent() != MI->getParent())
2447     return false;
2448   return RISCV::hasEqualFRM(Root, *MI);
2449 }
2450 
getFPFusedMultiplyPatterns(MachineInstr & Root,SmallVectorImpl<unsigned> & Patterns,bool DoRegPressureReduce)2451 static bool getFPFusedMultiplyPatterns(MachineInstr &Root,
2452                                        SmallVectorImpl<unsigned> &Patterns,
2453                                        bool DoRegPressureReduce) {
2454   unsigned Opc = Root.getOpcode();
2455   bool IsFAdd = isFADD(Opc);
2456   if (!IsFAdd && !isFSUB(Opc))
2457     return false;
2458   bool Added = false;
2459   if (canCombineFPFusedMultiply(Root, Root.getOperand(1),
2460                                 DoRegPressureReduce)) {
2461     Patterns.push_back(IsFAdd ? RISCVMachineCombinerPattern::FMADD_AX
2462                               : RISCVMachineCombinerPattern::FMSUB);
2463     Added = true;
2464   }
2465   if (canCombineFPFusedMultiply(Root, Root.getOperand(2),
2466                                 DoRegPressureReduce)) {
2467     Patterns.push_back(IsFAdd ? RISCVMachineCombinerPattern::FMADD_XA
2468                               : RISCVMachineCombinerPattern::FNMSUB);
2469     Added = true;
2470   }
2471   return Added;
2472 }
2473 
getFPPatterns(MachineInstr & Root,SmallVectorImpl<unsigned> & Patterns,bool DoRegPressureReduce)2474 static bool getFPPatterns(MachineInstr &Root,
2475                           SmallVectorImpl<unsigned> &Patterns,
2476                           bool DoRegPressureReduce) {
2477   return getFPFusedMultiplyPatterns(Root, Patterns, DoRegPressureReduce);
2478 }
2479 
2480 /// Utility routine that checks if \param MO is defined by an
2481 /// \param CombineOpc instruction in the basic block \param MBB
canCombine(const MachineBasicBlock & MBB,const MachineOperand & MO,unsigned CombineOpc)2482 static const MachineInstr *canCombine(const MachineBasicBlock &MBB,
2483                                       const MachineOperand &MO,
2484                                       unsigned CombineOpc) {
2485   const MachineRegisterInfo &MRI = MBB.getParent()->getRegInfo();
2486   const MachineInstr *MI = nullptr;
2487 
2488   if (MO.isReg() && MO.getReg().isVirtual())
2489     MI = MRI.getUniqueVRegDef(MO.getReg());
2490   // And it needs to be in the trace (otherwise, it won't have a depth).
2491   if (!MI || MI->getParent() != &MBB || MI->getOpcode() != CombineOpc)
2492     return nullptr;
2493   // Must only used by the user we combine with.
2494   if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
2495     return nullptr;
2496 
2497   return MI;
2498 }
2499 
2500 /// Utility routine that checks if \param MO is defined by a SLLI in \param
2501 /// MBB that can be combined by splitting across 2 SHXADD instructions. The
2502 /// first SHXADD shift amount is given by \param OuterShiftAmt.
canCombineShiftIntoShXAdd(const MachineBasicBlock & MBB,const MachineOperand & MO,unsigned OuterShiftAmt)2503 static bool canCombineShiftIntoShXAdd(const MachineBasicBlock &MBB,
2504                                       const MachineOperand &MO,
2505                                       unsigned OuterShiftAmt) {
2506   const MachineInstr *ShiftMI = canCombine(MBB, MO, RISCV::SLLI);
2507   if (!ShiftMI)
2508     return false;
2509 
2510   unsigned InnerShiftAmt = ShiftMI->getOperand(2).getImm();
2511   if (InnerShiftAmt < OuterShiftAmt || (InnerShiftAmt - OuterShiftAmt) > 3)
2512     return false;
2513 
2514   return true;
2515 }
2516 
2517 // Returns the shift amount from a SHXADD instruction. Returns 0 if the
2518 // instruction is not a SHXADD.
getSHXADDShiftAmount(unsigned Opc)2519 static unsigned getSHXADDShiftAmount(unsigned Opc) {
2520   switch (Opc) {
2521   default:
2522     return 0;
2523   case RISCV::SH1ADD:
2524     return 1;
2525   case RISCV::SH2ADD:
2526     return 2;
2527   case RISCV::SH3ADD:
2528     return 3;
2529   }
2530 }
2531 
2532 // Returns the shift amount from a SHXADD.UW instruction. Returns 0 if the
2533 // instruction is not a SHXADD.UW.
getSHXADDUWShiftAmount(unsigned Opc)2534 static unsigned getSHXADDUWShiftAmount(unsigned Opc) {
2535   switch (Opc) {
2536   default:
2537     return 0;
2538   case RISCV::SH1ADD_UW:
2539     return 1;
2540   case RISCV::SH2ADD_UW:
2541     return 2;
2542   case RISCV::SH3ADD_UW:
2543     return 3;
2544   }
2545 }
2546 
2547 // Look for opportunities to combine (sh3add Z, (add X, (slli Y, 5))) into
2548 // (sh3add (sh2add Y, Z), X).
getSHXADDPatterns(const MachineInstr & Root,SmallVectorImpl<unsigned> & Patterns)2549 static bool getSHXADDPatterns(const MachineInstr &Root,
2550                               SmallVectorImpl<unsigned> &Patterns) {
2551   unsigned ShiftAmt = getSHXADDShiftAmount(Root.getOpcode());
2552   if (!ShiftAmt)
2553     return false;
2554 
2555   const MachineBasicBlock &MBB = *Root.getParent();
2556 
2557   const MachineInstr *AddMI = canCombine(MBB, Root.getOperand(2), RISCV::ADD);
2558   if (!AddMI)
2559     return false;
2560 
2561   bool Found = false;
2562   if (canCombineShiftIntoShXAdd(MBB, AddMI->getOperand(1), ShiftAmt)) {
2563     Patterns.push_back(RISCVMachineCombinerPattern::SHXADD_ADD_SLLI_OP1);
2564     Found = true;
2565   }
2566   if (canCombineShiftIntoShXAdd(MBB, AddMI->getOperand(2), ShiftAmt)) {
2567     Patterns.push_back(RISCVMachineCombinerPattern::SHXADD_ADD_SLLI_OP2);
2568     Found = true;
2569   }
2570 
2571   return Found;
2572 }
2573 
getCombinerObjective(unsigned Pattern) const2574 CombinerObjective RISCVInstrInfo::getCombinerObjective(unsigned Pattern) const {
2575   switch (Pattern) {
2576   case RISCVMachineCombinerPattern::FMADD_AX:
2577   case RISCVMachineCombinerPattern::FMADD_XA:
2578   case RISCVMachineCombinerPattern::FMSUB:
2579   case RISCVMachineCombinerPattern::FNMSUB:
2580     return CombinerObjective::MustReduceDepth;
2581   default:
2582     return TargetInstrInfo::getCombinerObjective(Pattern);
2583   }
2584 }
2585 
getMachineCombinerPatterns(MachineInstr & Root,SmallVectorImpl<unsigned> & Patterns,bool DoRegPressureReduce) const2586 bool RISCVInstrInfo::getMachineCombinerPatterns(
2587     MachineInstr &Root, SmallVectorImpl<unsigned> &Patterns,
2588     bool DoRegPressureReduce) const {
2589 
2590   if (getFPPatterns(Root, Patterns, DoRegPressureReduce))
2591     return true;
2592 
2593   if (getSHXADDPatterns(Root, Patterns))
2594     return true;
2595 
2596   return TargetInstrInfo::getMachineCombinerPatterns(Root, Patterns,
2597                                                      DoRegPressureReduce);
2598 }
2599 
getFPFusedMultiplyOpcode(unsigned RootOpc,unsigned Pattern)2600 static unsigned getFPFusedMultiplyOpcode(unsigned RootOpc, unsigned Pattern) {
2601   switch (RootOpc) {
2602   default:
2603     llvm_unreachable("Unexpected opcode");
2604   case RISCV::FADD_H:
2605     return RISCV::FMADD_H;
2606   case RISCV::FADD_S:
2607     return RISCV::FMADD_S;
2608   case RISCV::FADD_D:
2609     return RISCV::FMADD_D;
2610   case RISCV::FSUB_H:
2611     return Pattern == RISCVMachineCombinerPattern::FMSUB ? RISCV::FMSUB_H
2612                                                          : RISCV::FNMSUB_H;
2613   case RISCV::FSUB_S:
2614     return Pattern == RISCVMachineCombinerPattern::FMSUB ? RISCV::FMSUB_S
2615                                                          : RISCV::FNMSUB_S;
2616   case RISCV::FSUB_D:
2617     return Pattern == RISCVMachineCombinerPattern::FMSUB ? RISCV::FMSUB_D
2618                                                          : RISCV::FNMSUB_D;
2619   }
2620 }
2621 
getAddendOperandIdx(unsigned Pattern)2622 static unsigned getAddendOperandIdx(unsigned Pattern) {
2623   switch (Pattern) {
2624   default:
2625     llvm_unreachable("Unexpected pattern");
2626   case RISCVMachineCombinerPattern::FMADD_AX:
2627   case RISCVMachineCombinerPattern::FMSUB:
2628     return 2;
2629   case RISCVMachineCombinerPattern::FMADD_XA:
2630   case RISCVMachineCombinerPattern::FNMSUB:
2631     return 1;
2632   }
2633 }
2634 
combineFPFusedMultiply(MachineInstr & Root,MachineInstr & Prev,unsigned Pattern,SmallVectorImpl<MachineInstr * > & InsInstrs,SmallVectorImpl<MachineInstr * > & DelInstrs)2635 static void combineFPFusedMultiply(MachineInstr &Root, MachineInstr &Prev,
2636                                    unsigned Pattern,
2637                                    SmallVectorImpl<MachineInstr *> &InsInstrs,
2638                                    SmallVectorImpl<MachineInstr *> &DelInstrs) {
2639   MachineFunction *MF = Root.getMF();
2640   MachineRegisterInfo &MRI = MF->getRegInfo();
2641   const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
2642 
2643   MachineOperand &Mul1 = Prev.getOperand(1);
2644   MachineOperand &Mul2 = Prev.getOperand(2);
2645   MachineOperand &Dst = Root.getOperand(0);
2646   MachineOperand &Addend = Root.getOperand(getAddendOperandIdx(Pattern));
2647 
2648   Register DstReg = Dst.getReg();
2649   unsigned FusedOpc = getFPFusedMultiplyOpcode(Root.getOpcode(), Pattern);
2650   uint32_t IntersectedFlags = Root.getFlags() & Prev.getFlags();
2651   DebugLoc MergedLoc =
2652       DILocation::getMergedLocation(Root.getDebugLoc(), Prev.getDebugLoc());
2653 
2654   bool Mul1IsKill = Mul1.isKill();
2655   bool Mul2IsKill = Mul2.isKill();
2656   bool AddendIsKill = Addend.isKill();
2657 
2658   // We need to clear kill flags since we may be extending the live range past
2659   // a kill. If the mul had kill flags, we can preserve those since we know
2660   // where the previous range stopped.
2661   MRI.clearKillFlags(Mul1.getReg());
2662   MRI.clearKillFlags(Mul2.getReg());
2663 
2664   MachineInstrBuilder MIB =
2665       BuildMI(*MF, MergedLoc, TII->get(FusedOpc), DstReg)
2666           .addReg(Mul1.getReg(), getKillRegState(Mul1IsKill))
2667           .addReg(Mul2.getReg(), getKillRegState(Mul2IsKill))
2668           .addReg(Addend.getReg(), getKillRegState(AddendIsKill))
2669           .setMIFlags(IntersectedFlags);
2670 
2671   InsInstrs.push_back(MIB);
2672   if (MRI.hasOneNonDBGUse(Prev.getOperand(0).getReg()))
2673     DelInstrs.push_back(&Prev);
2674   DelInstrs.push_back(&Root);
2675 }
2676 
2677 // Combine patterns like (sh3add Z, (add X, (slli Y, 5))) to
2678 // (sh3add (sh2add Y, Z), X) if the shift amount can be split across two
2679 // shXadd instructions. The outer shXadd keeps its original opcode.
2680 static void
genShXAddAddShift(MachineInstr & Root,unsigned AddOpIdx,SmallVectorImpl<MachineInstr * > & InsInstrs,SmallVectorImpl<MachineInstr * > & DelInstrs,DenseMap<Register,unsigned> & InstrIdxForVirtReg)2681 genShXAddAddShift(MachineInstr &Root, unsigned AddOpIdx,
2682                   SmallVectorImpl<MachineInstr *> &InsInstrs,
2683                   SmallVectorImpl<MachineInstr *> &DelInstrs,
2684                   DenseMap<Register, unsigned> &InstrIdxForVirtReg) {
2685   MachineFunction *MF = Root.getMF();
2686   MachineRegisterInfo &MRI = MF->getRegInfo();
2687   const TargetInstrInfo *TII = MF->getSubtarget().getInstrInfo();
2688 
2689   unsigned OuterShiftAmt = getSHXADDShiftAmount(Root.getOpcode());
2690   assert(OuterShiftAmt != 0 && "Unexpected opcode");
2691 
2692   MachineInstr *AddMI = MRI.getUniqueVRegDef(Root.getOperand(2).getReg());
2693   MachineInstr *ShiftMI =
2694       MRI.getUniqueVRegDef(AddMI->getOperand(AddOpIdx).getReg());
2695 
2696   unsigned InnerShiftAmt = ShiftMI->getOperand(2).getImm();
2697   assert(InnerShiftAmt >= OuterShiftAmt && "Unexpected shift amount");
2698 
2699   unsigned InnerOpc;
2700   switch (InnerShiftAmt - OuterShiftAmt) {
2701   default:
2702     llvm_unreachable("Unexpected shift amount");
2703   case 0:
2704     InnerOpc = RISCV::ADD;
2705     break;
2706   case 1:
2707     InnerOpc = RISCV::SH1ADD;
2708     break;
2709   case 2:
2710     InnerOpc = RISCV::SH2ADD;
2711     break;
2712   case 3:
2713     InnerOpc = RISCV::SH3ADD;
2714     break;
2715   }
2716 
2717   const MachineOperand &X = AddMI->getOperand(3 - AddOpIdx);
2718   const MachineOperand &Y = ShiftMI->getOperand(1);
2719   const MachineOperand &Z = Root.getOperand(1);
2720 
2721   Register NewVR = MRI.createVirtualRegister(&RISCV::GPRRegClass);
2722 
2723   auto MIB1 = BuildMI(*MF, MIMetadata(Root), TII->get(InnerOpc), NewVR)
2724                   .addReg(Y.getReg(), getKillRegState(Y.isKill()))
2725                   .addReg(Z.getReg(), getKillRegState(Z.isKill()));
2726   auto MIB2 = BuildMI(*MF, MIMetadata(Root), TII->get(Root.getOpcode()),
2727                       Root.getOperand(0).getReg())
2728                   .addReg(NewVR, RegState::Kill)
2729                   .addReg(X.getReg(), getKillRegState(X.isKill()));
2730 
2731   InstrIdxForVirtReg.insert(std::make_pair(NewVR, 0));
2732   InsInstrs.push_back(MIB1);
2733   InsInstrs.push_back(MIB2);
2734   DelInstrs.push_back(ShiftMI);
2735   DelInstrs.push_back(AddMI);
2736   DelInstrs.push_back(&Root);
2737 }
2738 
genAlternativeCodeSequence(MachineInstr & Root,unsigned Pattern,SmallVectorImpl<MachineInstr * > & InsInstrs,SmallVectorImpl<MachineInstr * > & DelInstrs,DenseMap<Register,unsigned> & InstrIdxForVirtReg) const2739 void RISCVInstrInfo::genAlternativeCodeSequence(
2740     MachineInstr &Root, unsigned Pattern,
2741     SmallVectorImpl<MachineInstr *> &InsInstrs,
2742     SmallVectorImpl<MachineInstr *> &DelInstrs,
2743     DenseMap<Register, unsigned> &InstrIdxForVirtReg) const {
2744   MachineRegisterInfo &MRI = Root.getMF()->getRegInfo();
2745   switch (Pattern) {
2746   default:
2747     TargetInstrInfo::genAlternativeCodeSequence(Root, Pattern, InsInstrs,
2748                                                 DelInstrs, InstrIdxForVirtReg);
2749     return;
2750   case RISCVMachineCombinerPattern::FMADD_AX:
2751   case RISCVMachineCombinerPattern::FMSUB: {
2752     MachineInstr &Prev = *MRI.getVRegDef(Root.getOperand(1).getReg());
2753     combineFPFusedMultiply(Root, Prev, Pattern, InsInstrs, DelInstrs);
2754     return;
2755   }
2756   case RISCVMachineCombinerPattern::FMADD_XA:
2757   case RISCVMachineCombinerPattern::FNMSUB: {
2758     MachineInstr &Prev = *MRI.getVRegDef(Root.getOperand(2).getReg());
2759     combineFPFusedMultiply(Root, Prev, Pattern, InsInstrs, DelInstrs);
2760     return;
2761   }
2762   case RISCVMachineCombinerPattern::SHXADD_ADD_SLLI_OP1:
2763     genShXAddAddShift(Root, 1, InsInstrs, DelInstrs, InstrIdxForVirtReg);
2764     return;
2765   case RISCVMachineCombinerPattern::SHXADD_ADD_SLLI_OP2:
2766     genShXAddAddShift(Root, 2, InsInstrs, DelInstrs, InstrIdxForVirtReg);
2767     return;
2768   }
2769 }
2770 
verifyInstruction(const MachineInstr & MI,StringRef & ErrInfo) const2771 bool RISCVInstrInfo::verifyInstruction(const MachineInstr &MI,
2772                                        StringRef &ErrInfo) const {
2773   MCInstrDesc const &Desc = MI.getDesc();
2774 
2775   for (const auto &[Index, Operand] : enumerate(Desc.operands())) {
2776     unsigned OpType = Operand.OperandType;
2777     if (OpType >= RISCVOp::OPERAND_FIRST_RISCV_IMM &&
2778         OpType <= RISCVOp::OPERAND_LAST_RISCV_IMM) {
2779       const MachineOperand &MO = MI.getOperand(Index);
2780       if (MO.isReg()) {
2781         ErrInfo = "Expected a non-register operand.";
2782         return false;
2783       }
2784       if (MO.isImm()) {
2785         int64_t Imm = MO.getImm();
2786         bool Ok;
2787         switch (OpType) {
2788         default:
2789           llvm_unreachable("Unexpected operand type");
2790 
2791           // clang-format off
2792 #define CASE_OPERAND_UIMM(NUM)                                                 \
2793   case RISCVOp::OPERAND_UIMM##NUM:                                             \
2794     Ok = isUInt<NUM>(Imm);                                                     \
2795     break;
2796 #define CASE_OPERAND_SIMM(NUM)                                                 \
2797   case RISCVOp::OPERAND_SIMM##NUM:                                             \
2798     Ok = isInt<NUM>(Imm);                                                      \
2799     break;
2800         CASE_OPERAND_UIMM(1)
2801         CASE_OPERAND_UIMM(2)
2802         CASE_OPERAND_UIMM(3)
2803         CASE_OPERAND_UIMM(4)
2804         CASE_OPERAND_UIMM(5)
2805         CASE_OPERAND_UIMM(6)
2806         CASE_OPERAND_UIMM(7)
2807         CASE_OPERAND_UIMM(8)
2808         CASE_OPERAND_UIMM(9)
2809         CASE_OPERAND_UIMM(10)
2810         CASE_OPERAND_UIMM(12)
2811         CASE_OPERAND_UIMM(16)
2812         CASE_OPERAND_UIMM(20)
2813         CASE_OPERAND_UIMM(32)
2814         CASE_OPERAND_UIMM(48)
2815         CASE_OPERAND_UIMM(64)
2816           // clang-format on
2817         case RISCVOp::OPERAND_UIMM2_LSB0:
2818           Ok = isShiftedUInt<1, 1>(Imm);
2819           break;
2820         case RISCVOp::OPERAND_UIMM5_LSB0:
2821           Ok = isShiftedUInt<4, 1>(Imm);
2822           break;
2823         case RISCVOp::OPERAND_UIMM5_NONZERO:
2824           Ok = isUInt<5>(Imm) && (Imm != 0);
2825           break;
2826         case RISCVOp::OPERAND_UIMM5_GT3:
2827           Ok = isUInt<5>(Imm) && (Imm > 3);
2828           break;
2829         case RISCVOp::OPERAND_UIMM5_PLUS1:
2830           Ok = (isUInt<5>(Imm) && (Imm != 0)) || (Imm == 32);
2831           break;
2832         case RISCVOp::OPERAND_UIMM6_LSB0:
2833           Ok = isShiftedUInt<5, 1>(Imm);
2834           break;
2835         case RISCVOp::OPERAND_UIMM7_LSB00:
2836           Ok = isShiftedUInt<5, 2>(Imm);
2837           break;
2838         case RISCVOp::OPERAND_UIMM7_LSB000:
2839           Ok = isShiftedUInt<4, 3>(Imm);
2840           break;
2841         case RISCVOp::OPERAND_UIMM8_LSB00:
2842           Ok = isShiftedUInt<6, 2>(Imm);
2843           break;
2844         case RISCVOp::OPERAND_UIMM8_LSB000:
2845           Ok = isShiftedUInt<5, 3>(Imm);
2846           break;
2847         case RISCVOp::OPERAND_UIMM8_GE32:
2848           Ok = isUInt<8>(Imm) && Imm >= 32;
2849           break;
2850         case RISCVOp::OPERAND_UIMM9_LSB000:
2851           Ok = isShiftedUInt<6, 3>(Imm);
2852           break;
2853         case RISCVOp::OPERAND_SIMM10_LSB0000_NONZERO:
2854           Ok = isShiftedInt<6, 4>(Imm) && (Imm != 0);
2855           break;
2856         case RISCVOp::OPERAND_UIMM10_LSB00_NONZERO:
2857           Ok = isShiftedUInt<8, 2>(Imm) && (Imm != 0);
2858           break;
2859         case RISCVOp::OPERAND_UIMM16_NONZERO:
2860           Ok = isUInt<16>(Imm) && (Imm != 0);
2861           break;
2862         case RISCVOp::OPERAND_ZERO:
2863           Ok = Imm == 0;
2864           break;
2865         case RISCVOp::OPERAND_THREE:
2866           Ok = Imm == 3;
2867           break;
2868         case RISCVOp::OPERAND_FOUR:
2869           Ok = Imm == 4;
2870           break;
2871           // clang-format off
2872         CASE_OPERAND_SIMM(5)
2873         CASE_OPERAND_SIMM(6)
2874         CASE_OPERAND_SIMM(11)
2875         CASE_OPERAND_SIMM(12)
2876         CASE_OPERAND_SIMM(26)
2877         // clang-format on
2878         case RISCVOp::OPERAND_SIMM5_PLUS1:
2879           Ok = (isInt<5>(Imm) && Imm != -16) || Imm == 16;
2880           break;
2881         case RISCVOp::OPERAND_SIMM5_NONZERO:
2882           Ok = isInt<5>(Imm) && (Imm != 0);
2883           break;
2884         case RISCVOp::OPERAND_SIMM6_NONZERO:
2885           Ok = Imm != 0 && isInt<6>(Imm);
2886           break;
2887         case RISCVOp::OPERAND_VTYPEI10:
2888           Ok = isUInt<10>(Imm);
2889           break;
2890         case RISCVOp::OPERAND_VTYPEI11:
2891           Ok = isUInt<11>(Imm);
2892           break;
2893         case RISCVOp::OPERAND_SIMM12_LSB00000:
2894           Ok = isShiftedInt<7, 5>(Imm);
2895           break;
2896         case RISCVOp::OPERAND_SIMM16_NONZERO:
2897           Ok = isInt<16>(Imm) && (Imm != 0);
2898           break;
2899         case RISCVOp::OPERAND_SIMM20_LI:
2900           Ok = isInt<20>(Imm);
2901           break;
2902         case RISCVOp::OPERAND_BARE_SIMM32:
2903           Ok = isInt<32>(Imm);
2904           break;
2905         case RISCVOp::OPERAND_UIMMLOG2XLEN:
2906           Ok = STI.is64Bit() ? isUInt<6>(Imm) : isUInt<5>(Imm);
2907           break;
2908         case RISCVOp::OPERAND_UIMMLOG2XLEN_NONZERO:
2909           Ok = STI.is64Bit() ? isUInt<6>(Imm) : isUInt<5>(Imm);
2910           Ok = Ok && Imm != 0;
2911           break;
2912         case RISCVOp::OPERAND_CLUI_IMM:
2913           Ok = (isUInt<5>(Imm) && Imm != 0) ||
2914                (Imm >= 0xfffe0 && Imm <= 0xfffff);
2915           break;
2916         case RISCVOp::OPERAND_RVKRNUM:
2917           Ok = Imm >= 0 && Imm <= 10;
2918           break;
2919         case RISCVOp::OPERAND_RVKRNUM_0_7:
2920           Ok = Imm >= 0 && Imm <= 7;
2921           break;
2922         case RISCVOp::OPERAND_RVKRNUM_1_10:
2923           Ok = Imm >= 1 && Imm <= 10;
2924           break;
2925         case RISCVOp::OPERAND_RVKRNUM_2_14:
2926           Ok = Imm >= 2 && Imm <= 14;
2927           break;
2928         case RISCVOp::OPERAND_RLIST:
2929           Ok = Imm >= RISCVZC::RA && Imm <= RISCVZC::RA_S0_S11;
2930           break;
2931         case RISCVOp::OPERAND_RLIST_S0:
2932           Ok = Imm >= RISCVZC::RA_S0 && Imm <= RISCVZC::RA_S0_S11;
2933           break;
2934         case RISCVOp::OPERAND_STACKADJ:
2935           Ok = Imm >= 0 && Imm <= 48 && Imm % 16 == 0;
2936           break;
2937         case RISCVOp::OPERAND_FRMARG:
2938           Ok = RISCVFPRndMode::isValidRoundingMode(Imm);
2939           break;
2940         case RISCVOp::OPERAND_RTZARG:
2941           Ok = Imm == RISCVFPRndMode::RTZ;
2942           break;
2943         case RISCVOp::OPERAND_COND_CODE:
2944           Ok = Imm >= 0 && Imm < RISCVCC::COND_INVALID;
2945           break;
2946         case RISCVOp::OPERAND_VEC_POLICY:
2947           Ok = (Imm &
2948                 (RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC)) == Imm;
2949           break;
2950         case RISCVOp::OPERAND_SEW:
2951           Ok = (isUInt<5>(Imm) && RISCVVType::isValidSEW(1 << Imm));
2952           break;
2953         case RISCVOp::OPERAND_SEW_MASK:
2954           Ok = Imm == 0;
2955           break;
2956         case RISCVOp::OPERAND_VEC_RM:
2957           assert(RISCVII::hasRoundModeOp(Desc.TSFlags));
2958           if (RISCVII::usesVXRM(Desc.TSFlags))
2959             Ok = isUInt<2>(Imm);
2960           else
2961             Ok = RISCVFPRndMode::isValidRoundingMode(Imm);
2962           break;
2963         }
2964         if (!Ok) {
2965           ErrInfo = "Invalid immediate";
2966           return false;
2967         }
2968       }
2969     }
2970   }
2971 
2972   const uint64_t TSFlags = Desc.TSFlags;
2973   if (RISCVII::hasVLOp(TSFlags)) {
2974     const MachineOperand &Op = MI.getOperand(RISCVII::getVLOpNum(Desc));
2975     if (!Op.isImm() && !Op.isReg())  {
2976       ErrInfo = "Invalid operand type for VL operand";
2977       return false;
2978     }
2979     if (Op.isReg() && Op.getReg() != RISCV::NoRegister) {
2980       const MachineRegisterInfo &MRI = MI.getParent()->getParent()->getRegInfo();
2981       auto *RC = MRI.getRegClass(Op.getReg());
2982       if (!RISCV::GPRRegClass.hasSubClassEq(RC)) {
2983         ErrInfo = "Invalid register class for VL operand";
2984         return false;
2985       }
2986     }
2987     if (!RISCVII::hasSEWOp(TSFlags)) {
2988       ErrInfo = "VL operand w/o SEW operand?";
2989       return false;
2990     }
2991   }
2992   if (RISCVII::hasSEWOp(TSFlags)) {
2993     unsigned OpIdx = RISCVII::getSEWOpNum(Desc);
2994     if (!MI.getOperand(OpIdx).isImm()) {
2995       ErrInfo = "SEW value expected to be an immediate";
2996       return false;
2997     }
2998     uint64_t Log2SEW = MI.getOperand(OpIdx).getImm();
2999     if (Log2SEW > 31) {
3000       ErrInfo = "Unexpected SEW value";
3001       return false;
3002     }
3003     unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
3004     if (!RISCVVType::isValidSEW(SEW)) {
3005       ErrInfo = "Unexpected SEW value";
3006       return false;
3007     }
3008   }
3009   if (RISCVII::hasVecPolicyOp(TSFlags)) {
3010     unsigned OpIdx = RISCVII::getVecPolicyOpNum(Desc);
3011     if (!MI.getOperand(OpIdx).isImm()) {
3012       ErrInfo = "Policy operand expected to be an immediate";
3013       return false;
3014     }
3015     uint64_t Policy = MI.getOperand(OpIdx).getImm();
3016     if (Policy > (RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC)) {
3017       ErrInfo = "Invalid Policy Value";
3018       return false;
3019     }
3020     if (!RISCVII::hasVLOp(TSFlags)) {
3021       ErrInfo = "policy operand w/o VL operand?";
3022       return false;
3023     }
3024 
3025     // VecPolicy operands can only exist on instructions with passthru/merge
3026     // arguments. Note that not all arguments with passthru have vec policy
3027     // operands- some instructions have implicit policies.
3028     unsigned UseOpIdx;
3029     if (!MI.isRegTiedToUseOperand(0, &UseOpIdx)) {
3030       ErrInfo = "policy operand w/o tied operand?";
3031       return false;
3032     }
3033   }
3034 
3035   if (int Idx = RISCVII::getFRMOpNum(Desc);
3036       Idx >= 0 && MI.getOperand(Idx).getImm() == RISCVFPRndMode::DYN &&
3037       !MI.readsRegister(RISCV::FRM, /*TRI=*/nullptr)) {
3038     ErrInfo = "dynamic rounding mode should read FRM";
3039     return false;
3040   }
3041 
3042   return true;
3043 }
3044 
canFoldIntoAddrMode(const MachineInstr & MemI,Register Reg,const MachineInstr & AddrI,ExtAddrMode & AM) const3045 bool RISCVInstrInfo::canFoldIntoAddrMode(const MachineInstr &MemI, Register Reg,
3046                                          const MachineInstr &AddrI,
3047                                          ExtAddrMode &AM) const {
3048   switch (MemI.getOpcode()) {
3049   default:
3050     return false;
3051   case RISCV::LB:
3052   case RISCV::LBU:
3053   case RISCV::LH:
3054   case RISCV::LH_INX:
3055   case RISCV::LHU:
3056   case RISCV::LW:
3057   case RISCV::LW_INX:
3058   case RISCV::LWU:
3059   case RISCV::LD:
3060   case RISCV::LD_RV32:
3061   case RISCV::FLH:
3062   case RISCV::FLW:
3063   case RISCV::FLD:
3064   case RISCV::SB:
3065   case RISCV::SH:
3066   case RISCV::SH_INX:
3067   case RISCV::SW:
3068   case RISCV::SW_INX:
3069   case RISCV::SD:
3070   case RISCV::SD_RV32:
3071   case RISCV::FSH:
3072   case RISCV::FSW:
3073   case RISCV::FSD:
3074     break;
3075   }
3076 
3077   if (MemI.getOperand(0).getReg() == Reg)
3078     return false;
3079 
3080   if (AddrI.getOpcode() != RISCV::ADDI || !AddrI.getOperand(1).isReg() ||
3081       !AddrI.getOperand(2).isImm())
3082     return false;
3083 
3084   int64_t OldOffset = MemI.getOperand(2).getImm();
3085   int64_t Disp = AddrI.getOperand(2).getImm();
3086   int64_t NewOffset = OldOffset + Disp;
3087   if (!STI.is64Bit())
3088     NewOffset = SignExtend64<32>(NewOffset);
3089 
3090   if (!isInt<12>(NewOffset))
3091     return false;
3092 
3093   AM.BaseReg = AddrI.getOperand(1).getReg();
3094   AM.ScaledReg = 0;
3095   AM.Scale = 0;
3096   AM.Displacement = NewOffset;
3097   AM.Form = ExtAddrMode::Formula::Basic;
3098   return true;
3099 }
3100 
emitLdStWithAddr(MachineInstr & MemI,const ExtAddrMode & AM) const3101 MachineInstr *RISCVInstrInfo::emitLdStWithAddr(MachineInstr &MemI,
3102                                                const ExtAddrMode &AM) const {
3103 
3104   const DebugLoc &DL = MemI.getDebugLoc();
3105   MachineBasicBlock &MBB = *MemI.getParent();
3106 
3107   assert(AM.ScaledReg == 0 && AM.Scale == 0 &&
3108          "Addressing mode not supported for folding");
3109 
3110   return BuildMI(MBB, MemI, DL, get(MemI.getOpcode()))
3111       .addReg(MemI.getOperand(0).getReg(),
3112               MemI.mayLoad() ? RegState::Define : 0)
3113       .addReg(AM.BaseReg)
3114       .addImm(AM.Displacement)
3115       .setMemRefs(MemI.memoperands())
3116       .setMIFlags(MemI.getFlags());
3117 }
3118 
3119 // TODO: At the moment, MIPS introduced paring of instructions operating with
3120 // word or double word. This should be extended with more instructions when more
3121 // vendors support load/store pairing.
isPairableLdStInstOpc(unsigned Opc)3122 bool RISCVInstrInfo::isPairableLdStInstOpc(unsigned Opc) {
3123   switch (Opc) {
3124   default:
3125     return false;
3126   case RISCV::SW:
3127   case RISCV::SD:
3128   case RISCV::LD:
3129   case RISCV::LW:
3130     return true;
3131   }
3132 }
3133 
isLdStSafeToPair(const MachineInstr & LdSt,const TargetRegisterInfo * TRI)3134 bool RISCVInstrInfo::isLdStSafeToPair(const MachineInstr &LdSt,
3135                                       const TargetRegisterInfo *TRI) {
3136   // If this is a volatile load/store, don't mess with it.
3137   if (LdSt.hasOrderedMemoryRef() || LdSt.getNumExplicitOperands() != 3)
3138     return false;
3139 
3140   if (LdSt.getOperand(1).isFI())
3141     return true;
3142 
3143   assert(LdSt.getOperand(1).isReg() && "Expected a reg operand.");
3144   // Can't cluster if the instruction modifies the base register
3145   // or it is update form. e.g. ld x5,8(x5)
3146   if (LdSt.modifiesRegister(LdSt.getOperand(1).getReg(), TRI))
3147     return false;
3148 
3149   if (!LdSt.getOperand(2).isImm())
3150     return false;
3151 
3152   return true;
3153 }
3154 
getMemOperandsWithOffsetWidth(const MachineInstr & LdSt,SmallVectorImpl<const MachineOperand * > & BaseOps,int64_t & Offset,bool & OffsetIsScalable,LocationSize & Width,const TargetRegisterInfo * TRI) const3155 bool RISCVInstrInfo::getMemOperandsWithOffsetWidth(
3156     const MachineInstr &LdSt, SmallVectorImpl<const MachineOperand *> &BaseOps,
3157     int64_t &Offset, bool &OffsetIsScalable, LocationSize &Width,
3158     const TargetRegisterInfo *TRI) const {
3159   if (!LdSt.mayLoadOrStore())
3160     return false;
3161 
3162   // Conservatively, only handle scalar loads/stores for now.
3163   switch (LdSt.getOpcode()) {
3164   case RISCV::LB:
3165   case RISCV::LBU:
3166   case RISCV::SB:
3167   case RISCV::LH:
3168   case RISCV::LH_INX:
3169   case RISCV::LHU:
3170   case RISCV::FLH:
3171   case RISCV::SH:
3172   case RISCV::SH_INX:
3173   case RISCV::FSH:
3174   case RISCV::LW:
3175   case RISCV::LW_INX:
3176   case RISCV::LWU:
3177   case RISCV::FLW:
3178   case RISCV::SW:
3179   case RISCV::SW_INX:
3180   case RISCV::FSW:
3181   case RISCV::LD:
3182   case RISCV::LD_RV32:
3183   case RISCV::FLD:
3184   case RISCV::SD:
3185   case RISCV::SD_RV32:
3186   case RISCV::FSD:
3187     break;
3188   default:
3189     return false;
3190   }
3191   const MachineOperand *BaseOp;
3192   OffsetIsScalable = false;
3193   if (!getMemOperandWithOffsetWidth(LdSt, BaseOp, Offset, Width, TRI))
3194     return false;
3195   BaseOps.push_back(BaseOp);
3196   return true;
3197 }
3198 
3199 // TODO: This was copied from SIInstrInfo. Could it be lifted to a common
3200 // helper?
memOpsHaveSameBasePtr(const MachineInstr & MI1,ArrayRef<const MachineOperand * > BaseOps1,const MachineInstr & MI2,ArrayRef<const MachineOperand * > BaseOps2)3201 static bool memOpsHaveSameBasePtr(const MachineInstr &MI1,
3202                                   ArrayRef<const MachineOperand *> BaseOps1,
3203                                   const MachineInstr &MI2,
3204                                   ArrayRef<const MachineOperand *> BaseOps2) {
3205   // Only examine the first "base" operand of each instruction, on the
3206   // assumption that it represents the real base address of the memory access.
3207   // Other operands are typically offsets or indices from this base address.
3208   if (BaseOps1.front()->isIdenticalTo(*BaseOps2.front()))
3209     return true;
3210 
3211   if (!MI1.hasOneMemOperand() || !MI2.hasOneMemOperand())
3212     return false;
3213 
3214   auto MO1 = *MI1.memoperands_begin();
3215   auto MO2 = *MI2.memoperands_begin();
3216   if (MO1->getAddrSpace() != MO2->getAddrSpace())
3217     return false;
3218 
3219   auto Base1 = MO1->getValue();
3220   auto Base2 = MO2->getValue();
3221   if (!Base1 || !Base2)
3222     return false;
3223   Base1 = getUnderlyingObject(Base1);
3224   Base2 = getUnderlyingObject(Base2);
3225 
3226   if (isa<UndefValue>(Base1) || isa<UndefValue>(Base2))
3227     return false;
3228 
3229   return Base1 == Base2;
3230 }
3231 
shouldClusterMemOps(ArrayRef<const MachineOperand * > BaseOps1,int64_t Offset1,bool OffsetIsScalable1,ArrayRef<const MachineOperand * > BaseOps2,int64_t Offset2,bool OffsetIsScalable2,unsigned ClusterSize,unsigned NumBytes) const3232 bool RISCVInstrInfo::shouldClusterMemOps(
3233     ArrayRef<const MachineOperand *> BaseOps1, int64_t Offset1,
3234     bool OffsetIsScalable1, ArrayRef<const MachineOperand *> BaseOps2,
3235     int64_t Offset2, bool OffsetIsScalable2, unsigned ClusterSize,
3236     unsigned NumBytes) const {
3237   // If the mem ops (to be clustered) do not have the same base ptr, then they
3238   // should not be clustered
3239   if (!BaseOps1.empty() && !BaseOps2.empty()) {
3240     const MachineInstr &FirstLdSt = *BaseOps1.front()->getParent();
3241     const MachineInstr &SecondLdSt = *BaseOps2.front()->getParent();
3242     if (!memOpsHaveSameBasePtr(FirstLdSt, BaseOps1, SecondLdSt, BaseOps2))
3243       return false;
3244   } else if (!BaseOps1.empty() || !BaseOps2.empty()) {
3245     // If only one base op is empty, they do not have the same base ptr
3246     return false;
3247   }
3248 
3249   unsigned CacheLineSize =
3250       BaseOps1.front()->getParent()->getMF()->getSubtarget().getCacheLineSize();
3251   // Assume a cache line size of 64 bytes if no size is set in RISCVSubtarget.
3252   CacheLineSize = CacheLineSize ? CacheLineSize : 64;
3253   // Cluster if the memory operations are on the same or a neighbouring cache
3254   // line, but limit the maximum ClusterSize to avoid creating too much
3255   // additional register pressure.
3256   return ClusterSize <= 4 && std::abs(Offset1 - Offset2) < CacheLineSize;
3257 }
3258 
3259 // Set BaseReg (the base register operand), Offset (the byte offset being
3260 // accessed) and the access Width of the passed instruction that reads/writes
3261 // memory. Returns false if the instruction does not read/write memory or the
3262 // BaseReg/Offset/Width can't be determined. Is not guaranteed to always
3263 // recognise base operands and offsets in all cases.
3264 // TODO: Add an IsScalable bool ref argument (like the equivalent AArch64
3265 // function) and set it as appropriate.
getMemOperandWithOffsetWidth(const MachineInstr & LdSt,const MachineOperand * & BaseReg,int64_t & Offset,LocationSize & Width,const TargetRegisterInfo * TRI) const3266 bool RISCVInstrInfo::getMemOperandWithOffsetWidth(
3267     const MachineInstr &LdSt, const MachineOperand *&BaseReg, int64_t &Offset,
3268     LocationSize &Width, const TargetRegisterInfo *TRI) const {
3269   if (!LdSt.mayLoadOrStore())
3270     return false;
3271 
3272   // Here we assume the standard RISC-V ISA, which uses a base+offset
3273   // addressing mode. You'll need to relax these conditions to support custom
3274   // load/store instructions.
3275   if (LdSt.getNumExplicitOperands() != 3)
3276     return false;
3277   if ((!LdSt.getOperand(1).isReg() && !LdSt.getOperand(1).isFI()) ||
3278       !LdSt.getOperand(2).isImm())
3279     return false;
3280 
3281   if (!LdSt.hasOneMemOperand())
3282     return false;
3283 
3284   Width = (*LdSt.memoperands_begin())->getSize();
3285   BaseReg = &LdSt.getOperand(1);
3286   Offset = LdSt.getOperand(2).getImm();
3287   return true;
3288 }
3289 
areMemAccessesTriviallyDisjoint(const MachineInstr & MIa,const MachineInstr & MIb) const3290 bool RISCVInstrInfo::areMemAccessesTriviallyDisjoint(
3291     const MachineInstr &MIa, const MachineInstr &MIb) const {
3292   assert(MIa.mayLoadOrStore() && "MIa must be a load or store.");
3293   assert(MIb.mayLoadOrStore() && "MIb must be a load or store.");
3294 
3295   if (MIa.hasUnmodeledSideEffects() || MIb.hasUnmodeledSideEffects() ||
3296       MIa.hasOrderedMemoryRef() || MIb.hasOrderedMemoryRef())
3297     return false;
3298 
3299   // Retrieve the base register, offset from the base register and width. Width
3300   // is the size of memory that is being loaded/stored (e.g. 1, 2, 4).  If
3301   // base registers are identical, and the offset of a lower memory access +
3302   // the width doesn't overlap the offset of a higher memory access,
3303   // then the memory accesses are different.
3304   const TargetRegisterInfo *TRI = STI.getRegisterInfo();
3305   const MachineOperand *BaseOpA = nullptr, *BaseOpB = nullptr;
3306   int64_t OffsetA = 0, OffsetB = 0;
3307   LocationSize WidthA = LocationSize::precise(0),
3308                WidthB = LocationSize::precise(0);
3309   if (getMemOperandWithOffsetWidth(MIa, BaseOpA, OffsetA, WidthA, TRI) &&
3310       getMemOperandWithOffsetWidth(MIb, BaseOpB, OffsetB, WidthB, TRI)) {
3311     if (BaseOpA->isIdenticalTo(*BaseOpB)) {
3312       int LowOffset = std::min(OffsetA, OffsetB);
3313       int HighOffset = std::max(OffsetA, OffsetB);
3314       LocationSize LowWidth = (LowOffset == OffsetA) ? WidthA : WidthB;
3315       if (LowWidth.hasValue() &&
3316           LowOffset + (int)LowWidth.getValue() <= HighOffset)
3317         return true;
3318     }
3319   }
3320   return false;
3321 }
3322 
3323 std::pair<unsigned, unsigned>
decomposeMachineOperandsTargetFlags(unsigned TF) const3324 RISCVInstrInfo::decomposeMachineOperandsTargetFlags(unsigned TF) const {
3325   const unsigned Mask = RISCVII::MO_DIRECT_FLAG_MASK;
3326   return std::make_pair(TF & Mask, TF & ~Mask);
3327 }
3328 
3329 ArrayRef<std::pair<unsigned, const char *>>
getSerializableDirectMachineOperandTargetFlags() const3330 RISCVInstrInfo::getSerializableDirectMachineOperandTargetFlags() const {
3331   using namespace RISCVII;
3332   static const std::pair<unsigned, const char *> TargetFlags[] = {
3333       {MO_CALL, "riscv-call"},
3334       {MO_LO, "riscv-lo"},
3335       {MO_HI, "riscv-hi"},
3336       {MO_PCREL_LO, "riscv-pcrel-lo"},
3337       {MO_PCREL_HI, "riscv-pcrel-hi"},
3338       {MO_GOT_HI, "riscv-got-hi"},
3339       {MO_TPREL_LO, "riscv-tprel-lo"},
3340       {MO_TPREL_HI, "riscv-tprel-hi"},
3341       {MO_TPREL_ADD, "riscv-tprel-add"},
3342       {MO_TLS_GOT_HI, "riscv-tls-got-hi"},
3343       {MO_TLS_GD_HI, "riscv-tls-gd-hi"},
3344       {MO_TLSDESC_HI, "riscv-tlsdesc-hi"},
3345       {MO_TLSDESC_LOAD_LO, "riscv-tlsdesc-load-lo"},
3346       {MO_TLSDESC_ADD_LO, "riscv-tlsdesc-add-lo"},
3347       {MO_TLSDESC_CALL, "riscv-tlsdesc-call"}};
3348   return ArrayRef(TargetFlags);
3349 }
isFunctionSafeToOutlineFrom(MachineFunction & MF,bool OutlineFromLinkOnceODRs) const3350 bool RISCVInstrInfo::isFunctionSafeToOutlineFrom(
3351     MachineFunction &MF, bool OutlineFromLinkOnceODRs) const {
3352   const Function &F = MF.getFunction();
3353 
3354   // Can F be deduplicated by the linker? If it can, don't outline from it.
3355   if (!OutlineFromLinkOnceODRs && F.hasLinkOnceODRLinkage())
3356     return false;
3357 
3358   // Don't outline from functions with section markings; the program could
3359   // expect that all the code is in the named section.
3360   if (F.hasSection())
3361     return false;
3362 
3363   // It's safe to outline from MF.
3364   return true;
3365 }
3366 
isMBBSafeToOutlineFrom(MachineBasicBlock & MBB,unsigned & Flags) const3367 bool RISCVInstrInfo::isMBBSafeToOutlineFrom(MachineBasicBlock &MBB,
3368                                             unsigned &Flags) const {
3369   // More accurate safety checking is done in getOutliningCandidateInfo.
3370   return TargetInstrInfo::isMBBSafeToOutlineFrom(MBB, Flags);
3371 }
3372 
3373 // Enum values indicating how an outlined call should be constructed.
3374 enum MachineOutlinerConstructionID {
3375   MachineOutlinerTailCall,
3376   MachineOutlinerDefault
3377 };
3378 
shouldOutlineFromFunctionByDefault(MachineFunction & MF) const3379 bool RISCVInstrInfo::shouldOutlineFromFunctionByDefault(
3380     MachineFunction &MF) const {
3381   return MF.getFunction().hasMinSize();
3382 }
3383 
isCandidatePatchable(const MachineBasicBlock & MBB)3384 static bool isCandidatePatchable(const MachineBasicBlock &MBB) {
3385   const MachineFunction *MF = MBB.getParent();
3386   const Function &F = MF->getFunction();
3387   return F.getFnAttribute("fentry-call").getValueAsBool() ||
3388          F.hasFnAttribute("patchable-function-entry");
3389 }
3390 
isMIReadsReg(const MachineInstr & MI,const TargetRegisterInfo * TRI,MCRegister RegNo)3391 static bool isMIReadsReg(const MachineInstr &MI, const TargetRegisterInfo *TRI,
3392                          MCRegister RegNo) {
3393   return MI.readsRegister(RegNo, TRI) ||
3394          MI.getDesc().hasImplicitUseOfPhysReg(RegNo);
3395 }
3396 
isMIModifiesReg(const MachineInstr & MI,const TargetRegisterInfo * TRI,MCRegister RegNo)3397 static bool isMIModifiesReg(const MachineInstr &MI,
3398                             const TargetRegisterInfo *TRI, MCRegister RegNo) {
3399   return MI.modifiesRegister(RegNo, TRI) ||
3400          MI.getDesc().hasImplicitDefOfPhysReg(RegNo);
3401 }
3402 
cannotInsertTailCall(const MachineBasicBlock & MBB)3403 static bool cannotInsertTailCall(const MachineBasicBlock &MBB) {
3404   if (!MBB.back().isReturn())
3405     return true;
3406   if (isCandidatePatchable(MBB))
3407     return true;
3408 
3409   // If the candidate reads the pre-set register
3410   // that can be used for expanding PseudoTAIL instruction,
3411   // then we cannot insert tail call.
3412   const TargetSubtargetInfo &STI = MBB.getParent()->getSubtarget();
3413   MCRegister TailExpandUseRegNo =
3414       RISCVII::getTailExpandUseRegNo(STI.getFeatureBits());
3415   for (const MachineInstr &MI : MBB) {
3416     if (isMIReadsReg(MI, STI.getRegisterInfo(), TailExpandUseRegNo))
3417       return true;
3418     if (isMIModifiesReg(MI, STI.getRegisterInfo(), TailExpandUseRegNo))
3419       break;
3420   }
3421   return false;
3422 }
3423 
analyzeCandidate(outliner::Candidate & C)3424 static bool analyzeCandidate(outliner::Candidate &C) {
3425   // If last instruction is return then we can rely on
3426   // the verification already performed in the getOutliningTypeImpl.
3427   if (C.back().isReturn()) {
3428     assert(!cannotInsertTailCall(*C.getMBB()) &&
3429            "The candidate who uses return instruction must be outlined "
3430            "using tail call");
3431     return false;
3432   }
3433 
3434   // Filter out candidates where the X5 register (t0) can't be used to setup
3435   // the function call.
3436   const TargetRegisterInfo *TRI = C.getMF()->getSubtarget().getRegisterInfo();
3437   if (llvm::any_of(C, [TRI](const MachineInstr &MI) {
3438         return isMIModifiesReg(MI, TRI, RISCV::X5);
3439       }))
3440     return true;
3441 
3442   return !C.isAvailableAcrossAndOutOfSeq(RISCV::X5, *TRI);
3443 }
3444 
3445 std::optional<std::unique_ptr<outliner::OutlinedFunction>>
getOutliningCandidateInfo(const MachineModuleInfo & MMI,std::vector<outliner::Candidate> & RepeatedSequenceLocs,unsigned MinRepeats) const3446 RISCVInstrInfo::getOutliningCandidateInfo(
3447     const MachineModuleInfo &MMI,
3448     std::vector<outliner::Candidate> &RepeatedSequenceLocs,
3449     unsigned MinRepeats) const {
3450 
3451   // Analyze each candidate and erase the ones that are not viable.
3452   llvm::erase_if(RepeatedSequenceLocs, analyzeCandidate);
3453 
3454   // If the sequence doesn't have enough candidates left, then we're done.
3455   if (RepeatedSequenceLocs.size() < MinRepeats)
3456     return std::nullopt;
3457 
3458   // Each RepeatedSequenceLoc is identical.
3459   outliner::Candidate &Candidate = RepeatedSequenceLocs[0];
3460   unsigned InstrSizeCExt =
3461       Candidate.getMF()->getSubtarget<RISCVSubtarget>().hasStdExtZca() ? 2 : 4;
3462   unsigned CallOverhead = 0, FrameOverhead = 0;
3463 
3464   MachineOutlinerConstructionID MOCI = MachineOutlinerDefault;
3465   if (Candidate.back().isReturn()) {
3466     MOCI = MachineOutlinerTailCall;
3467     // tail call = auipc + jalr in the worst case without linker relaxation.
3468     // FIXME: This code suggests the JALR can be compressed - how?
3469     CallOverhead = 4 + InstrSizeCExt;
3470     // Using tail call we move ret instruction from caller to callee.
3471     FrameOverhead = 0;
3472   } else {
3473     // call t0, function = 8 bytes.
3474     CallOverhead = 8;
3475     // jr t0 = 4 bytes, 2 bytes if compressed instructions are enabled.
3476     FrameOverhead = InstrSizeCExt;
3477   }
3478 
3479   for (auto &C : RepeatedSequenceLocs)
3480     C.setCallInfo(MOCI, CallOverhead);
3481 
3482   unsigned SequenceSize = 0;
3483   for (auto &MI : Candidate)
3484     SequenceSize += getInstSizeInBytes(MI);
3485 
3486   return std::make_unique<outliner::OutlinedFunction>(
3487       RepeatedSequenceLocs, SequenceSize, FrameOverhead, MOCI);
3488 }
3489 
3490 outliner::InstrType
getOutliningTypeImpl(const MachineModuleInfo & MMI,MachineBasicBlock::iterator & MBBI,unsigned Flags) const3491 RISCVInstrInfo::getOutliningTypeImpl(const MachineModuleInfo &MMI,
3492                                      MachineBasicBlock::iterator &MBBI,
3493                                      unsigned Flags) const {
3494   MachineInstr &MI = *MBBI;
3495   MachineBasicBlock *MBB = MI.getParent();
3496   const TargetRegisterInfo *TRI =
3497       MBB->getParent()->getSubtarget().getRegisterInfo();
3498   const auto &F = MI.getMF()->getFunction();
3499 
3500   // We can manually strip out CFI instructions later.
3501   if (MI.isCFIInstruction())
3502     // If current function has exception handling code, we can't outline &
3503     // strip these CFI instructions since it may break .eh_frame section
3504     // needed in unwinding.
3505     return F.needsUnwindTableEntry() ? outliner::InstrType::Illegal
3506                                      : outliner::InstrType::Invisible;
3507 
3508   if (cannotInsertTailCall(*MBB) &&
3509       (MI.isReturn() || isMIModifiesReg(MI, TRI, RISCV::X5)))
3510     return outliner::InstrType::Illegal;
3511 
3512   // Make sure the operands don't reference something unsafe.
3513   for (const auto &MO : MI.operands()) {
3514 
3515     // pcrel-hi and pcrel-lo can't put in separate sections, filter that out
3516     // if any possible.
3517     if (MO.getTargetFlags() == RISCVII::MO_PCREL_LO &&
3518         (MI.getMF()->getTarget().getFunctionSections() || F.hasComdat() ||
3519          F.hasSection() || F.getSectionPrefix()))
3520       return outliner::InstrType::Illegal;
3521   }
3522 
3523   return outliner::InstrType::Legal;
3524 }
3525 
buildOutlinedFrame(MachineBasicBlock & MBB,MachineFunction & MF,const outliner::OutlinedFunction & OF) const3526 void RISCVInstrInfo::buildOutlinedFrame(
3527     MachineBasicBlock &MBB, MachineFunction &MF,
3528     const outliner::OutlinedFunction &OF) const {
3529 
3530   // Strip out any CFI instructions
3531   bool Changed = true;
3532   while (Changed) {
3533     Changed = false;
3534     auto I = MBB.begin();
3535     auto E = MBB.end();
3536     for (; I != E; ++I) {
3537       if (I->isCFIInstruction()) {
3538         I->removeFromParent();
3539         Changed = true;
3540         break;
3541       }
3542     }
3543   }
3544 
3545   if (OF.FrameConstructionID == MachineOutlinerTailCall)
3546     return;
3547 
3548   MBB.addLiveIn(RISCV::X5);
3549 
3550   // Add in a return instruction to the end of the outlined frame.
3551   MBB.insert(MBB.end(), BuildMI(MF, DebugLoc(), get(RISCV::JALR))
3552       .addReg(RISCV::X0, RegState::Define)
3553       .addReg(RISCV::X5)
3554       .addImm(0));
3555 }
3556 
insertOutlinedCall(Module & M,MachineBasicBlock & MBB,MachineBasicBlock::iterator & It,MachineFunction & MF,outliner::Candidate & C) const3557 MachineBasicBlock::iterator RISCVInstrInfo::insertOutlinedCall(
3558     Module &M, MachineBasicBlock &MBB, MachineBasicBlock::iterator &It,
3559     MachineFunction &MF, outliner::Candidate &C) const {
3560 
3561   if (C.CallConstructionID == MachineOutlinerTailCall) {
3562     It = MBB.insert(It, BuildMI(MF, DebugLoc(), get(RISCV::PseudoTAIL))
3563                             .addGlobalAddress(M.getNamedValue(MF.getName()),
3564                                               /*Offset=*/0, RISCVII::MO_CALL));
3565     return It;
3566   }
3567 
3568   // Add in a call instruction to the outlined function at the given location.
3569   It = MBB.insert(It,
3570                   BuildMI(MF, DebugLoc(), get(RISCV::PseudoCALLReg), RISCV::X5)
3571                       .addGlobalAddress(M.getNamedValue(MF.getName()), 0,
3572                                         RISCVII::MO_CALL));
3573   return It;
3574 }
3575 
isAddImmediate(const MachineInstr & MI,Register Reg) const3576 std::optional<RegImmPair> RISCVInstrInfo::isAddImmediate(const MachineInstr &MI,
3577                                                          Register Reg) const {
3578   // TODO: Handle cases where Reg is a super- or sub-register of the
3579   // destination register.
3580   const MachineOperand &Op0 = MI.getOperand(0);
3581   if (!Op0.isReg() || Reg != Op0.getReg())
3582     return std::nullopt;
3583 
3584   // Don't consider ADDIW as a candidate because the caller may not be aware
3585   // of its sign extension behaviour.
3586   if (MI.getOpcode() == RISCV::ADDI && MI.getOperand(1).isReg() &&
3587       MI.getOperand(2).isImm())
3588     return RegImmPair{MI.getOperand(1).getReg(), MI.getOperand(2).getImm()};
3589 
3590   return std::nullopt;
3591 }
3592 
3593 // MIR printer helper function to annotate Operands with a comment.
createMIROperandComment(const MachineInstr & MI,const MachineOperand & Op,unsigned OpIdx,const TargetRegisterInfo * TRI) const3594 std::string RISCVInstrInfo::createMIROperandComment(
3595     const MachineInstr &MI, const MachineOperand &Op, unsigned OpIdx,
3596     const TargetRegisterInfo *TRI) const {
3597   // Print a generic comment for this operand if there is one.
3598   std::string GenericComment =
3599       TargetInstrInfo::createMIROperandComment(MI, Op, OpIdx, TRI);
3600   if (!GenericComment.empty())
3601     return GenericComment;
3602 
3603   // If not, we must have an immediate operand.
3604   if (!Op.isImm())
3605     return std::string();
3606 
3607   const MCInstrDesc &Desc = MI.getDesc();
3608   if (OpIdx >= Desc.getNumOperands())
3609     return std::string();
3610 
3611   std::string Comment;
3612   raw_string_ostream OS(Comment);
3613 
3614   const MCOperandInfo &OpInfo = Desc.operands()[OpIdx];
3615 
3616   // Print the full VType operand of vsetvli/vsetivli instructions, and the SEW
3617   // operand of vector codegen pseudos.
3618   switch (OpInfo.OperandType) {
3619   case RISCVOp::OPERAND_VTYPEI10:
3620   case RISCVOp::OPERAND_VTYPEI11: {
3621     unsigned Imm = Op.getImm();
3622     RISCVVType::printVType(Imm, OS);
3623     break;
3624   }
3625   case RISCVOp::OPERAND_SEW:
3626   case RISCVOp::OPERAND_SEW_MASK: {
3627     unsigned Log2SEW = Op.getImm();
3628     unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
3629     assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW");
3630     OS << "e" << SEW;
3631     break;
3632   }
3633   case RISCVOp::OPERAND_VEC_POLICY:
3634     unsigned Policy = Op.getImm();
3635     assert(Policy <= (RISCVVType::TAIL_AGNOSTIC | RISCVVType::MASK_AGNOSTIC) &&
3636            "Invalid Policy Value");
3637     OS << (Policy & RISCVVType::TAIL_AGNOSTIC ? "ta" : "tu") << ", "
3638        << (Policy & RISCVVType::MASK_AGNOSTIC ? "ma" : "mu");
3639     break;
3640   }
3641 
3642   return Comment;
3643 }
3644 
3645 // clang-format off
3646 #define CASE_RVV_OPCODE_UNMASK_LMUL(OP, LMUL)                                 \
3647   RISCV::Pseudo##OP##_##LMUL
3648 
3649 #define CASE_RVV_OPCODE_MASK_LMUL(OP, LMUL)                                   \
3650   RISCV::Pseudo##OP##_##LMUL##_MASK
3651 
3652 #define CASE_RVV_OPCODE_LMUL(OP, LMUL)                                        \
3653   CASE_RVV_OPCODE_UNMASK_LMUL(OP, LMUL):                                      \
3654   case CASE_RVV_OPCODE_MASK_LMUL(OP, LMUL)
3655 
3656 #define CASE_RVV_OPCODE_UNMASK_WIDEN(OP)                                      \
3657   CASE_RVV_OPCODE_UNMASK_LMUL(OP, MF8):                                       \
3658   case CASE_RVV_OPCODE_UNMASK_LMUL(OP, MF4):                                  \
3659   case CASE_RVV_OPCODE_UNMASK_LMUL(OP, MF2):                                  \
3660   case CASE_RVV_OPCODE_UNMASK_LMUL(OP, M1):                                   \
3661   case CASE_RVV_OPCODE_UNMASK_LMUL(OP, M2):                                   \
3662   case CASE_RVV_OPCODE_UNMASK_LMUL(OP, M4)
3663 
3664 #define CASE_RVV_OPCODE_UNMASK(OP)                                            \
3665   CASE_RVV_OPCODE_UNMASK_WIDEN(OP):                                           \
3666   case CASE_RVV_OPCODE_UNMASK_LMUL(OP, M8)
3667 
3668 #define CASE_RVV_OPCODE_MASK_WIDEN(OP)                                        \
3669   CASE_RVV_OPCODE_MASK_LMUL(OP, MF8):                                         \
3670   case CASE_RVV_OPCODE_MASK_LMUL(OP, MF4):                                    \
3671   case CASE_RVV_OPCODE_MASK_LMUL(OP, MF2):                                    \
3672   case CASE_RVV_OPCODE_MASK_LMUL(OP, M1):                                     \
3673   case CASE_RVV_OPCODE_MASK_LMUL(OP, M2):                                     \
3674   case CASE_RVV_OPCODE_MASK_LMUL(OP, M4)
3675 
3676 #define CASE_RVV_OPCODE_MASK(OP)                                              \
3677   CASE_RVV_OPCODE_MASK_WIDEN(OP):                                             \
3678   case CASE_RVV_OPCODE_MASK_LMUL(OP, M8)
3679 
3680 #define CASE_RVV_OPCODE_WIDEN(OP)                                             \
3681   CASE_RVV_OPCODE_UNMASK_WIDEN(OP):                                           \
3682   case CASE_RVV_OPCODE_MASK_WIDEN(OP)
3683 
3684 #define CASE_RVV_OPCODE(OP)                                                   \
3685   CASE_RVV_OPCODE_UNMASK(OP):                                                 \
3686   case CASE_RVV_OPCODE_MASK(OP)
3687 // clang-format on
3688 
3689 // clang-format off
3690 #define CASE_VMA_OPCODE_COMMON(OP, TYPE, LMUL)                                 \
3691   RISCV::PseudoV##OP##_##TYPE##_##LMUL
3692 
3693 #define CASE_VMA_OPCODE_LMULS(OP, TYPE)                                        \
3694   CASE_VMA_OPCODE_COMMON(OP, TYPE, MF8):                                       \
3695   case CASE_VMA_OPCODE_COMMON(OP, TYPE, MF4):                                  \
3696   case CASE_VMA_OPCODE_COMMON(OP, TYPE, MF2):                                  \
3697   case CASE_VMA_OPCODE_COMMON(OP, TYPE, M1):                                   \
3698   case CASE_VMA_OPCODE_COMMON(OP, TYPE, M2):                                   \
3699   case CASE_VMA_OPCODE_COMMON(OP, TYPE, M4):                                   \
3700   case CASE_VMA_OPCODE_COMMON(OP, TYPE, M8)
3701 
3702 // VFMA instructions are SEW specific.
3703 #define CASE_VFMA_OPCODE_COMMON(OP, TYPE, LMUL, SEW)                           \
3704   RISCV::PseudoV##OP##_##TYPE##_##LMUL##_##SEW
3705 
3706 #define CASE_VFMA_OPCODE_LMULS_M1(OP, TYPE, SEW)                               \
3707   CASE_VFMA_OPCODE_COMMON(OP, TYPE, M1, SEW):                                  \
3708   case CASE_VFMA_OPCODE_COMMON(OP, TYPE, M2, SEW):                             \
3709   case CASE_VFMA_OPCODE_COMMON(OP, TYPE, M4, SEW):                             \
3710   case CASE_VFMA_OPCODE_COMMON(OP, TYPE, M8, SEW)
3711 
3712 #define CASE_VFMA_OPCODE_LMULS_MF2(OP, TYPE, SEW)                              \
3713   CASE_VFMA_OPCODE_COMMON(OP, TYPE, MF2, SEW):                                 \
3714   case CASE_VFMA_OPCODE_LMULS_M1(OP, TYPE, SEW)
3715 
3716 #define CASE_VFMA_OPCODE_LMULS_MF4(OP, TYPE, SEW)                              \
3717   CASE_VFMA_OPCODE_COMMON(OP, TYPE, MF4, SEW):                                 \
3718   case CASE_VFMA_OPCODE_LMULS_MF2(OP, TYPE, SEW)
3719 
3720 #define CASE_VFMA_OPCODE_VV(OP)                                                \
3721   CASE_VFMA_OPCODE_LMULS_MF4(OP, VV, E16):                                     \
3722   case CASE_VFMA_OPCODE_LMULS_MF2(OP, VV, E32):                                \
3723   case CASE_VFMA_OPCODE_LMULS_M1(OP, VV, E64)
3724 
3725 #define CASE_VFMA_SPLATS(OP)                                                   \
3726   CASE_VFMA_OPCODE_LMULS_MF4(OP, VFPR16, E16):                                 \
3727   case CASE_VFMA_OPCODE_LMULS_MF2(OP, VFPR32, E32):                            \
3728   case CASE_VFMA_OPCODE_LMULS_M1(OP, VFPR64, E64)
3729 // clang-format on
3730 
findCommutedOpIndices(const MachineInstr & MI,unsigned & SrcOpIdx1,unsigned & SrcOpIdx2) const3731 bool RISCVInstrInfo::findCommutedOpIndices(const MachineInstr &MI,
3732                                            unsigned &SrcOpIdx1,
3733                                            unsigned &SrcOpIdx2) const {
3734   const MCInstrDesc &Desc = MI.getDesc();
3735   if (!Desc.isCommutable())
3736     return false;
3737 
3738   switch (MI.getOpcode()) {
3739   case RISCV::TH_MVEQZ:
3740   case RISCV::TH_MVNEZ:
3741     // We can't commute operands if operand 2 (i.e., rs1 in
3742     // mveqz/mvnez rd,rs1,rs2) is the zero-register (as it is
3743     // not valid as the in/out-operand 1).
3744     if (MI.getOperand(2).getReg() == RISCV::X0)
3745       return false;
3746     // Operands 1 and 2 are commutable, if we switch the opcode.
3747     return fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, 1, 2);
3748   case RISCV::TH_MULA:
3749   case RISCV::TH_MULAW:
3750   case RISCV::TH_MULAH:
3751   case RISCV::TH_MULS:
3752   case RISCV::TH_MULSW:
3753   case RISCV::TH_MULSH:
3754     // Operands 2 and 3 are commutable.
3755     return fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, 2, 3);
3756   case RISCV::PseudoCCMOVGPRNoX0:
3757   case RISCV::PseudoCCMOVGPR:
3758     // Operands 4 and 5 are commutable.
3759     return fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, 4, 5);
3760   case CASE_RVV_OPCODE(VADD_VV):
3761   case CASE_RVV_OPCODE(VAND_VV):
3762   case CASE_RVV_OPCODE(VOR_VV):
3763   case CASE_RVV_OPCODE(VXOR_VV):
3764   case CASE_RVV_OPCODE_MASK(VMSEQ_VV):
3765   case CASE_RVV_OPCODE_MASK(VMSNE_VV):
3766   case CASE_RVV_OPCODE(VMIN_VV):
3767   case CASE_RVV_OPCODE(VMINU_VV):
3768   case CASE_RVV_OPCODE(VMAX_VV):
3769   case CASE_RVV_OPCODE(VMAXU_VV):
3770   case CASE_RVV_OPCODE(VMUL_VV):
3771   case CASE_RVV_OPCODE(VMULH_VV):
3772   case CASE_RVV_OPCODE(VMULHU_VV):
3773   case CASE_RVV_OPCODE_WIDEN(VWADD_VV):
3774   case CASE_RVV_OPCODE_WIDEN(VWADDU_VV):
3775   case CASE_RVV_OPCODE_WIDEN(VWMUL_VV):
3776   case CASE_RVV_OPCODE_WIDEN(VWMULU_VV):
3777   case CASE_RVV_OPCODE_WIDEN(VWMACC_VV):
3778   case CASE_RVV_OPCODE_WIDEN(VWMACCU_VV):
3779   case CASE_RVV_OPCODE_UNMASK(VADC_VVM):
3780   case CASE_RVV_OPCODE(VSADD_VV):
3781   case CASE_RVV_OPCODE(VSADDU_VV):
3782   case CASE_RVV_OPCODE(VAADD_VV):
3783   case CASE_RVV_OPCODE(VAADDU_VV):
3784   case CASE_RVV_OPCODE(VSMUL_VV):
3785     // Operands 2 and 3 are commutable.
3786     return fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, 2, 3);
3787   case CASE_VFMA_SPLATS(FMADD):
3788   case CASE_VFMA_SPLATS(FMSUB):
3789   case CASE_VFMA_SPLATS(FMACC):
3790   case CASE_VFMA_SPLATS(FMSAC):
3791   case CASE_VFMA_SPLATS(FNMADD):
3792   case CASE_VFMA_SPLATS(FNMSUB):
3793   case CASE_VFMA_SPLATS(FNMACC):
3794   case CASE_VFMA_SPLATS(FNMSAC):
3795   case CASE_VFMA_OPCODE_VV(FMACC):
3796   case CASE_VFMA_OPCODE_VV(FMSAC):
3797   case CASE_VFMA_OPCODE_VV(FNMACC):
3798   case CASE_VFMA_OPCODE_VV(FNMSAC):
3799   case CASE_VMA_OPCODE_LMULS(MADD, VX):
3800   case CASE_VMA_OPCODE_LMULS(NMSUB, VX):
3801   case CASE_VMA_OPCODE_LMULS(MACC, VX):
3802   case CASE_VMA_OPCODE_LMULS(NMSAC, VX):
3803   case CASE_VMA_OPCODE_LMULS(MACC, VV):
3804   case CASE_VMA_OPCODE_LMULS(NMSAC, VV): {
3805     // If the tail policy is undisturbed we can't commute.
3806     assert(RISCVII::hasVecPolicyOp(MI.getDesc().TSFlags));
3807     if ((MI.getOperand(RISCVII::getVecPolicyOpNum(MI.getDesc())).getImm() &
3808          1) == 0)
3809       return false;
3810 
3811     // For these instructions we can only swap operand 1 and operand 3 by
3812     // changing the opcode.
3813     unsigned CommutableOpIdx1 = 1;
3814     unsigned CommutableOpIdx2 = 3;
3815     if (!fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, CommutableOpIdx1,
3816                               CommutableOpIdx2))
3817       return false;
3818     return true;
3819   }
3820   case CASE_VFMA_OPCODE_VV(FMADD):
3821   case CASE_VFMA_OPCODE_VV(FMSUB):
3822   case CASE_VFMA_OPCODE_VV(FNMADD):
3823   case CASE_VFMA_OPCODE_VV(FNMSUB):
3824   case CASE_VMA_OPCODE_LMULS(MADD, VV):
3825   case CASE_VMA_OPCODE_LMULS(NMSUB, VV): {
3826     // If the tail policy is undisturbed we can't commute.
3827     assert(RISCVII::hasVecPolicyOp(MI.getDesc().TSFlags));
3828     if ((MI.getOperand(RISCVII::getVecPolicyOpNum(MI.getDesc())).getImm() &
3829          1) == 0)
3830       return false;
3831 
3832     // For these instructions we have more freedom. We can commute with the
3833     // other multiplicand or with the addend/subtrahend/minuend.
3834 
3835     // Any fixed operand must be from source 1, 2 or 3.
3836     if (SrcOpIdx1 != CommuteAnyOperandIndex && SrcOpIdx1 > 3)
3837       return false;
3838     if (SrcOpIdx2 != CommuteAnyOperandIndex && SrcOpIdx2 > 3)
3839       return false;
3840 
3841     // It both ops are fixed one must be the tied source.
3842     if (SrcOpIdx1 != CommuteAnyOperandIndex &&
3843         SrcOpIdx2 != CommuteAnyOperandIndex && SrcOpIdx1 != 1 && SrcOpIdx2 != 1)
3844       return false;
3845 
3846     // Look for two different register operands assumed to be commutable
3847     // regardless of the FMA opcode. The FMA opcode is adjusted later if
3848     // needed.
3849     if (SrcOpIdx1 == CommuteAnyOperandIndex ||
3850         SrcOpIdx2 == CommuteAnyOperandIndex) {
3851       // At least one of operands to be commuted is not specified and
3852       // this method is free to choose appropriate commutable operands.
3853       unsigned CommutableOpIdx1 = SrcOpIdx1;
3854       if (SrcOpIdx1 == SrcOpIdx2) {
3855         // Both of operands are not fixed. Set one of commutable
3856         // operands to the tied source.
3857         CommutableOpIdx1 = 1;
3858       } else if (SrcOpIdx1 == CommuteAnyOperandIndex) {
3859         // Only one of the operands is not fixed.
3860         CommutableOpIdx1 = SrcOpIdx2;
3861       }
3862 
3863       // CommutableOpIdx1 is well defined now. Let's choose another commutable
3864       // operand and assign its index to CommutableOpIdx2.
3865       unsigned CommutableOpIdx2;
3866       if (CommutableOpIdx1 != 1) {
3867         // If we haven't already used the tied source, we must use it now.
3868         CommutableOpIdx2 = 1;
3869       } else {
3870         Register Op1Reg = MI.getOperand(CommutableOpIdx1).getReg();
3871 
3872         // The commuted operands should have different registers.
3873         // Otherwise, the commute transformation does not change anything and
3874         // is useless. We use this as a hint to make our decision.
3875         if (Op1Reg != MI.getOperand(2).getReg())
3876           CommutableOpIdx2 = 2;
3877         else
3878           CommutableOpIdx2 = 3;
3879       }
3880 
3881       // Assign the found pair of commutable indices to SrcOpIdx1 and
3882       // SrcOpIdx2 to return those values.
3883       if (!fixCommutedOpIndices(SrcOpIdx1, SrcOpIdx2, CommutableOpIdx1,
3884                                 CommutableOpIdx2))
3885         return false;
3886     }
3887 
3888     return true;
3889   }
3890   }
3891 
3892   return TargetInstrInfo::findCommutedOpIndices(MI, SrcOpIdx1, SrcOpIdx2);
3893 }
3894 
3895 // clang-format off
3896 #define CASE_VMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, LMUL)                \
3897   case RISCV::PseudoV##OLDOP##_##TYPE##_##LMUL:                                \
3898     Opc = RISCV::PseudoV##NEWOP##_##TYPE##_##LMUL;                             \
3899     break;
3900 
3901 #define CASE_VMA_CHANGE_OPCODE_LMULS(OLDOP, NEWOP, TYPE)                       \
3902   CASE_VMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, MF8)                       \
3903   CASE_VMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, MF4)                       \
3904   CASE_VMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, MF2)                       \
3905   CASE_VMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, M1)                        \
3906   CASE_VMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, M2)                        \
3907   CASE_VMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, M4)                        \
3908   CASE_VMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, M8)
3909 
3910 // VFMA depends on SEW.
3911 #define CASE_VFMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, LMUL, SEW)          \
3912   case RISCV::PseudoV##OLDOP##_##TYPE##_##LMUL##_##SEW:                        \
3913     Opc = RISCV::PseudoV##NEWOP##_##TYPE##_##LMUL##_##SEW;                     \
3914     break;
3915 
3916 #define CASE_VFMA_CHANGE_OPCODE_LMULS_M1(OLDOP, NEWOP, TYPE, SEW)              \
3917   CASE_VFMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, M1, SEW)                  \
3918   CASE_VFMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, M2, SEW)                  \
3919   CASE_VFMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, M4, SEW)                  \
3920   CASE_VFMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, M8, SEW)
3921 
3922 #define CASE_VFMA_CHANGE_OPCODE_LMULS_MF2(OLDOP, NEWOP, TYPE, SEW)             \
3923   CASE_VFMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, MF2, SEW)                 \
3924   CASE_VFMA_CHANGE_OPCODE_LMULS_M1(OLDOP, NEWOP, TYPE, SEW)
3925 
3926 #define CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(OLDOP, NEWOP, TYPE, SEW)             \
3927   CASE_VFMA_CHANGE_OPCODE_COMMON(OLDOP, NEWOP, TYPE, MF4, SEW)                 \
3928   CASE_VFMA_CHANGE_OPCODE_LMULS_MF2(OLDOP, NEWOP, TYPE, SEW)
3929 
3930 #define CASE_VFMA_CHANGE_OPCODE_VV(OLDOP, NEWOP)                               \
3931   CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(OLDOP, NEWOP, VV, E16)                     \
3932   CASE_VFMA_CHANGE_OPCODE_LMULS_MF2(OLDOP, NEWOP, VV, E32)                     \
3933   CASE_VFMA_CHANGE_OPCODE_LMULS_M1(OLDOP, NEWOP, VV, E64)
3934 
3935 #define CASE_VFMA_CHANGE_OPCODE_SPLATS(OLDOP, NEWOP)                           \
3936   CASE_VFMA_CHANGE_OPCODE_LMULS_MF4(OLDOP, NEWOP, VFPR16, E16)                 \
3937   CASE_VFMA_CHANGE_OPCODE_LMULS_MF2(OLDOP, NEWOP, VFPR32, E32)                 \
3938   CASE_VFMA_CHANGE_OPCODE_LMULS_M1(OLDOP, NEWOP, VFPR64, E64)
3939 // clang-format on
3940 
commuteInstructionImpl(MachineInstr & MI,bool NewMI,unsigned OpIdx1,unsigned OpIdx2) const3941 MachineInstr *RISCVInstrInfo::commuteInstructionImpl(MachineInstr &MI,
3942                                                      bool NewMI,
3943                                                      unsigned OpIdx1,
3944                                                      unsigned OpIdx2) const {
3945   auto cloneIfNew = [NewMI](MachineInstr &MI) -> MachineInstr & {
3946     if (NewMI)
3947       return *MI.getParent()->getParent()->CloneMachineInstr(&MI);
3948     return MI;
3949   };
3950 
3951   switch (MI.getOpcode()) {
3952   case RISCV::TH_MVEQZ:
3953   case RISCV::TH_MVNEZ: {
3954     auto &WorkingMI = cloneIfNew(MI);
3955     WorkingMI.setDesc(get(MI.getOpcode() == RISCV::TH_MVEQZ ? RISCV::TH_MVNEZ
3956                                                             : RISCV::TH_MVEQZ));
3957     return TargetInstrInfo::commuteInstructionImpl(WorkingMI, false, OpIdx1,
3958                                                    OpIdx2);
3959   }
3960   case RISCV::PseudoCCMOVGPRNoX0:
3961   case RISCV::PseudoCCMOVGPR: {
3962     // CCMOV can be commuted by inverting the condition.
3963     auto CC = static_cast<RISCVCC::CondCode>(MI.getOperand(3).getImm());
3964     CC = RISCVCC::getOppositeBranchCondition(CC);
3965     auto &WorkingMI = cloneIfNew(MI);
3966     WorkingMI.getOperand(3).setImm(CC);
3967     return TargetInstrInfo::commuteInstructionImpl(WorkingMI, /*NewMI*/ false,
3968                                                    OpIdx1, OpIdx2);
3969   }
3970   case CASE_VFMA_SPLATS(FMACC):
3971   case CASE_VFMA_SPLATS(FMADD):
3972   case CASE_VFMA_SPLATS(FMSAC):
3973   case CASE_VFMA_SPLATS(FMSUB):
3974   case CASE_VFMA_SPLATS(FNMACC):
3975   case CASE_VFMA_SPLATS(FNMADD):
3976   case CASE_VFMA_SPLATS(FNMSAC):
3977   case CASE_VFMA_SPLATS(FNMSUB):
3978   case CASE_VFMA_OPCODE_VV(FMACC):
3979   case CASE_VFMA_OPCODE_VV(FMSAC):
3980   case CASE_VFMA_OPCODE_VV(FNMACC):
3981   case CASE_VFMA_OPCODE_VV(FNMSAC):
3982   case CASE_VMA_OPCODE_LMULS(MADD, VX):
3983   case CASE_VMA_OPCODE_LMULS(NMSUB, VX):
3984   case CASE_VMA_OPCODE_LMULS(MACC, VX):
3985   case CASE_VMA_OPCODE_LMULS(NMSAC, VX):
3986   case CASE_VMA_OPCODE_LMULS(MACC, VV):
3987   case CASE_VMA_OPCODE_LMULS(NMSAC, VV): {
3988     // It only make sense to toggle these between clobbering the
3989     // addend/subtrahend/minuend one of the multiplicands.
3990     assert((OpIdx1 == 1 || OpIdx2 == 1) && "Unexpected opcode index");
3991     assert((OpIdx1 == 3 || OpIdx2 == 3) && "Unexpected opcode index");
3992     unsigned Opc;
3993     switch (MI.getOpcode()) {
3994       default:
3995         llvm_unreachable("Unexpected opcode");
3996       CASE_VFMA_CHANGE_OPCODE_SPLATS(FMACC, FMADD)
3997       CASE_VFMA_CHANGE_OPCODE_SPLATS(FMADD, FMACC)
3998       CASE_VFMA_CHANGE_OPCODE_SPLATS(FMSAC, FMSUB)
3999       CASE_VFMA_CHANGE_OPCODE_SPLATS(FMSUB, FMSAC)
4000       CASE_VFMA_CHANGE_OPCODE_SPLATS(FNMACC, FNMADD)
4001       CASE_VFMA_CHANGE_OPCODE_SPLATS(FNMADD, FNMACC)
4002       CASE_VFMA_CHANGE_OPCODE_SPLATS(FNMSAC, FNMSUB)
4003       CASE_VFMA_CHANGE_OPCODE_SPLATS(FNMSUB, FNMSAC)
4004       CASE_VFMA_CHANGE_OPCODE_VV(FMACC, FMADD)
4005       CASE_VFMA_CHANGE_OPCODE_VV(FMSAC, FMSUB)
4006       CASE_VFMA_CHANGE_OPCODE_VV(FNMACC, FNMADD)
4007       CASE_VFMA_CHANGE_OPCODE_VV(FNMSAC, FNMSUB)
4008       CASE_VMA_CHANGE_OPCODE_LMULS(MACC, MADD, VX)
4009       CASE_VMA_CHANGE_OPCODE_LMULS(MADD, MACC, VX)
4010       CASE_VMA_CHANGE_OPCODE_LMULS(NMSAC, NMSUB, VX)
4011       CASE_VMA_CHANGE_OPCODE_LMULS(NMSUB, NMSAC, VX)
4012       CASE_VMA_CHANGE_OPCODE_LMULS(MACC, MADD, VV)
4013       CASE_VMA_CHANGE_OPCODE_LMULS(NMSAC, NMSUB, VV)
4014     }
4015 
4016     auto &WorkingMI = cloneIfNew(MI);
4017     WorkingMI.setDesc(get(Opc));
4018     return TargetInstrInfo::commuteInstructionImpl(WorkingMI, /*NewMI=*/false,
4019                                                    OpIdx1, OpIdx2);
4020   }
4021   case CASE_VFMA_OPCODE_VV(FMADD):
4022   case CASE_VFMA_OPCODE_VV(FMSUB):
4023   case CASE_VFMA_OPCODE_VV(FNMADD):
4024   case CASE_VFMA_OPCODE_VV(FNMSUB):
4025   case CASE_VMA_OPCODE_LMULS(MADD, VV):
4026   case CASE_VMA_OPCODE_LMULS(NMSUB, VV): {
4027     assert((OpIdx1 == 1 || OpIdx2 == 1) && "Unexpected opcode index");
4028     // If one of the operands, is the addend we need to change opcode.
4029     // Otherwise we're just swapping 2 of the multiplicands.
4030     if (OpIdx1 == 3 || OpIdx2 == 3) {
4031       unsigned Opc;
4032       switch (MI.getOpcode()) {
4033         default:
4034           llvm_unreachable("Unexpected opcode");
4035         CASE_VFMA_CHANGE_OPCODE_VV(FMADD, FMACC)
4036         CASE_VFMA_CHANGE_OPCODE_VV(FMSUB, FMSAC)
4037         CASE_VFMA_CHANGE_OPCODE_VV(FNMADD, FNMACC)
4038         CASE_VFMA_CHANGE_OPCODE_VV(FNMSUB, FNMSAC)
4039         CASE_VMA_CHANGE_OPCODE_LMULS(MADD, MACC, VV)
4040         CASE_VMA_CHANGE_OPCODE_LMULS(NMSUB, NMSAC, VV)
4041       }
4042 
4043       auto &WorkingMI = cloneIfNew(MI);
4044       WorkingMI.setDesc(get(Opc));
4045       return TargetInstrInfo::commuteInstructionImpl(WorkingMI, /*NewMI=*/false,
4046                                                      OpIdx1, OpIdx2);
4047     }
4048     // Let the default code handle it.
4049     break;
4050   }
4051   }
4052 
4053   return TargetInstrInfo::commuteInstructionImpl(MI, NewMI, OpIdx1, OpIdx2);
4054 }
4055 
4056 #undef CASE_VMA_CHANGE_OPCODE_COMMON
4057 #undef CASE_VMA_CHANGE_OPCODE_LMULS
4058 #undef CASE_VFMA_CHANGE_OPCODE_COMMON
4059 #undef CASE_VFMA_CHANGE_OPCODE_LMULS_M1
4060 #undef CASE_VFMA_CHANGE_OPCODE_LMULS_MF2
4061 #undef CASE_VFMA_CHANGE_OPCODE_LMULS_MF4
4062 #undef CASE_VFMA_CHANGE_OPCODE_VV
4063 #undef CASE_VFMA_CHANGE_OPCODE_SPLATS
4064 
4065 #undef CASE_RVV_OPCODE_UNMASK_LMUL
4066 #undef CASE_RVV_OPCODE_MASK_LMUL
4067 #undef CASE_RVV_OPCODE_LMUL
4068 #undef CASE_RVV_OPCODE_UNMASK_WIDEN
4069 #undef CASE_RVV_OPCODE_UNMASK
4070 #undef CASE_RVV_OPCODE_MASK_WIDEN
4071 #undef CASE_RVV_OPCODE_MASK
4072 #undef CASE_RVV_OPCODE_WIDEN
4073 #undef CASE_RVV_OPCODE
4074 
4075 #undef CASE_VMA_OPCODE_COMMON
4076 #undef CASE_VMA_OPCODE_LMULS
4077 #undef CASE_VFMA_OPCODE_COMMON
4078 #undef CASE_VFMA_OPCODE_LMULS_M1
4079 #undef CASE_VFMA_OPCODE_LMULS_MF2
4080 #undef CASE_VFMA_OPCODE_LMULS_MF4
4081 #undef CASE_VFMA_OPCODE_VV
4082 #undef CASE_VFMA_SPLATS
4083 
simplifyInstruction(MachineInstr & MI) const4084 bool RISCVInstrInfo::simplifyInstruction(MachineInstr &MI) const {
4085   switch (MI.getOpcode()) {
4086   default:
4087     break;
4088   case RISCV::ADD:
4089   case RISCV::OR:
4090   case RISCV::XOR:
4091     // Normalize (so we hit the next if clause).
4092     // add/[x]or rd, zero, rs => add/[x]or rd, rs, zero
4093     if (MI.getOperand(1).getReg() == RISCV::X0)
4094       commuteInstruction(MI);
4095     // add/[x]or rd, rs, zero => addi rd, rs, 0
4096     if (MI.getOperand(2).getReg() == RISCV::X0) {
4097       MI.getOperand(2).ChangeToImmediate(0);
4098       MI.setDesc(get(RISCV::ADDI));
4099       return true;
4100     }
4101     // xor rd, rs, rs => addi rd, zero, 0
4102     if (MI.getOpcode() == RISCV::XOR &&
4103         MI.getOperand(1).getReg() == MI.getOperand(2).getReg()) {
4104       MI.getOperand(1).setReg(RISCV::X0);
4105       MI.getOperand(2).ChangeToImmediate(0);
4106       MI.setDesc(get(RISCV::ADDI));
4107       return true;
4108     }
4109     break;
4110   case RISCV::ORI:
4111   case RISCV::XORI:
4112     // [x]ori rd, zero, N => addi rd, zero, N
4113     if (MI.getOperand(1).getReg() == RISCV::X0) {
4114       MI.setDesc(get(RISCV::ADDI));
4115       return true;
4116     }
4117     break;
4118   case RISCV::SUB:
4119     // sub rd, rs, zero => addi rd, rs, 0
4120     if (MI.getOperand(2).getReg() == RISCV::X0) {
4121       MI.getOperand(2).ChangeToImmediate(0);
4122       MI.setDesc(get(RISCV::ADDI));
4123       return true;
4124     }
4125     break;
4126   case RISCV::SUBW:
4127     // subw rd, rs, zero => addiw rd, rs, 0
4128     if (MI.getOperand(2).getReg() == RISCV::X0) {
4129       MI.getOperand(2).ChangeToImmediate(0);
4130       MI.setDesc(get(RISCV::ADDIW));
4131       return true;
4132     }
4133     break;
4134   case RISCV::ADDW:
4135     // Normalize (so we hit the next if clause).
4136     // addw rd, zero, rs => addw rd, rs, zero
4137     if (MI.getOperand(1).getReg() == RISCV::X0)
4138       commuteInstruction(MI);
4139     // addw rd, rs, zero => addiw rd, rs, 0
4140     if (MI.getOperand(2).getReg() == RISCV::X0) {
4141       MI.getOperand(2).ChangeToImmediate(0);
4142       MI.setDesc(get(RISCV::ADDIW));
4143       return true;
4144     }
4145     break;
4146   case RISCV::SH1ADD:
4147   case RISCV::SH1ADD_UW:
4148   case RISCV::SH2ADD:
4149   case RISCV::SH2ADD_UW:
4150   case RISCV::SH3ADD:
4151   case RISCV::SH3ADD_UW:
4152     // shNadd[.uw] rd, zero, rs => addi rd, rs, 0
4153     if (MI.getOperand(1).getReg() == RISCV::X0) {
4154       MI.removeOperand(1);
4155       MI.addOperand(MachineOperand::CreateImm(0));
4156       MI.setDesc(get(RISCV::ADDI));
4157       return true;
4158     }
4159     // shNadd[.uw] rd, rs, zero => slli[.uw] rd, rs, N
4160     if (MI.getOperand(2).getReg() == RISCV::X0) {
4161       MI.removeOperand(2);
4162       unsigned Opc = MI.getOpcode();
4163       if (Opc == RISCV::SH1ADD_UW || Opc == RISCV::SH2ADD_UW ||
4164           Opc == RISCV::SH3ADD_UW) {
4165         MI.addOperand(MachineOperand::CreateImm(getSHXADDUWShiftAmount(Opc)));
4166         MI.setDesc(get(RISCV::SLLI_UW));
4167         return true;
4168       }
4169       MI.addOperand(MachineOperand::CreateImm(getSHXADDShiftAmount(Opc)));
4170       MI.setDesc(get(RISCV::SLLI));
4171       return true;
4172     }
4173     break;
4174   case RISCV::AND:
4175   case RISCV::MUL:
4176   case RISCV::MULH:
4177   case RISCV::MULHSU:
4178   case RISCV::MULHU:
4179   case RISCV::MULW:
4180     // and rd, zero, rs => addi rd, zero, 0
4181     // mul* rd, zero, rs => addi rd, zero, 0
4182     // and rd, rs, zero => addi rd, zero, 0
4183     // mul* rd, rs, zero => addi rd, zero, 0
4184     if (MI.getOperand(1).getReg() == RISCV::X0 ||
4185         MI.getOperand(2).getReg() == RISCV::X0) {
4186       MI.getOperand(1).setReg(RISCV::X0);
4187       MI.getOperand(2).ChangeToImmediate(0);
4188       MI.setDesc(get(RISCV::ADDI));
4189       return true;
4190     }
4191     break;
4192   case RISCV::ANDI:
4193     // andi rd, zero, C => addi rd, zero, 0
4194     if (MI.getOperand(1).getReg() == RISCV::X0) {
4195       MI.getOperand(2).setImm(0);
4196       MI.setDesc(get(RISCV::ADDI));
4197       return true;
4198     }
4199     break;
4200   case RISCV::SLL:
4201   case RISCV::SRL:
4202   case RISCV::SRA:
4203     // shift rd, zero, rs => addi rd, zero, 0
4204     if (MI.getOperand(1).getReg() == RISCV::X0) {
4205       MI.getOperand(2).ChangeToImmediate(0);
4206       MI.setDesc(get(RISCV::ADDI));
4207       return true;
4208     }
4209     // shift rd, rs, zero => addi rd, rs, 0
4210     if (MI.getOperand(2).getReg() == RISCV::X0) {
4211       MI.getOperand(2).ChangeToImmediate(0);
4212       MI.setDesc(get(RISCV::ADDI));
4213       return true;
4214     }
4215     break;
4216   case RISCV::SLLW:
4217   case RISCV::SRLW:
4218   case RISCV::SRAW:
4219     // shiftw rd, zero, rs => addi rd, zero, 0
4220     if (MI.getOperand(1).getReg() == RISCV::X0) {
4221       MI.getOperand(2).ChangeToImmediate(0);
4222       MI.setDesc(get(RISCV::ADDI));
4223       return true;
4224     }
4225     break;
4226   case RISCV::SLLI:
4227   case RISCV::SRLI:
4228   case RISCV::SRAI:
4229   case RISCV::SLLIW:
4230   case RISCV::SRLIW:
4231   case RISCV::SRAIW:
4232   case RISCV::SLLI_UW:
4233     // shiftimm rd, zero, N => addi rd, zero, 0
4234     if (MI.getOperand(1).getReg() == RISCV::X0) {
4235       MI.getOperand(2).setImm(0);
4236       MI.setDesc(get(RISCV::ADDI));
4237       return true;
4238     }
4239     break;
4240   case RISCV::SLTU:
4241   case RISCV::ADD_UW:
4242     // sltu rd, zero, zero => addi rd, zero, 0
4243     // add.uw rd, zero, zero => addi rd, zero, 0
4244     if (MI.getOperand(1).getReg() == RISCV::X0 &&
4245         MI.getOperand(2).getReg() == RISCV::X0) {
4246       MI.getOperand(2).ChangeToImmediate(0);
4247       MI.setDesc(get(RISCV::ADDI));
4248       return true;
4249     }
4250     // add.uw rd, zero, rs => addi rd, rs, 0
4251     if (MI.getOpcode() == RISCV::ADD_UW &&
4252         MI.getOperand(1).getReg() == RISCV::X0) {
4253       MI.removeOperand(1);
4254       MI.addOperand(MachineOperand::CreateImm(0));
4255       MI.setDesc(get(RISCV::ADDI));
4256     }
4257     break;
4258   case RISCV::SLTIU:
4259     // sltiu rd, zero, NZC => addi rd, zero, 1
4260     // sltiu rd, zero, 0 => addi rd, zero, 0
4261     if (MI.getOperand(1).getReg() == RISCV::X0) {
4262       MI.getOperand(2).setImm(MI.getOperand(2).getImm() != 0);
4263       MI.setDesc(get(RISCV::ADDI));
4264       return true;
4265     }
4266     break;
4267   case RISCV::SEXT_H:
4268   case RISCV::SEXT_B:
4269   case RISCV::ZEXT_H_RV32:
4270   case RISCV::ZEXT_H_RV64:
4271     // sext.[hb] rd, zero => addi rd, zero, 0
4272     // zext.h rd, zero => addi rd, zero, 0
4273     if (MI.getOperand(1).getReg() == RISCV::X0) {
4274       MI.addOperand(MachineOperand::CreateImm(0));
4275       MI.setDesc(get(RISCV::ADDI));
4276       return true;
4277     }
4278     break;
4279   case RISCV::MIN:
4280   case RISCV::MINU:
4281   case RISCV::MAX:
4282   case RISCV::MAXU:
4283     // min|max rd, rs, rs => addi rd, rs, 0
4284     if (MI.getOperand(1).getReg() == MI.getOperand(2).getReg()) {
4285       MI.getOperand(2).ChangeToImmediate(0);
4286       MI.setDesc(get(RISCV::ADDI));
4287       return true;
4288     }
4289     break;
4290   case RISCV::BEQ:
4291   case RISCV::BNE:
4292     // b{eq,ne} zero, rs, imm => b{eq,ne} rs, zero, imm
4293     if (MI.getOperand(0).getReg() == RISCV::X0) {
4294       MachineOperand MO0 = MI.getOperand(0);
4295       MI.removeOperand(0);
4296       MI.insert(MI.operands_begin() + 1, {MO0});
4297     }
4298     break;
4299   case RISCV::BLTU:
4300     // bltu zero, rs, imm => bne rs, zero, imm
4301     if (MI.getOperand(0).getReg() == RISCV::X0) {
4302       MachineOperand MO0 = MI.getOperand(0);
4303       MI.removeOperand(0);
4304       MI.insert(MI.operands_begin() + 1, {MO0});
4305       MI.setDesc(get(RISCV::BNE));
4306     }
4307     break;
4308   case RISCV::BGEU:
4309     // bgeu zero, rs, imm => beq rs, zero, imm
4310     if (MI.getOperand(0).getReg() == RISCV::X0) {
4311       MachineOperand MO0 = MI.getOperand(0);
4312       MI.removeOperand(0);
4313       MI.insert(MI.operands_begin() + 1, {MO0});
4314       MI.setDesc(get(RISCV::BEQ));
4315     }
4316     break;
4317   }
4318   return false;
4319 }
4320 
4321 // clang-format off
4322 #define CASE_WIDEOP_OPCODE_COMMON(OP, LMUL)                                    \
4323   RISCV::PseudoV##OP##_##LMUL##_TIED
4324 
4325 #define CASE_WIDEOP_OPCODE_LMULS(OP)                                           \
4326   CASE_WIDEOP_OPCODE_COMMON(OP, MF8):                                          \
4327   case CASE_WIDEOP_OPCODE_COMMON(OP, MF4):                                     \
4328   case CASE_WIDEOP_OPCODE_COMMON(OP, MF2):                                     \
4329   case CASE_WIDEOP_OPCODE_COMMON(OP, M1):                                      \
4330   case CASE_WIDEOP_OPCODE_COMMON(OP, M2):                                      \
4331   case CASE_WIDEOP_OPCODE_COMMON(OP, M4)
4332 
4333 #define CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, LMUL)                             \
4334   case RISCV::PseudoV##OP##_##LMUL##_TIED:                                     \
4335     NewOpc = RISCV::PseudoV##OP##_##LMUL;                                      \
4336     break;
4337 
4338 #define CASE_WIDEOP_CHANGE_OPCODE_LMULS(OP)                                    \
4339   CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, MF8)                                    \
4340   CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, MF4)                                    \
4341   CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, MF2)                                    \
4342   CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, M1)                                     \
4343   CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, M2)                                     \
4344   CASE_WIDEOP_CHANGE_OPCODE_COMMON(OP, M4)
4345 
4346 // FP Widening Ops may by SEW aware. Create SEW aware cases for these cases.
4347 #define CASE_FP_WIDEOP_OPCODE_COMMON(OP, LMUL, SEW)                            \
4348   RISCV::PseudoV##OP##_##LMUL##_##SEW##_TIED
4349 
4350 #define CASE_FP_WIDEOP_OPCODE_LMULS(OP)                                        \
4351   CASE_FP_WIDEOP_OPCODE_COMMON(OP, MF4, E16):                                  \
4352   case CASE_FP_WIDEOP_OPCODE_COMMON(OP, MF2, E16):                             \
4353   case CASE_FP_WIDEOP_OPCODE_COMMON(OP, MF2, E32):                             \
4354   case CASE_FP_WIDEOP_OPCODE_COMMON(OP, M1, E16):                              \
4355   case CASE_FP_WIDEOP_OPCODE_COMMON(OP, M1, E32):                              \
4356   case CASE_FP_WIDEOP_OPCODE_COMMON(OP, M2, E16):                              \
4357   case CASE_FP_WIDEOP_OPCODE_COMMON(OP, M2, E32):                              \
4358   case CASE_FP_WIDEOP_OPCODE_COMMON(OP, M4, E16):                              \
4359   case CASE_FP_WIDEOP_OPCODE_COMMON(OP, M4, E32)                               \
4360 
4361 #define CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, LMUL, SEW)                     \
4362   case RISCV::PseudoV##OP##_##LMUL##_##SEW##_TIED:                             \
4363     NewOpc = RISCV::PseudoV##OP##_##LMUL##_##SEW;                              \
4364     break;
4365 
4366 #define CASE_FP_WIDEOP_CHANGE_OPCODE_LMULS(OP)                                 \
4367   CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, MF4, E16)                            \
4368   CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, MF2, E16)                            \
4369   CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, MF2, E32)                            \
4370   CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, M1, E16)                             \
4371   CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, M1, E32)                             \
4372   CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, M2, E16)                             \
4373   CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, M2, E32)                             \
4374   CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, M4, E16)                             \
4375   CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON(OP, M4, E32)                             \
4376 // clang-format on
4377 
convertToThreeAddress(MachineInstr & MI,LiveVariables * LV,LiveIntervals * LIS) const4378 MachineInstr *RISCVInstrInfo::convertToThreeAddress(MachineInstr &MI,
4379                                                     LiveVariables *LV,
4380                                                     LiveIntervals *LIS) const {
4381   MachineInstrBuilder MIB;
4382   switch (MI.getOpcode()) {
4383   default:
4384     return nullptr;
4385   case CASE_FP_WIDEOP_OPCODE_LMULS(FWADD_WV):
4386   case CASE_FP_WIDEOP_OPCODE_LMULS(FWSUB_WV): {
4387     assert(RISCVII::hasVecPolicyOp(MI.getDesc().TSFlags) &&
4388            MI.getNumExplicitOperands() == 7 &&
4389            "Expect 7 explicit operands rd, rs2, rs1, rm, vl, sew, policy");
4390     // If the tail policy is undisturbed we can't convert.
4391     if ((MI.getOperand(RISCVII::getVecPolicyOpNum(MI.getDesc())).getImm() &
4392          1) == 0)
4393       return nullptr;
4394     // clang-format off
4395     unsigned NewOpc;
4396     switch (MI.getOpcode()) {
4397     default:
4398       llvm_unreachable("Unexpected opcode");
4399     CASE_FP_WIDEOP_CHANGE_OPCODE_LMULS(FWADD_WV)
4400     CASE_FP_WIDEOP_CHANGE_OPCODE_LMULS(FWSUB_WV)
4401     }
4402     // clang-format on
4403 
4404     MachineBasicBlock &MBB = *MI.getParent();
4405     MIB = BuildMI(MBB, MI, MI.getDebugLoc(), get(NewOpc))
4406               .add(MI.getOperand(0))
4407               .addReg(MI.getOperand(0).getReg(), RegState::Undef)
4408               .add(MI.getOperand(1))
4409               .add(MI.getOperand(2))
4410               .add(MI.getOperand(3))
4411               .add(MI.getOperand(4))
4412               .add(MI.getOperand(5))
4413               .add(MI.getOperand(6));
4414     break;
4415   }
4416   case CASE_WIDEOP_OPCODE_LMULS(WADD_WV):
4417   case CASE_WIDEOP_OPCODE_LMULS(WADDU_WV):
4418   case CASE_WIDEOP_OPCODE_LMULS(WSUB_WV):
4419   case CASE_WIDEOP_OPCODE_LMULS(WSUBU_WV): {
4420     // If the tail policy is undisturbed we can't convert.
4421     assert(RISCVII::hasVecPolicyOp(MI.getDesc().TSFlags) &&
4422            MI.getNumExplicitOperands() == 6);
4423     if ((MI.getOperand(RISCVII::getVecPolicyOpNum(MI.getDesc())).getImm() &
4424          1) == 0)
4425       return nullptr;
4426 
4427     // clang-format off
4428     unsigned NewOpc;
4429     switch (MI.getOpcode()) {
4430     default:
4431       llvm_unreachable("Unexpected opcode");
4432     CASE_WIDEOP_CHANGE_OPCODE_LMULS(WADD_WV)
4433     CASE_WIDEOP_CHANGE_OPCODE_LMULS(WADDU_WV)
4434     CASE_WIDEOP_CHANGE_OPCODE_LMULS(WSUB_WV)
4435     CASE_WIDEOP_CHANGE_OPCODE_LMULS(WSUBU_WV)
4436     }
4437     // clang-format on
4438 
4439     MachineBasicBlock &MBB = *MI.getParent();
4440     MIB = BuildMI(MBB, MI, MI.getDebugLoc(), get(NewOpc))
4441               .add(MI.getOperand(0))
4442               .addReg(MI.getOperand(0).getReg(), RegState::Undef)
4443               .add(MI.getOperand(1))
4444               .add(MI.getOperand(2))
4445               .add(MI.getOperand(3))
4446               .add(MI.getOperand(4))
4447               .add(MI.getOperand(5));
4448     break;
4449   }
4450   }
4451   MIB.copyImplicitOps(MI);
4452 
4453   if (LV) {
4454     unsigned NumOps = MI.getNumOperands();
4455     for (unsigned I = 1; I < NumOps; ++I) {
4456       MachineOperand &Op = MI.getOperand(I);
4457       if (Op.isReg() && Op.isKill())
4458         LV->replaceKillInstruction(Op.getReg(), MI, *MIB);
4459     }
4460   }
4461 
4462   if (LIS) {
4463     SlotIndex Idx = LIS->ReplaceMachineInstrInMaps(MI, *MIB);
4464 
4465     if (MI.getOperand(0).isEarlyClobber()) {
4466       // Use operand 1 was tied to early-clobber def operand 0, so its live
4467       // interval could have ended at an early-clobber slot. Now they are not
4468       // tied we need to update it to the normal register slot.
4469       LiveInterval &LI = LIS->getInterval(MI.getOperand(1).getReg());
4470       LiveRange::Segment *S = LI.getSegmentContaining(Idx);
4471       if (S->end == Idx.getRegSlot(true))
4472         S->end = Idx.getRegSlot();
4473     }
4474   }
4475 
4476   return MIB;
4477 }
4478 
4479 #undef CASE_WIDEOP_OPCODE_COMMON
4480 #undef CASE_WIDEOP_OPCODE_LMULS
4481 #undef CASE_WIDEOP_CHANGE_OPCODE_COMMON
4482 #undef CASE_WIDEOP_CHANGE_OPCODE_LMULS
4483 #undef CASE_FP_WIDEOP_OPCODE_COMMON
4484 #undef CASE_FP_WIDEOP_OPCODE_LMULS
4485 #undef CASE_FP_WIDEOP_CHANGE_OPCODE_COMMON
4486 #undef CASE_FP_WIDEOP_CHANGE_OPCODE_LMULS
4487 
mulImm(MachineFunction & MF,MachineBasicBlock & MBB,MachineBasicBlock::iterator II,const DebugLoc & DL,Register DestReg,uint32_t Amount,MachineInstr::MIFlag Flag) const4488 void RISCVInstrInfo::mulImm(MachineFunction &MF, MachineBasicBlock &MBB,
4489                             MachineBasicBlock::iterator II, const DebugLoc &DL,
4490                             Register DestReg, uint32_t Amount,
4491                             MachineInstr::MIFlag Flag) const {
4492   MachineRegisterInfo &MRI = MF.getRegInfo();
4493   if (llvm::has_single_bit<uint32_t>(Amount)) {
4494     uint32_t ShiftAmount = Log2_32(Amount);
4495     if (ShiftAmount == 0)
4496       return;
4497     BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg)
4498         .addReg(DestReg, RegState::Kill)
4499         .addImm(ShiftAmount)
4500         .setMIFlag(Flag);
4501   } else if (STI.hasStdExtZba() &&
4502              ((Amount % 3 == 0 && isPowerOf2_64(Amount / 3)) ||
4503               (Amount % 5 == 0 && isPowerOf2_64(Amount / 5)) ||
4504               (Amount % 9 == 0 && isPowerOf2_64(Amount / 9)))) {
4505     // We can use Zba SHXADD+SLLI instructions for multiply in some cases.
4506     unsigned Opc;
4507     uint32_t ShiftAmount;
4508     if (Amount % 9 == 0) {
4509       Opc = RISCV::SH3ADD;
4510       ShiftAmount = Log2_64(Amount / 9);
4511     } else if (Amount % 5 == 0) {
4512       Opc = RISCV::SH2ADD;
4513       ShiftAmount = Log2_64(Amount / 5);
4514     } else if (Amount % 3 == 0) {
4515       Opc = RISCV::SH1ADD;
4516       ShiftAmount = Log2_64(Amount / 3);
4517     } else {
4518       llvm_unreachable("implied by if-clause");
4519     }
4520     if (ShiftAmount)
4521       BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg)
4522           .addReg(DestReg, RegState::Kill)
4523           .addImm(ShiftAmount)
4524           .setMIFlag(Flag);
4525     BuildMI(MBB, II, DL, get(Opc), DestReg)
4526         .addReg(DestReg, RegState::Kill)
4527         .addReg(DestReg)
4528         .setMIFlag(Flag);
4529   } else if (llvm::has_single_bit<uint32_t>(Amount - 1)) {
4530     Register ScaledRegister = MRI.createVirtualRegister(&RISCV::GPRRegClass);
4531     uint32_t ShiftAmount = Log2_32(Amount - 1);
4532     BuildMI(MBB, II, DL, get(RISCV::SLLI), ScaledRegister)
4533         .addReg(DestReg)
4534         .addImm(ShiftAmount)
4535         .setMIFlag(Flag);
4536     BuildMI(MBB, II, DL, get(RISCV::ADD), DestReg)
4537         .addReg(ScaledRegister, RegState::Kill)
4538         .addReg(DestReg, RegState::Kill)
4539         .setMIFlag(Flag);
4540   } else if (llvm::has_single_bit<uint32_t>(Amount + 1)) {
4541     Register ScaledRegister = MRI.createVirtualRegister(&RISCV::GPRRegClass);
4542     uint32_t ShiftAmount = Log2_32(Amount + 1);
4543     BuildMI(MBB, II, DL, get(RISCV::SLLI), ScaledRegister)
4544         .addReg(DestReg)
4545         .addImm(ShiftAmount)
4546         .setMIFlag(Flag);
4547     BuildMI(MBB, II, DL, get(RISCV::SUB), DestReg)
4548         .addReg(ScaledRegister, RegState::Kill)
4549         .addReg(DestReg, RegState::Kill)
4550         .setMIFlag(Flag);
4551   } else if (STI.hasStdExtZmmul()) {
4552     Register N = MRI.createVirtualRegister(&RISCV::GPRRegClass);
4553     movImm(MBB, II, DL, N, Amount, Flag);
4554     BuildMI(MBB, II, DL, get(RISCV::MUL), DestReg)
4555         .addReg(DestReg, RegState::Kill)
4556         .addReg(N, RegState::Kill)
4557         .setMIFlag(Flag);
4558   } else {
4559     Register Acc;
4560     uint32_t PrevShiftAmount = 0;
4561     for (uint32_t ShiftAmount = 0; Amount >> ShiftAmount; ShiftAmount++) {
4562       if (Amount & (1U << ShiftAmount)) {
4563         if (ShiftAmount)
4564           BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg)
4565               .addReg(DestReg, RegState::Kill)
4566               .addImm(ShiftAmount - PrevShiftAmount)
4567               .setMIFlag(Flag);
4568         if (Amount >> (ShiftAmount + 1)) {
4569           // If we don't have an accmulator yet, create it and copy DestReg.
4570           if (!Acc) {
4571             Acc = MRI.createVirtualRegister(&RISCV::GPRRegClass);
4572             BuildMI(MBB, II, DL, get(TargetOpcode::COPY), Acc)
4573                 .addReg(DestReg)
4574                 .setMIFlag(Flag);
4575           } else {
4576             BuildMI(MBB, II, DL, get(RISCV::ADD), Acc)
4577                 .addReg(Acc, RegState::Kill)
4578                 .addReg(DestReg)
4579                 .setMIFlag(Flag);
4580           }
4581         }
4582         PrevShiftAmount = ShiftAmount;
4583       }
4584     }
4585     assert(Acc && "Expected valid accumulator");
4586     BuildMI(MBB, II, DL, get(RISCV::ADD), DestReg)
4587         .addReg(DestReg, RegState::Kill)
4588         .addReg(Acc, RegState::Kill)
4589         .setMIFlag(Flag);
4590   }
4591 }
4592 
4593 ArrayRef<std::pair<MachineMemOperand::Flags, const char *>>
getSerializableMachineMemOperandTargetFlags() const4594 RISCVInstrInfo::getSerializableMachineMemOperandTargetFlags() const {
4595   static const std::pair<MachineMemOperand::Flags, const char *> TargetFlags[] =
4596       {{MONontemporalBit0, "riscv-nontemporal-domain-bit-0"},
4597        {MONontemporalBit1, "riscv-nontemporal-domain-bit-1"}};
4598   return ArrayRef(TargetFlags);
4599 }
4600 
getTailDuplicateSize(CodeGenOptLevel OptLevel) const4601 unsigned RISCVInstrInfo::getTailDuplicateSize(CodeGenOptLevel OptLevel) const {
4602   return OptLevel >= CodeGenOptLevel::Aggressive
4603              ? STI.getTailDupAggressiveThreshold()
4604              : 2;
4605 }
4606 
isRVVSpill(const MachineInstr & MI)4607 bool RISCV::isRVVSpill(const MachineInstr &MI) {
4608   // RVV lacks any support for immediate addressing for stack addresses, so be
4609   // conservative.
4610   unsigned Opcode = MI.getOpcode();
4611   if (!RISCVVPseudosTable::getPseudoInfo(Opcode) &&
4612       !getLMULForRVVWholeLoadStore(Opcode) && !isRVVSpillForZvlsseg(Opcode))
4613     return false;
4614   return true;
4615 }
4616 
4617 std::optional<std::pair<unsigned, unsigned>>
isRVVSpillForZvlsseg(unsigned Opcode)4618 RISCV::isRVVSpillForZvlsseg(unsigned Opcode) {
4619   switch (Opcode) {
4620   default:
4621     return std::nullopt;
4622   case RISCV::PseudoVSPILL2_M1:
4623   case RISCV::PseudoVRELOAD2_M1:
4624     return std::make_pair(2u, 1u);
4625   case RISCV::PseudoVSPILL2_M2:
4626   case RISCV::PseudoVRELOAD2_M2:
4627     return std::make_pair(2u, 2u);
4628   case RISCV::PseudoVSPILL2_M4:
4629   case RISCV::PseudoVRELOAD2_M4:
4630     return std::make_pair(2u, 4u);
4631   case RISCV::PseudoVSPILL3_M1:
4632   case RISCV::PseudoVRELOAD3_M1:
4633     return std::make_pair(3u, 1u);
4634   case RISCV::PseudoVSPILL3_M2:
4635   case RISCV::PseudoVRELOAD3_M2:
4636     return std::make_pair(3u, 2u);
4637   case RISCV::PseudoVSPILL4_M1:
4638   case RISCV::PseudoVRELOAD4_M1:
4639     return std::make_pair(4u, 1u);
4640   case RISCV::PseudoVSPILL4_M2:
4641   case RISCV::PseudoVRELOAD4_M2:
4642     return std::make_pair(4u, 2u);
4643   case RISCV::PseudoVSPILL5_M1:
4644   case RISCV::PseudoVRELOAD5_M1:
4645     return std::make_pair(5u, 1u);
4646   case RISCV::PseudoVSPILL6_M1:
4647   case RISCV::PseudoVRELOAD6_M1:
4648     return std::make_pair(6u, 1u);
4649   case RISCV::PseudoVSPILL7_M1:
4650   case RISCV::PseudoVRELOAD7_M1:
4651     return std::make_pair(7u, 1u);
4652   case RISCV::PseudoVSPILL8_M1:
4653   case RISCV::PseudoVRELOAD8_M1:
4654     return std::make_pair(8u, 1u);
4655   }
4656 }
4657 
hasEqualFRM(const MachineInstr & MI1,const MachineInstr & MI2)4658 bool RISCV::hasEqualFRM(const MachineInstr &MI1, const MachineInstr &MI2) {
4659   int16_t MI1FrmOpIdx =
4660       RISCV::getNamedOperandIdx(MI1.getOpcode(), RISCV::OpName::frm);
4661   int16_t MI2FrmOpIdx =
4662       RISCV::getNamedOperandIdx(MI2.getOpcode(), RISCV::OpName::frm);
4663   if (MI1FrmOpIdx < 0 || MI2FrmOpIdx < 0)
4664     return false;
4665   MachineOperand FrmOp1 = MI1.getOperand(MI1FrmOpIdx);
4666   MachineOperand FrmOp2 = MI2.getOperand(MI2FrmOpIdx);
4667   return FrmOp1.getImm() == FrmOp2.getImm();
4668 }
4669 
4670 std::optional<unsigned>
getVectorLowDemandedScalarBits(unsigned Opcode,unsigned Log2SEW)4671 RISCV::getVectorLowDemandedScalarBits(unsigned Opcode, unsigned Log2SEW) {
4672   switch (Opcode) {
4673   default:
4674     return std::nullopt;
4675 
4676   // 11.6. Vector Single-Width Shift Instructions
4677   case RISCV::VSLL_VX:
4678   case RISCV::VSRL_VX:
4679   case RISCV::VSRA_VX:
4680   // 12.4. Vector Single-Width Scaling Shift Instructions
4681   case RISCV::VSSRL_VX:
4682   case RISCV::VSSRA_VX:
4683   // Zvbb
4684   case RISCV::VROL_VX:
4685   case RISCV::VROR_VX:
4686     // Only the low lg2(SEW) bits of the shift-amount value are used.
4687     return Log2SEW;
4688 
4689   // 11.7 Vector Narrowing Integer Right Shift Instructions
4690   case RISCV::VNSRL_WX:
4691   case RISCV::VNSRA_WX:
4692   // 12.5. Vector Narrowing Fixed-Point Clip Instructions
4693   case RISCV::VNCLIPU_WX:
4694   case RISCV::VNCLIP_WX:
4695   // Zvbb
4696   case RISCV::VWSLL_VX:
4697     // Only the low lg2(2*SEW) bits of the shift-amount value are used.
4698     return Log2SEW + 1;
4699 
4700   // 11.1. Vector Single-Width Integer Add and Subtract
4701   case RISCV::VADD_VX:
4702   case RISCV::VSUB_VX:
4703   case RISCV::VRSUB_VX:
4704   // 11.2. Vector Widening Integer Add/Subtract
4705   case RISCV::VWADDU_VX:
4706   case RISCV::VWSUBU_VX:
4707   case RISCV::VWADD_VX:
4708   case RISCV::VWSUB_VX:
4709   case RISCV::VWADDU_WX:
4710   case RISCV::VWSUBU_WX:
4711   case RISCV::VWADD_WX:
4712   case RISCV::VWSUB_WX:
4713   // 11.4. Vector Integer Add-with-Carry / Subtract-with-Borrow Instructions
4714   case RISCV::VADC_VXM:
4715   case RISCV::VADC_VIM:
4716   case RISCV::VMADC_VXM:
4717   case RISCV::VMADC_VIM:
4718   case RISCV::VMADC_VX:
4719   case RISCV::VSBC_VXM:
4720   case RISCV::VMSBC_VXM:
4721   case RISCV::VMSBC_VX:
4722   // 11.5 Vector Bitwise Logical Instructions
4723   case RISCV::VAND_VX:
4724   case RISCV::VOR_VX:
4725   case RISCV::VXOR_VX:
4726   // 11.8. Vector Integer Compare Instructions
4727   case RISCV::VMSEQ_VX:
4728   case RISCV::VMSNE_VX:
4729   case RISCV::VMSLTU_VX:
4730   case RISCV::VMSLT_VX:
4731   case RISCV::VMSLEU_VX:
4732   case RISCV::VMSLE_VX:
4733   case RISCV::VMSGTU_VX:
4734   case RISCV::VMSGT_VX:
4735   // 11.9. Vector Integer Min/Max Instructions
4736   case RISCV::VMINU_VX:
4737   case RISCV::VMIN_VX:
4738   case RISCV::VMAXU_VX:
4739   case RISCV::VMAX_VX:
4740   // 11.10. Vector Single-Width Integer Multiply Instructions
4741   case RISCV::VMUL_VX:
4742   case RISCV::VMULH_VX:
4743   case RISCV::VMULHU_VX:
4744   case RISCV::VMULHSU_VX:
4745   // 11.11. Vector Integer Divide Instructions
4746   case RISCV::VDIVU_VX:
4747   case RISCV::VDIV_VX:
4748   case RISCV::VREMU_VX:
4749   case RISCV::VREM_VX:
4750   // 11.12. Vector Widening Integer Multiply Instructions
4751   case RISCV::VWMUL_VX:
4752   case RISCV::VWMULU_VX:
4753   case RISCV::VWMULSU_VX:
4754   // 11.13. Vector Single-Width Integer Multiply-Add Instructions
4755   case RISCV::VMACC_VX:
4756   case RISCV::VNMSAC_VX:
4757   case RISCV::VMADD_VX:
4758   case RISCV::VNMSUB_VX:
4759   // 11.14. Vector Widening Integer Multiply-Add Instructions
4760   case RISCV::VWMACCU_VX:
4761   case RISCV::VWMACC_VX:
4762   case RISCV::VWMACCSU_VX:
4763   case RISCV::VWMACCUS_VX:
4764   // 11.15. Vector Integer Merge Instructions
4765   case RISCV::VMERGE_VXM:
4766   // 11.16. Vector Integer Move Instructions
4767   case RISCV::VMV_V_X:
4768   // 12.1. Vector Single-Width Saturating Add and Subtract
4769   case RISCV::VSADDU_VX:
4770   case RISCV::VSADD_VX:
4771   case RISCV::VSSUBU_VX:
4772   case RISCV::VSSUB_VX:
4773   // 12.2. Vector Single-Width Averaging Add and Subtract
4774   case RISCV::VAADDU_VX:
4775   case RISCV::VAADD_VX:
4776   case RISCV::VASUBU_VX:
4777   case RISCV::VASUB_VX:
4778   // 12.3. Vector Single-Width Fractional Multiply with Rounding and Saturation
4779   case RISCV::VSMUL_VX:
4780   // 16.1. Integer Scalar Move Instructions
4781   case RISCV::VMV_S_X:
4782   // Zvbb
4783   case RISCV::VANDN_VX:
4784     return 1U << Log2SEW;
4785   }
4786 }
4787 
getRVVMCOpcode(unsigned RVVPseudoOpcode)4788 unsigned RISCV::getRVVMCOpcode(unsigned RVVPseudoOpcode) {
4789   const RISCVVPseudosTable::PseudoInfo *RVV =
4790       RISCVVPseudosTable::getPseudoInfo(RVVPseudoOpcode);
4791   if (!RVV)
4792     return 0;
4793   return RVV->BaseInstr;
4794 }
4795 
getDestLog2EEW(const MCInstrDesc & Desc,unsigned Log2SEW)4796 unsigned RISCV::getDestLog2EEW(const MCInstrDesc &Desc, unsigned Log2SEW) {
4797   unsigned DestEEW =
4798       (Desc.TSFlags & RISCVII::DestEEWMask) >> RISCVII::DestEEWShift;
4799   // EEW = 1
4800   if (DestEEW == 0)
4801     return 0;
4802   // EEW = SEW * n
4803   unsigned Scaled = Log2SEW + (DestEEW - 1);
4804   assert(Scaled >= 3 && Scaled <= 6);
4805   return Scaled;
4806 }
4807 
4808 /// Given two VL operands, do we know that LHS <= RHS?
isVLKnownLE(const MachineOperand & LHS,const MachineOperand & RHS)4809 bool RISCV::isVLKnownLE(const MachineOperand &LHS, const MachineOperand &RHS) {
4810   if (LHS.isReg() && RHS.isReg() && LHS.getReg().isVirtual() &&
4811       LHS.getReg() == RHS.getReg())
4812     return true;
4813   if (RHS.isImm() && RHS.getImm() == RISCV::VLMaxSentinel)
4814     return true;
4815   if (LHS.isImm() && LHS.getImm() == RISCV::VLMaxSentinel)
4816     return false;
4817   if (!LHS.isImm() || !RHS.isImm())
4818     return false;
4819   return LHS.getImm() <= RHS.getImm();
4820 }
4821 
4822 namespace {
4823 class RISCVPipelinerLoopInfo : public TargetInstrInfo::PipelinerLoopInfo {
4824   const MachineInstr *LHS;
4825   const MachineInstr *RHS;
4826   SmallVector<MachineOperand, 3> Cond;
4827 
4828 public:
RISCVPipelinerLoopInfo(const MachineInstr * LHS,const MachineInstr * RHS,const SmallVectorImpl<MachineOperand> & Cond)4829   RISCVPipelinerLoopInfo(const MachineInstr *LHS, const MachineInstr *RHS,
4830                          const SmallVectorImpl<MachineOperand> &Cond)
4831       : LHS(LHS), RHS(RHS), Cond(Cond.begin(), Cond.end()) {}
4832 
shouldIgnoreForPipelining(const MachineInstr * MI) const4833   bool shouldIgnoreForPipelining(const MachineInstr *MI) const override {
4834     // Make the instructions for loop control be placed in stage 0.
4835     // The predecessors of LHS/RHS are considered by the caller.
4836     if (LHS && MI == LHS)
4837       return true;
4838     if (RHS && MI == RHS)
4839       return true;
4840     return false;
4841   }
4842 
createTripCountGreaterCondition(int TC,MachineBasicBlock & MBB,SmallVectorImpl<MachineOperand> & CondParam)4843   std::optional<bool> createTripCountGreaterCondition(
4844       int TC, MachineBasicBlock &MBB,
4845       SmallVectorImpl<MachineOperand> &CondParam) override {
4846     // A branch instruction will be inserted as "if (Cond) goto epilogue".
4847     // Cond is normalized for such use.
4848     // The predecessors of the branch are assumed to have already been inserted.
4849     CondParam = Cond;
4850     return {};
4851   }
4852 
setPreheader(MachineBasicBlock * NewPreheader)4853   void setPreheader(MachineBasicBlock *NewPreheader) override {}
4854 
adjustTripCount(int TripCountAdjust)4855   void adjustTripCount(int TripCountAdjust) override {}
4856 };
4857 } // namespace
4858 
4859 std::unique_ptr<TargetInstrInfo::PipelinerLoopInfo>
analyzeLoopForPipelining(MachineBasicBlock * LoopBB) const4860 RISCVInstrInfo::analyzeLoopForPipelining(MachineBasicBlock *LoopBB) const {
4861   MachineBasicBlock *TBB = nullptr, *FBB = nullptr;
4862   SmallVector<MachineOperand, 4> Cond;
4863   if (analyzeBranch(*LoopBB, TBB, FBB, Cond, /*AllowModify=*/false))
4864     return nullptr;
4865 
4866   // Infinite loops are not supported
4867   if (TBB == LoopBB && FBB == LoopBB)
4868     return nullptr;
4869 
4870   // Must be conditional branch
4871   if (FBB == nullptr)
4872     return nullptr;
4873 
4874   assert((TBB == LoopBB || FBB == LoopBB) &&
4875          "The Loop must be a single-basic-block loop");
4876 
4877   // Normalization for createTripCountGreaterCondition()
4878   if (TBB == LoopBB)
4879     reverseBranchCondition(Cond);
4880 
4881   const MachineRegisterInfo &MRI = LoopBB->getParent()->getRegInfo();
4882   auto FindRegDef = [&MRI](MachineOperand &Op) -> const MachineInstr * {
4883     if (!Op.isReg())
4884       return nullptr;
4885     Register Reg = Op.getReg();
4886     if (!Reg.isVirtual())
4887       return nullptr;
4888     return MRI.getVRegDef(Reg);
4889   };
4890 
4891   const MachineInstr *LHS = FindRegDef(Cond[1]);
4892   const MachineInstr *RHS = FindRegDef(Cond[2]);
4893   if (LHS && LHS->isPHI())
4894     return nullptr;
4895   if (RHS && RHS->isPHI())
4896     return nullptr;
4897 
4898   return std::make_unique<RISCVPipelinerLoopInfo>(LHS, RHS, Cond);
4899 }
4900 
4901 // FIXME: We should remove this if we have a default generic scheduling model.
isHighLatencyDef(int Opc) const4902 bool RISCVInstrInfo::isHighLatencyDef(int Opc) const {
4903   unsigned RVVMCOpcode = RISCV::getRVVMCOpcode(Opc);
4904   Opc = RVVMCOpcode ? RVVMCOpcode : Opc;
4905   switch (Opc) {
4906   default:
4907     return false;
4908   // Integer div/rem.
4909   case RISCV::DIV:
4910   case RISCV::DIVW:
4911   case RISCV::DIVU:
4912   case RISCV::DIVUW:
4913   case RISCV::REM:
4914   case RISCV::REMW:
4915   case RISCV::REMU:
4916   case RISCV::REMUW:
4917   // Floating-point div/sqrt.
4918   case RISCV::FDIV_H:
4919   case RISCV::FDIV_S:
4920   case RISCV::FDIV_D:
4921   case RISCV::FDIV_H_INX:
4922   case RISCV::FDIV_S_INX:
4923   case RISCV::FDIV_D_INX:
4924   case RISCV::FDIV_D_IN32X:
4925   case RISCV::FSQRT_H:
4926   case RISCV::FSQRT_S:
4927   case RISCV::FSQRT_D:
4928   case RISCV::FSQRT_H_INX:
4929   case RISCV::FSQRT_S_INX:
4930   case RISCV::FSQRT_D_INX:
4931   case RISCV::FSQRT_D_IN32X:
4932   // Vector integer div/rem
4933   case RISCV::VDIV_VV:
4934   case RISCV::VDIV_VX:
4935   case RISCV::VDIVU_VV:
4936   case RISCV::VDIVU_VX:
4937   case RISCV::VREM_VV:
4938   case RISCV::VREM_VX:
4939   case RISCV::VREMU_VV:
4940   case RISCV::VREMU_VX:
4941   // Vector floating-point div/sqrt.
4942   case RISCV::VFDIV_VV:
4943   case RISCV::VFDIV_VF:
4944   case RISCV::VFRDIV_VF:
4945   case RISCV::VFSQRT_V:
4946   case RISCV::VFRSQRT7_V:
4947     return true;
4948   }
4949 }
4950