1 //===-- RISCVInsertReadWriteCSR.cpp - Insert Read/Write of RISC-V CSR -----===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // This file implements the machine function pass to insert read/write of CSR-s 9 // of the RISC-V instructions. 10 // 11 // Currently the pass implements naive insertion of a write to vxrm before an 12 // RVV fixed-point instruction. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "MCTargetDesc/RISCVBaseInfo.h" 17 #include "RISCV.h" 18 #include "RISCVSubtarget.h" 19 #include "llvm/CodeGen/MachineFunctionPass.h" 20 using namespace llvm; 21 22 #define DEBUG_TYPE "riscv-insert-read-write-csr" 23 #define RISCV_INSERT_READ_WRITE_CSR_NAME "RISC-V Insert Read/Write CSR Pass" 24 25 namespace { 26 27 class RISCVInsertReadWriteCSR : public MachineFunctionPass { 28 const TargetInstrInfo *TII; 29 30 public: 31 static char ID; 32 33 RISCVInsertReadWriteCSR() : MachineFunctionPass(ID) { 34 initializeRISCVInsertReadWriteCSRPass(*PassRegistry::getPassRegistry()); 35 } 36 37 bool runOnMachineFunction(MachineFunction &MF) override; 38 39 void getAnalysisUsage(AnalysisUsage &AU) const override { 40 AU.setPreservesCFG(); 41 MachineFunctionPass::getAnalysisUsage(AU); 42 } 43 44 StringRef getPassName() const override { 45 return RISCV_INSERT_READ_WRITE_CSR_NAME; 46 } 47 48 private: 49 bool emitWriteRoundingMode(MachineBasicBlock &MBB); 50 }; 51 52 } // end anonymous namespace 53 54 char RISCVInsertReadWriteCSR::ID = 0; 55 56 INITIALIZE_PASS(RISCVInsertReadWriteCSR, DEBUG_TYPE, 57 RISCV_INSERT_READ_WRITE_CSR_NAME, false, false) 58 59 // Returns the index to the rounding mode immediate value if any, otherwise the 60 // function will return None. 61 static std::optional<unsigned> getRoundModeIdx(const MachineInstr &MI) { 62 uint64_t TSFlags = MI.getDesc().TSFlags; 63 if (!RISCVII::hasRoundModeOp(TSFlags)) 64 return std::nullopt; 65 66 // The operand order 67 // ------------------------------------- 68 // | n-1 (if any) | n-2 | n-3 | n-4 | 69 // | policy | sew | vl | rm | 70 // ------------------------------------- 71 return MI.getNumExplicitOperands() - RISCVII::hasVecPolicyOp(TSFlags) - 3; 72 } 73 74 // This function inserts a write to vxrm when encountering an RVV fixed-point 75 // instruction. 76 bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) { 77 bool Changed = false; 78 for (MachineInstr &MI : MBB) { 79 if (auto RoundModeIdx = getRoundModeIdx(MI)) { 80 if (RISCVII::usesVXRM(MI.getDesc().TSFlags)) { 81 unsigned VXRMImm = MI.getOperand(*RoundModeIdx).getImm(); 82 83 Changed = true; 84 85 BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteVXRMImm)) 86 .addImm(VXRMImm); 87 MI.addOperand(MachineOperand::CreateReg(RISCV::VXRM, /*IsDef*/ false, 88 /*IsImp*/ true)); 89 } else { // FRM 90 unsigned FRMImm = MI.getOperand(*RoundModeIdx).getImm(); 91 92 // The value is a hint to this pass to not alter the frm value. 93 if (FRMImm == RISCVFPRndMode::DYN) 94 continue; 95 96 Changed = true; 97 98 // Save 99 MachineRegisterInfo *MRI = &MBB.getParent()->getRegInfo(); 100 Register SavedFRM = MRI->createVirtualRegister(&RISCV::GPRRegClass); 101 BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::SwapFRMImm), 102 SavedFRM) 103 .addImm(FRMImm); 104 MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false, 105 /*IsImp*/ true)); 106 // Restore 107 MachineInstrBuilder MIB = 108 BuildMI(*MBB.getParent(), {}, TII->get(RISCV::WriteFRM)) 109 .addReg(SavedFRM); 110 MBB.insertAfter(MI, MIB); 111 } 112 } 113 } 114 return Changed; 115 } 116 117 bool RISCVInsertReadWriteCSR::runOnMachineFunction(MachineFunction &MF) { 118 // Skip if the vector extension is not enabled. 119 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); 120 if (!ST.hasVInstructions()) 121 return false; 122 123 TII = ST.getInstrInfo(); 124 125 bool Changed = false; 126 127 for (MachineBasicBlock &MBB : MF) 128 Changed |= emitWriteRoundingMode(MBB); 129 130 return Changed; 131 } 132 133 FunctionPass *llvm::createRISCVInsertReadWriteCSRPass() { 134 return new RISCVInsertReadWriteCSR(); 135 } 136