1 //===- RISCVVectorPeephole.cpp - MI Vector Pseudo Peepholes ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass performs various vector pseudo peephole optimisations after 10 // instruction selection. 11 // 12 // Currently it converts vmerge.vvm to vmv.v.v 13 // PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew 14 // -> 15 // PseudoVMV_V_V %false, %true, %vl, %sew 16 // 17 // And masked pseudos to unmasked pseudos 18 // PseudoVADD_V_V_MASK %passthru, %a, %b, %allonesmask, %vl, sew, policy 19 // -> 20 // PseudoVADD_V_V %passthru %a, %b, %vl, sew, policy 21 // 22 // It also converts AVLs to VLMAX where possible 23 // %vl = VLENB * something 24 // PseudoVADD_V_V %passthru, %a, %b, %vl, sew, policy 25 // -> 26 // PseudoVADD_V_V %passthru, %a, %b, -1, sew, policy 27 // 28 //===----------------------------------------------------------------------===// 29 30 #include "RISCV.h" 31 #include "RISCVISelDAGToDAG.h" 32 #include "RISCVSubtarget.h" 33 #include "llvm/CodeGen/MachineFunctionPass.h" 34 #include "llvm/CodeGen/MachineRegisterInfo.h" 35 #include "llvm/CodeGen/TargetInstrInfo.h" 36 #include "llvm/CodeGen/TargetRegisterInfo.h" 37 38 using namespace llvm; 39 40 #define DEBUG_TYPE "riscv-vector-peephole" 41 42 namespace { 43 44 class RISCVVectorPeephole : public MachineFunctionPass { 45 public: 46 static char ID; 47 const TargetInstrInfo *TII; 48 MachineRegisterInfo *MRI; 49 const TargetRegisterInfo *TRI; 50 RISCVVectorPeephole() : MachineFunctionPass(ID) {} 51 52 bool runOnMachineFunction(MachineFunction &MF) override; 53 MachineFunctionProperties getRequiredProperties() const override { 54 return MachineFunctionProperties().set( 55 MachineFunctionProperties::Property::IsSSA); 56 } 57 58 StringRef getPassName() const override { return "RISC-V Fold Masks"; } 59 60 private: 61 bool convertToVLMAX(MachineInstr &MI) const; 62 bool convertToUnmasked(MachineInstr &MI) const; 63 bool convertVMergeToVMv(MachineInstr &MI) const; 64 65 bool isAllOnesMask(const MachineInstr *MaskDef) const; 66 67 /// Maps uses of V0 to the corresponding def of V0. 68 DenseMap<const MachineInstr *, const MachineInstr *> V0Defs; 69 }; 70 71 } // namespace 72 73 char RISCVVectorPeephole::ID = 0; 74 75 INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false, 76 false) 77 78 // If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert it 79 // to the VLMAX sentinel value. 80 bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const { 81 if (!RISCVII::hasVLOp(MI.getDesc().TSFlags) || 82 !RISCVII::hasSEWOp(MI.getDesc().TSFlags)) 83 return false; 84 MachineOperand &VL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc())); 85 if (!VL.isReg()) 86 return false; 87 MachineInstr *Def = MRI->getVRegDef(VL.getReg()); 88 if (!Def) 89 return false; 90 91 // Fixed-point value, denominator=8 92 uint64_t ScaleFixed = 8; 93 // Check if the VLENB was potentially scaled with slli/srli 94 if (Def->getOpcode() == RISCV::SLLI) { 95 assert(Def->getOperand(2).getImm() < 64); 96 ScaleFixed <<= Def->getOperand(2).getImm(); 97 Def = MRI->getVRegDef(Def->getOperand(1).getReg()); 98 } else if (Def->getOpcode() == RISCV::SRLI) { 99 assert(Def->getOperand(2).getImm() < 64); 100 ScaleFixed >>= Def->getOperand(2).getImm(); 101 Def = MRI->getVRegDef(Def->getOperand(1).getReg()); 102 } 103 104 if (!Def || Def->getOpcode() != RISCV::PseudoReadVLENB) 105 return false; 106 107 auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags)); 108 // Fixed-point value, denominator=8 109 unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first; 110 unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm(); 111 // A Log2SEW of 0 is an operation on mask registers only 112 unsigned SEW = Log2SEW ? 1 << Log2SEW : 8; 113 assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW"); 114 assert(8 * LMULFixed / SEW > 0); 115 116 // AVL = (VLENB * Scale) 117 // 118 // VLMAX = (VLENB * 8 * LMUL) / SEW 119 // 120 // AVL == VLMAX 121 // -> VLENB * Scale == (VLENB * 8 * LMUL) / SEW 122 // -> Scale == (8 * LMUL) / SEW 123 if (ScaleFixed != 8 * LMULFixed / SEW) 124 return false; 125 126 VL.ChangeToImmediate(RISCV::VLMaxSentinel); 127 128 return true; 129 } 130 131 bool RISCVVectorPeephole::isAllOnesMask(const MachineInstr *MaskDef) const { 132 assert(MaskDef && MaskDef->isCopy() && 133 MaskDef->getOperand(0).getReg() == RISCV::V0); 134 Register SrcReg = TRI->lookThruCopyLike(MaskDef->getOperand(1).getReg(), MRI); 135 if (!SrcReg.isVirtual()) 136 return false; 137 MaskDef = MRI->getVRegDef(SrcReg); 138 if (!MaskDef) 139 return false; 140 141 // TODO: Check that the VMSET is the expected bitwidth? The pseudo has 142 // undefined behaviour if it's the wrong bitwidth, so we could choose to 143 // assume that it's all-ones? Same applies to its VL. 144 switch (MaskDef->getOpcode()) { 145 case RISCV::PseudoVMSET_M_B1: 146 case RISCV::PseudoVMSET_M_B2: 147 case RISCV::PseudoVMSET_M_B4: 148 case RISCV::PseudoVMSET_M_B8: 149 case RISCV::PseudoVMSET_M_B16: 150 case RISCV::PseudoVMSET_M_B32: 151 case RISCV::PseudoVMSET_M_B64: 152 return true; 153 default: 154 return false; 155 } 156 } 157 158 // Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to 159 // (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET. 160 bool RISCVVectorPeephole::convertVMergeToVMv(MachineInstr &MI) const { 161 #define CASE_VMERGE_TO_VMV(lmul) \ 162 case RISCV::PseudoVMERGE_VVM_##lmul: \ 163 NewOpc = RISCV::PseudoVMV_V_V_##lmul; \ 164 break; 165 unsigned NewOpc; 166 switch (MI.getOpcode()) { 167 default: 168 return false; 169 CASE_VMERGE_TO_VMV(MF8) 170 CASE_VMERGE_TO_VMV(MF4) 171 CASE_VMERGE_TO_VMV(MF2) 172 CASE_VMERGE_TO_VMV(M1) 173 CASE_VMERGE_TO_VMV(M2) 174 CASE_VMERGE_TO_VMV(M4) 175 CASE_VMERGE_TO_VMV(M8) 176 } 177 178 Register MergeReg = MI.getOperand(1).getReg(); 179 Register FalseReg = MI.getOperand(2).getReg(); 180 // Check merge == false (or merge == undef) 181 if (MergeReg != RISCV::NoRegister && TRI->lookThruCopyLike(MergeReg, MRI) != 182 TRI->lookThruCopyLike(FalseReg, MRI)) 183 return false; 184 185 assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0); 186 if (!isAllOnesMask(V0Defs.lookup(&MI))) 187 return false; 188 189 MI.setDesc(TII->get(NewOpc)); 190 MI.removeOperand(1); // Merge operand 191 MI.tieOperands(0, 1); // Tie false to dest 192 MI.removeOperand(3); // Mask operand 193 MI.addOperand( 194 MachineOperand::CreateImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED)); 195 196 // vmv.v.v doesn't have a mask operand, so we may be able to inflate the 197 // register class for the destination and merge operands e.g. VRNoV0 -> VR 198 MRI->recomputeRegClass(MI.getOperand(0).getReg()); 199 MRI->recomputeRegClass(MI.getOperand(1).getReg()); 200 return true; 201 } 202 203 bool RISCVVectorPeephole::convertToUnmasked(MachineInstr &MI) const { 204 const RISCV::RISCVMaskedPseudoInfo *I = 205 RISCV::getMaskedPseudoInfo(MI.getOpcode()); 206 if (!I) 207 return false; 208 209 if (!isAllOnesMask(V0Defs.lookup(&MI))) 210 return false; 211 212 // There are two classes of pseudos in the table - compares and 213 // everything else. See the comment on RISCVMaskedPseudo for details. 214 const unsigned Opc = I->UnmaskedPseudo; 215 const MCInstrDesc &MCID = TII->get(Opc); 216 [[maybe_unused]] const bool HasPolicyOp = 217 RISCVII::hasVecPolicyOp(MCID.TSFlags); 218 const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MCID); 219 #ifndef NDEBUG 220 const MCInstrDesc &MaskedMCID = TII->get(MI.getOpcode()); 221 assert(RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags) == 222 RISCVII::hasVecPolicyOp(MCID.TSFlags) && 223 "Masked and unmasked pseudos are inconsistent"); 224 assert(HasPolicyOp == HasPassthru && "Unexpected pseudo structure"); 225 #endif 226 (void)HasPolicyOp; 227 228 MI.setDesc(MCID); 229 230 // TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs? 231 unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs(); 232 MI.removeOperand(MaskOpIdx); 233 234 // The unmasked pseudo will no longer be constrained to the vrnov0 reg class, 235 // so try and relax it to vr. 236 MRI->recomputeRegClass(MI.getOperand(0).getReg()); 237 unsigned PassthruOpIdx = MI.getNumExplicitDefs(); 238 if (HasPassthru) { 239 if (MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister) 240 MRI->recomputeRegClass(MI.getOperand(PassthruOpIdx).getReg()); 241 } else 242 MI.removeOperand(PassthruOpIdx); 243 244 return true; 245 } 246 247 bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) { 248 if (skipFunction(MF.getFunction())) 249 return false; 250 251 // Skip if the vector extension is not enabled. 252 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); 253 if (!ST.hasVInstructions()) 254 return false; 255 256 TII = ST.getInstrInfo(); 257 MRI = &MF.getRegInfo(); 258 TRI = MRI->getTargetRegisterInfo(); 259 260 bool Changed = false; 261 262 // Masked pseudos coming out of isel will have their mask operand in the form: 263 // 264 // $v0:vr = COPY %mask:vr 265 // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr 266 // 267 // Because $v0 isn't in SSA, keep track of its definition at each use so we 268 // can check mask operands. 269 for (const MachineBasicBlock &MBB : MF) { 270 const MachineInstr *CurrentV0Def = nullptr; 271 for (const MachineInstr &MI : MBB) { 272 if (MI.readsRegister(RISCV::V0, TRI)) 273 V0Defs[&MI] = CurrentV0Def; 274 275 if (MI.definesRegister(RISCV::V0, TRI)) 276 CurrentV0Def = &MI; 277 } 278 } 279 280 for (MachineBasicBlock &MBB : MF) { 281 for (MachineInstr &MI : MBB) { 282 Changed |= convertToVLMAX(MI); 283 Changed |= convertToUnmasked(MI); 284 Changed |= convertVMergeToVMv(MI); 285 } 286 } 287 288 return Changed; 289 } 290 291 FunctionPass *llvm::createRISCVVectorPeepholePass() { 292 return new RISCVVectorPeephole(); 293 } 294