//===-- RISCVInsertReadWriteCSR.cpp - Insert Read/Write of RISC-V CSR -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // This file implements the machine function pass to insert read/write of CSR-s // of the RISC-V instructions. // // Currently the pass implements: // -Writing and saving frm before an RVV floating-point instruction with a // static rounding mode and restores the value after. // //===----------------------------------------------------------------------===// #include "MCTargetDesc/RISCVBaseInfo.h" #include "RISCV.h" #include "RISCVSubtarget.h" #include "llvm/CodeGen/MachineFunctionPass.h" using namespace llvm; #define DEBUG_TYPE "riscv-insert-read-write-csr" #define RISCV_INSERT_READ_WRITE_CSR_NAME "RISC-V Insert Read/Write CSR Pass" static cl::opt DisableFRMInsertOpt("riscv-disable-frm-insert-opt", cl::init(false), cl::Hidden, cl::desc("Disable optimized frm insertion.")); namespace { class RISCVInsertReadWriteCSR : public MachineFunctionPass { const TargetInstrInfo *TII; public: static char ID; RISCVInsertReadWriteCSR() : MachineFunctionPass(ID) {} bool runOnMachineFunction(MachineFunction &MF) override; void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); MachineFunctionPass::getAnalysisUsage(AU); } StringRef getPassName() const override { return RISCV_INSERT_READ_WRITE_CSR_NAME; } private: bool emitWriteRoundingMode(MachineBasicBlock &MBB); bool emitWriteRoundingModeOpt(MachineBasicBlock &MBB); }; } // end anonymous namespace char RISCVInsertReadWriteCSR::ID = 0; INITIALIZE_PASS(RISCVInsertReadWriteCSR, DEBUG_TYPE, RISCV_INSERT_READ_WRITE_CSR_NAME, false, false) // TODO: Use more accurate rounding mode at the start of MBB. bool RISCVInsertReadWriteCSR::emitWriteRoundingModeOpt(MachineBasicBlock &MBB) { bool Changed = false; MachineInstr *LastFRMChanger = nullptr; unsigned CurrentRM = RISCVFPRndMode::DYN; Register SavedFRM; for (MachineInstr &MI : MBB) { if (MI.getOpcode() == RISCV::SwapFRMImm || MI.getOpcode() == RISCV::WriteFRMImm) { CurrentRM = MI.getOperand(0).getImm(); SavedFRM = Register(); continue; } if (MI.getOpcode() == RISCV::WriteFRM) { CurrentRM = RISCVFPRndMode::DYN; SavedFRM = Register(); continue; } if (MI.isCall() || MI.isInlineAsm() || MI.readsRegister(RISCV::FRM, /*TRI=*/nullptr)) { // Restore FRM before unknown operations. if (SavedFRM.isValid()) BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRM)) .addReg(SavedFRM); CurrentRM = RISCVFPRndMode::DYN; SavedFRM = Register(); continue; } assert(!MI.modifiesRegister(RISCV::FRM, /*TRI=*/nullptr) && "Expected that MI could not modify FRM."); int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc()); if (FRMIdx < 0) continue; unsigned InstrRM = MI.getOperand(FRMIdx).getImm(); LastFRMChanger = &MI; // Make MI implicit use FRM. MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false, /*IsImp*/ true)); Changed = true; // Skip if MI uses same rounding mode as FRM. if (InstrRM == CurrentRM) continue; if (!SavedFRM.isValid()) { // Save current FRM value to SavedFRM. MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo(); SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass); BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm), SavedFRM) .addImm(InstrRM); } else { // Don't need to save current FRM when SavedFRM having value. BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRMImm)) .addImm(InstrRM); } CurrentRM = InstrRM; } // Restore FRM if needed. if (SavedFRM.isValid()) { assert(LastFRMChanger && "Expected valid pointer."); MachineInstrBuilder MIB = BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM)) .addReg(SavedFRM); MBB.insertAfter(LastFRMChanger, MIB); } return Changed; } // This function also swaps frm and restores it when encountering an RVV // floating point instruction with a static rounding mode. bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) { bool Changed = false; for (MachineInstr &MI : MBB) { int FRMIdx = RISCVII::getFRMOpNum(MI.getDesc()); if (FRMIdx < 0) continue; unsigned FRMImm = MI.getOperand(FRMIdx).getImm(); // The value is a hint to this pass to not alter the frm value. if (FRMImm == RISCVFPRndMode::DYN) continue; Changed = true; // Save MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo(); Register SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass); BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm), SavedFRM) .addImm(FRMImm); MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false, /*IsImp*/ true)); // Restore MachineInstrBuilder MIB = BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM)) .addReg(SavedFRM); MBB.insertAfter(MI, MIB); } return Changed; } bool RISCVInsertReadWriteCSR::runOnMachineFunction(MachineFunction &MF) { // Skip if the vector extension is not enabled. const RISCVSubtarget &ST = MF.getSubtarget(); if (!ST.hasVInstructions()) return false; TII = ST.getInstrInfo(); bool Changed = false; for (MachineBasicBlock &MBB : MF) { if (DisableFRMInsertOpt) Changed |= emitWriteRoundingMode(MBB); else Changed |= emitWriteRoundingModeOpt(MBB); } return Changed; } FunctionPass *llvm::createRISCVInsertReadWriteCSRPass() { return new RISCVInsertReadWriteCSR(); }