1 //===-- RISCVInsertReadWriteCSR.cpp - Insert Read/Write of RISC-V CSR -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This file implements the machine function pass to insert read/write of CSR-s 9 // of the RISC-V instructions. 10 // 11 // Currently the pass implements: 12 // -Writing and saving frm before an RVV floating-point instruction with a 13 // static rounding mode and restores the value after. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "MCTargetDesc/RISCVBaseInfo.h" 18 #include "RISCV.h" 19 #include "RISCVSubtarget.h" 20 #include "llvm/CodeGen/MachineFunctionPass.h" 21 using namespace llvm; 22 23 #define DEBUG_TYPE "riscv-insert-read-write-csr" 24 #define RISCV_INSERT_READ_WRITE_CSR_NAME "RISC-V Insert Read/Write CSR Pass" 25 26 namespace { 27 28 class RISCVInsertReadWriteCSR : public MachineFunctionPass { 29 const TargetInstrInfo *TII; 30 31 public: 32 static char ID; 33 34 RISCVInsertReadWriteCSR() : MachineFunctionPass(ID) {} 35 36 bool runOnMachineFunction(MachineFunction &MF) override; 37 38 void getAnalysisUsage(AnalysisUsage &AU) const override { 39 AU.setPreservesCFG(); 40 MachineFunctionPass::getAnalysisUsage(AU); 41 } 42 43 StringRef getPassName() const override { 44 return RISCV_INSERT_READ_WRITE_CSR_NAME; 45 } 46 47 private: 48 bool emitWriteRoundingMode(MachineBasicBlock &MBB); 49 }; 50 51 } // end anonymous namespace 52 53 char RISCVInsertReadWriteCSR::ID = 0; 54 55 INITIALIZE_PASS(RISCVInsertReadWriteCSR, DEBUG_TYPE, 56 RISCV_INSERT_READ_WRITE_CSR_NAME, false, false) 57 58 // This function also swaps frm and restores it when encountering an RVV 59 // floating point instruction with a static rounding mode. 60 bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) { 61 bool Changed = false; 62 for (MachineInstr &MI : MBB) { 63 int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc()); 64 if (FRMIdx < 0) 65 continue; 66 67 unsigned FRMImm = MI.getOperand(FRMIdx).getImm(); 68 69 // The value is a hint to this pass to not alter the frm value. 70 if (FRMImm == RISCVFPRndMode::DYN) 71 continue; 72 73 Changed = true; 74 75 // Save 76 MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo(); 77 Register SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass); 78 BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm), 79 SavedFRM) 80 .addImm(FRMImm); 81 MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false, 82 /*IsImp*/ true)); 83 // Restore 84 MachineInstrBuilder MIB = 85 BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM)) 86 .addReg(SavedFRM); 87 MBB.insertAfter(MI, MIB); 88 } 89 return Changed; 90 } 91 92 bool RISCVInsertReadWriteCSR::runOnMachineFunction(MachineFunction &MF) { 93 // Skip if the vector extension is not enabled. 94 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); 95 if (!ST.hasVInstructions()) 96 return false; 97 98 TII = ST.getInstrInfo(); 99 100 bool Changed = false; 101 102 for (MachineBasicBlock &MBB : MF) 103 Changed |= emitWriteRoundingMode(MBB); 104 105 return Changed; 106 } 107 108 FunctionPass *llvm::createRISCVInsertReadWriteCSRPass() { 109 return new RISCVInsertReadWriteCSR(); 110 } 111