xref: /freebsd/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp (revision 7ab1a32cd43cbae61ad4dd435d6a482bbf61cb52)
1 //===-- RISCVInsertReadWriteCSR.cpp - Insert Read/Write of RISC-V CSR -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // This file implements the machine function pass to insert read/write of CSR-s
9 // of the RISC-V instructions.
10 //
11 // Currently the pass implements:
12 // -Writing and saving frm before an RVV floating-point instruction with a
13 //  static rounding mode and restores the value after.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "MCTargetDesc/RISCVBaseInfo.h"
18 #include "RISCV.h"
19 #include "RISCVSubtarget.h"
20 #include "llvm/CodeGen/MachineFunctionPass.h"
21 using namespace llvm;
22 
23 #define DEBUG_TYPE "riscv-insert-read-write-csr"
24 #define RISCV_INSERT_READ_WRITE_CSR_NAME "RISC-V Insert Read/Write CSR Pass"
25 
26 static cl::opt<bool>
27     DisableFRMInsertOpt("riscv-disable-frm-insert-opt", cl::init(false),
28                         cl::Hidden,
29                         cl::desc("Disable optimized frm insertion."));
30 
31 namespace {
32 
33 class RISCVInsertReadWriteCSR : public MachineFunctionPass {
34   const TargetInstrInfo *TII;
35 
36 public:
37   static char ID;
38 
39   RISCVInsertReadWriteCSR() : MachineFunctionPass(ID) {}
40 
41   bool runOnMachineFunction(MachineFunction &MF) override;
42 
43   void getAnalysisUsage(AnalysisUsage &AU) const override {
44     AU.setPreservesCFG();
45     MachineFunctionPass::getAnalysisUsage(AU);
46   }
47 
48   StringRef getPassName() const override {
49     return RISCV_INSERT_READ_WRITE_CSR_NAME;
50   }
51 
52 private:
53   bool emitWriteRoundingMode(MachineBasicBlock &MBB);
54   bool emitWriteRoundingModeOpt(MachineBasicBlock &MBB);
55 };
56 
57 } // end anonymous namespace
58 
59 char RISCVInsertReadWriteCSR::ID = 0;
60 
61 INITIALIZE_PASS(RISCVInsertReadWriteCSR, DEBUG_TYPE,
62                 RISCV_INSERT_READ_WRITE_CSR_NAME, false, false)
63 
64 // TODO: Use more accurate rounding mode at the start of MBB.
65 bool RISCVInsertReadWriteCSR::emitWriteRoundingModeOpt(MachineBasicBlock &MBB) {
66   bool Changed = false;
67   MachineInstr *LastFRMChanger = nullptr;
68   unsigned CurrentRM = RISCVFPRndMode::DYN;
69   Register SavedFRM;
70 
71   for (MachineInstr &MI : MBB) {
72     if (MI.getOpcode() == RISCV::SwapFRMImm ||
73         MI.getOpcode() == RISCV::WriteFRMImm) {
74       CurrentRM = MI.getOperand(0).getImm();
75       SavedFRM = Register();
76       continue;
77     }
78 
79     if (MI.getOpcode() == RISCV::WriteFRM) {
80       CurrentRM = RISCVFPRndMode::DYN;
81       SavedFRM = Register();
82       continue;
83     }
84 
85     if (MI.isCall() || MI.isInlineAsm() ||
86         MI.readsRegister(RISCV::FRM, /*TRI=*/nullptr)) {
87       // Restore FRM before unknown operations.
88       if (SavedFRM.isValid())
89         BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRM))
90             .addReg(SavedFRM);
91       CurrentRM = RISCVFPRndMode::DYN;
92       SavedFRM = Register();
93       continue;
94     }
95 
96     assert(!MI.modifiesRegister(RISCV::FRM, /*TRI=*/nullptr) &&
97            "Expected that MI could not modify FRM.");
98 
99     int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc());
100     if (FRMIdx < 0)
101       continue;
102     unsigned InstrRM = MI.getOperand(FRMIdx).getImm();
103 
104     LastFRMChanger = &MI;
105 
106     // Make MI implicit use FRM.
107     MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
108                                             /*IsImp*/ true));
109     Changed = true;
110 
111     // Skip if MI uses same rounding mode as FRM.
112     if (InstrRM == CurrentRM)
113       continue;
114 
115     if (!SavedFRM.isValid()) {
116       // Save current FRM value to SavedFRM.
117       MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo();
118       SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass);
119       BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm), SavedFRM)
120           .addImm(InstrRM);
121     } else {
122       // Don't need to save current FRM when SavedFRM having value.
123       BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRMImm))
124           .addImm(InstrRM);
125     }
126     CurrentRM = InstrRM;
127   }
128 
129   // Restore FRM if needed.
130   if (SavedFRM.isValid()) {
131     assert(LastFRMChanger && "Expected valid pointer.");
132     MachineInstrBuilder MIB =
133         BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM))
134             .addReg(SavedFRM);
135     MBB.insertAfter(LastFRMChanger, MIB);
136   }
137 
138   return Changed;
139 }
140 
141 // This function also swaps frm and restores it when encountering an RVV
142 // floating point instruction with a static rounding mode.
143 bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) {
144   bool Changed = false;
145   for (MachineInstr &MI : MBB) {
146     int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc());
147     if (FRMIdx < 0)
148       continue;
149 
150     unsigned FRMImm = MI.getOperand(FRMIdx).getImm();
151 
152     // The value is a hint to this pass to not alter the frm value.
153     if (FRMImm == RISCVFPRndMode::DYN)
154       continue;
155 
156     Changed = true;
157 
158     // Save
159     MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo();
160     Register SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass);
161     BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm),
162             SavedFRM)
163         .addImm(FRMImm);
164     MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false,
165                                             /*IsImp*/ true));
166     // Restore
167     MachineInstrBuilder MIB =
168         BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM))
169             .addReg(SavedFRM);
170     MBB.insertAfter(MI, MIB);
171   }
172   return Changed;
173 }
174 
175 bool RISCVInsertReadWriteCSR::runOnMachineFunction(MachineFunction &MF) {
176   // Skip if the vector extension is not enabled.
177   const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
178   if (!ST.hasVInstructions())
179     return false;
180 
181   TII = ST.getInstrInfo();
182 
183   bool Changed = false;
184 
185   for (MachineBasicBlock &MBB : MF) {
186     if (DisableFRMInsertOpt)
187       Changed |= emitWriteRoundingMode(MBB);
188     else
189       Changed |= emitWriteRoundingModeOpt(MBB);
190   }
191 
192   return Changed;
193 }
194 
195 FunctionPass *llvm::createRISCVInsertReadWriteCSRPass() {
196   return new RISCVInsertReadWriteCSR();
197 }
198