1 //===-- RISCVInsertReadWriteCSR.cpp - Insert Read/Write of RISC-V CSR -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This file implements the machine function pass to insert read/write of CSR-s 9 // of the RISC-V instructions. 10 // 11 // Currently the pass implements: 12 // -Writing and saving frm before an RVV floating-point instruction with a 13 // static rounding mode and restores the value after. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "MCTargetDesc/RISCVBaseInfo.h" 18 #include "RISCV.h" 19 #include "RISCVSubtarget.h" 20 #include "llvm/CodeGen/MachineFunctionPass.h" 21 using namespace llvm; 22 23 #define DEBUG_TYPE "riscv-insert-read-write-csr" 24 #define RISCV_INSERT_READ_WRITE_CSR_NAME "RISC-V Insert Read/Write CSR Pass" 25 26 static cl::opt<bool> 27 DisableFRMInsertOpt("riscv-disable-frm-insert-opt", cl::init(false), 28 cl::Hidden, 29 cl::desc("Disable optimized frm insertion.")); 30 31 namespace { 32 33 class RISCVInsertReadWriteCSR : public MachineFunctionPass { 34 const TargetInstrInfo *TII; 35 36 public: 37 static char ID; 38 39 RISCVInsertReadWriteCSR() : MachineFunctionPass(ID) {} 40 41 bool runOnMachineFunction(MachineFunction &MF) override; 42 43 void getAnalysisUsage(AnalysisUsage &AU) const override { 44 AU.setPreservesCFG(); 45 MachineFunctionPass::getAnalysisUsage(AU); 46 } 47 48 StringRef getPassName() const override { 49 return RISCV_INSERT_READ_WRITE_CSR_NAME; 50 } 51 52 private: 53 bool emitWriteRoundingMode(MachineBasicBlock &MBB); 54 bool emitWriteRoundingModeOpt(MachineBasicBlock &MBB); 55 }; 56 57 } // end anonymous namespace 58 59 char RISCVInsertReadWriteCSR::ID = 0; 60 61 INITIALIZE_PASS(RISCVInsertReadWriteCSR, DEBUG_TYPE, 62 RISCV_INSERT_READ_WRITE_CSR_NAME, false, false) 63 64 // TODO: Use more accurate rounding mode at the start of MBB. 65 bool RISCVInsertReadWriteCSR::emitWriteRoundingModeOpt(MachineBasicBlock &MBB) { 66 bool Changed = false; 67 MachineInstr *LastFRMChanger = nullptr; 68 unsigned CurrentRM = RISCVFPRndMode::DYN; 69 Register SavedFRM; 70 71 for (MachineInstr &MI : MBB) { 72 if (MI.getOpcode() == RISCV::SwapFRMImm || 73 MI.getOpcode() == RISCV::WriteFRMImm) { 74 CurrentRM = MI.getOperand(0).getImm(); 75 SavedFRM = Register(); 76 continue; 77 } 78 79 if (MI.getOpcode() == RISCV::WriteFRM) { 80 CurrentRM = RISCVFPRndMode::DYN; 81 SavedFRM = Register(); 82 continue; 83 } 84 85 if (MI.isCall() || MI.isInlineAsm() || 86 MI.readsRegister(RISCV::FRM, /*TRI=*/nullptr)) { 87 // Restore FRM before unknown operations. 88 if (SavedFRM.isValid()) 89 BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRM)) 90 .addReg(SavedFRM); 91 CurrentRM = RISCVFPRndMode::DYN; 92 SavedFRM = Register(); 93 continue; 94 } 95 96 assert(!MI.modifiesRegister(RISCV::FRM, /*TRI=*/nullptr) && 97 "Expected that MI could not modify FRM."); 98 99 int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc()); 100 if (FRMIdx < 0) 101 continue; 102 unsigned InstrRM = MI.getOperand(FRMIdx).getImm(); 103 104 LastFRMChanger = &MI; 105 106 // Make MI implicit use FRM. 107 MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false, 108 /*IsImp*/ true)); 109 Changed = true; 110 111 // Skip if MI uses same rounding mode as FRM. 112 if (InstrRM == CurrentRM) 113 continue; 114 115 if (!SavedFRM.isValid()) { 116 // Save current FRM value to SavedFRM. 117 MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo(); 118 SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass); 119 BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm), SavedFRM) 120 .addImm(InstrRM); 121 } else { 122 // Don't need to save current FRM when SavedFRM having value. 123 BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRMImm)) 124 .addImm(InstrRM); 125 } 126 CurrentRM = InstrRM; 127 } 128 129 // Restore FRM if needed. 130 if (SavedFRM.isValid()) { 131 assert(LastFRMChanger && "Expected valid pointer."); 132 MachineInstrBuilder MIB = 133 BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM)) 134 .addReg(SavedFRM); 135 MBB.insertAfter(LastFRMChanger, MIB); 136 } 137 138 return Changed; 139 } 140 141 // This function also swaps frm and restores it when encountering an RVV 142 // floating point instruction with a static rounding mode. 143 bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) { 144 bool Changed = false; 145 for (MachineInstr &MI : MBB) { 146 int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc()); 147 if (FRMIdx < 0) 148 continue; 149 150 unsigned FRMImm = MI.getOperand(FRMIdx).getImm(); 151 152 // The value is a hint to this pass to not alter the frm value. 153 if (FRMImm == RISCVFPRndMode::DYN) 154 continue; 155 156 Changed = true; 157 158 // Save 159 MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo(); 160 Register SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass); 161 BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm), 162 SavedFRM) 163 .addImm(FRMImm); 164 MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false, 165 /*IsImp*/ true)); 166 // Restore 167 MachineInstrBuilder MIB = 168 BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM)) 169 .addReg(SavedFRM); 170 MBB.insertAfter(MI, MIB); 171 } 172 return Changed; 173 } 174 175 bool RISCVInsertReadWriteCSR::runOnMachineFunction(MachineFunction &MF) { 176 // Skip if the vector extension is not enabled. 177 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); 178 if (!ST.hasVInstructions()) 179 return false; 180 181 TII = ST.getInstrInfo(); 182 183 bool Changed = false; 184 185 for (MachineBasicBlock &MBB : MF) { 186 if (DisableFRMInsertOpt) 187 Changed |= emitWriteRoundingMode(MBB); 188 else 189 Changed |= emitWriteRoundingModeOpt(MBB); 190 } 191 192 return Changed; 193 } 194 195 FunctionPass *llvm::createRISCVInsertReadWriteCSRPass() { 196 return new RISCVInsertReadWriteCSR(); 197 } 198