xref: /freebsd/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- RISCVVectorPeephole.cpp - MI Vector Pseudo Peepholes ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass performs various vector pseudo peephole optimisations after
10 // instruction selection.
11 //
12 // Currently it converts vmerge.vvm to vmv.v.v
13 // PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew
14 // ->
15 // PseudoVMV_V_V %false, %true, %vl, %sew
16 //
17 // And masked pseudos to unmasked pseudos
18 // PseudoVADD_V_V_MASK %passthru, %a, %b, %allonesmask, %vl, sew, policy
19 // ->
20 // PseudoVADD_V_V %passthru %a, %b, %vl, sew, policy
21 //
22 // It also converts AVLs to VLMAX where possible
23 // %vl = VLENB * something
24 // PseudoVADD_V_V %passthru, %a, %b, %vl, sew, policy
25 // ->
26 // PseudoVADD_V_V %passthru, %a, %b, -1, sew, policy
27 //
28 //===----------------------------------------------------------------------===//
29 
30 #include "RISCV.h"
31 #include "RISCVISelDAGToDAG.h"
32 #include "RISCVSubtarget.h"
33 #include "llvm/CodeGen/MachineFunctionPass.h"
34 #include "llvm/CodeGen/MachineRegisterInfo.h"
35 #include "llvm/CodeGen/TargetInstrInfo.h"
36 #include "llvm/CodeGen/TargetRegisterInfo.h"
37 
38 using namespace llvm;
39 
40 #define DEBUG_TYPE "riscv-vector-peephole"
41 
42 namespace {
43 
44 class RISCVVectorPeephole : public MachineFunctionPass {
45 public:
46   static char ID;
47   const TargetInstrInfo *TII;
48   MachineRegisterInfo *MRI;
49   const TargetRegisterInfo *TRI;
RISCVVectorPeephole()50   RISCVVectorPeephole() : MachineFunctionPass(ID) {}
51 
52   bool runOnMachineFunction(MachineFunction &MF) override;
getRequiredProperties() const53   MachineFunctionProperties getRequiredProperties() const override {
54     return MachineFunctionProperties().set(
55         MachineFunctionProperties::Property::IsSSA);
56   }
57 
getPassName() const58   StringRef getPassName() const override { return "RISC-V Fold Masks"; }
59 
60 private:
61   bool convertToVLMAX(MachineInstr &MI) const;
62   bool convertToUnmasked(MachineInstr &MI) const;
63   bool convertVMergeToVMv(MachineInstr &MI) const;
64 
65   bool isAllOnesMask(const MachineInstr *MaskDef) const;
66 
67   /// Maps uses of V0 to the corresponding def of V0.
68   DenseMap<const MachineInstr *, const MachineInstr *> V0Defs;
69 };
70 
71 } // namespace
72 
73 char RISCVVectorPeephole::ID = 0;
74 
75 INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false,
76                 false)
77 
78 // If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert it
79 // to the VLMAX sentinel value.
convertToVLMAX(MachineInstr & MI) const80 bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const {
81   if (!RISCVII::hasVLOp(MI.getDesc().TSFlags) ||
82       !RISCVII::hasSEWOp(MI.getDesc().TSFlags))
83     return false;
84   MachineOperand &VL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
85   if (!VL.isReg())
86     return false;
87   MachineInstr *Def = MRI->getVRegDef(VL.getReg());
88   if (!Def)
89     return false;
90 
91   // Fixed-point value, denominator=8
92   uint64_t ScaleFixed = 8;
93   // Check if the VLENB was potentially scaled with slli/srli
94   if (Def->getOpcode() == RISCV::SLLI) {
95     assert(Def->getOperand(2).getImm() < 64);
96     ScaleFixed <<= Def->getOperand(2).getImm();
97     Def = MRI->getVRegDef(Def->getOperand(1).getReg());
98   } else if (Def->getOpcode() == RISCV::SRLI) {
99     assert(Def->getOperand(2).getImm() < 64);
100     ScaleFixed >>= Def->getOperand(2).getImm();
101     Def = MRI->getVRegDef(Def->getOperand(1).getReg());
102   }
103 
104   if (!Def || Def->getOpcode() != RISCV::PseudoReadVLENB)
105     return false;
106 
107   auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags));
108   // Fixed-point value, denominator=8
109   unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first;
110   unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
111   // A Log2SEW of 0 is an operation on mask registers only
112   unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
113   assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW");
114   assert(8 * LMULFixed / SEW > 0);
115 
116   // AVL = (VLENB * Scale)
117   //
118   // VLMAX = (VLENB * 8 * LMUL) / SEW
119   //
120   // AVL == VLMAX
121   // -> VLENB * Scale == (VLENB * 8 * LMUL) / SEW
122   // -> Scale == (8 * LMUL) / SEW
123   if (ScaleFixed != 8 * LMULFixed / SEW)
124     return false;
125 
126   VL.ChangeToImmediate(RISCV::VLMaxSentinel);
127 
128   return true;
129 }
130 
isAllOnesMask(const MachineInstr * MaskDef) const131 bool RISCVVectorPeephole::isAllOnesMask(const MachineInstr *MaskDef) const {
132   assert(MaskDef && MaskDef->isCopy() &&
133          MaskDef->getOperand(0).getReg() == RISCV::V0);
134   Register SrcReg = TRI->lookThruCopyLike(MaskDef->getOperand(1).getReg(), MRI);
135   if (!SrcReg.isVirtual())
136     return false;
137   MaskDef = MRI->getVRegDef(SrcReg);
138   if (!MaskDef)
139     return false;
140 
141   // TODO: Check that the VMSET is the expected bitwidth? The pseudo has
142   // undefined behaviour if it's the wrong bitwidth, so we could choose to
143   // assume that it's all-ones? Same applies to its VL.
144   switch (MaskDef->getOpcode()) {
145   case RISCV::PseudoVMSET_M_B1:
146   case RISCV::PseudoVMSET_M_B2:
147   case RISCV::PseudoVMSET_M_B4:
148   case RISCV::PseudoVMSET_M_B8:
149   case RISCV::PseudoVMSET_M_B16:
150   case RISCV::PseudoVMSET_M_B32:
151   case RISCV::PseudoVMSET_M_B64:
152     return true;
153   default:
154     return false;
155   }
156 }
157 
158 // Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
159 // (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
convertVMergeToVMv(MachineInstr & MI) const160 bool RISCVVectorPeephole::convertVMergeToVMv(MachineInstr &MI) const {
161 #define CASE_VMERGE_TO_VMV(lmul)                                               \
162   case RISCV::PseudoVMERGE_VVM_##lmul:                                         \
163     NewOpc = RISCV::PseudoVMV_V_V_##lmul;                                      \
164     break;
165   unsigned NewOpc;
166   switch (MI.getOpcode()) {
167   default:
168     return false;
169     CASE_VMERGE_TO_VMV(MF8)
170     CASE_VMERGE_TO_VMV(MF4)
171     CASE_VMERGE_TO_VMV(MF2)
172     CASE_VMERGE_TO_VMV(M1)
173     CASE_VMERGE_TO_VMV(M2)
174     CASE_VMERGE_TO_VMV(M4)
175     CASE_VMERGE_TO_VMV(M8)
176   }
177 
178   Register MergeReg = MI.getOperand(1).getReg();
179   Register FalseReg = MI.getOperand(2).getReg();
180   // Check merge == false (or merge == undef)
181   if (MergeReg != RISCV::NoRegister && TRI->lookThruCopyLike(MergeReg, MRI) !=
182                                            TRI->lookThruCopyLike(FalseReg, MRI))
183     return false;
184 
185   assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0);
186   if (!isAllOnesMask(V0Defs.lookup(&MI)))
187     return false;
188 
189   MI.setDesc(TII->get(NewOpc));
190   MI.removeOperand(1);  // Merge operand
191   MI.tieOperands(0, 1); // Tie false to dest
192   MI.removeOperand(3);  // Mask operand
193   MI.addOperand(
194       MachineOperand::CreateImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED));
195 
196   // vmv.v.v doesn't have a mask operand, so we may be able to inflate the
197   // register class for the destination and merge operands e.g. VRNoV0 -> VR
198   MRI->recomputeRegClass(MI.getOperand(0).getReg());
199   MRI->recomputeRegClass(MI.getOperand(1).getReg());
200   return true;
201 }
202 
convertToUnmasked(MachineInstr & MI) const203 bool RISCVVectorPeephole::convertToUnmasked(MachineInstr &MI) const {
204   const RISCV::RISCVMaskedPseudoInfo *I =
205       RISCV::getMaskedPseudoInfo(MI.getOpcode());
206   if (!I)
207     return false;
208 
209   if (!isAllOnesMask(V0Defs.lookup(&MI)))
210     return false;
211 
212   // There are two classes of pseudos in the table - compares and
213   // everything else.  See the comment on RISCVMaskedPseudo for details.
214   const unsigned Opc = I->UnmaskedPseudo;
215   const MCInstrDesc &MCID = TII->get(Opc);
216   [[maybe_unused]] const bool HasPolicyOp =
217       RISCVII::hasVecPolicyOp(MCID.TSFlags);
218   const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MCID);
219 #ifndef NDEBUG
220   const MCInstrDesc &MaskedMCID = TII->get(MI.getOpcode());
221   assert(RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags) ==
222              RISCVII::hasVecPolicyOp(MCID.TSFlags) &&
223          "Masked and unmasked pseudos are inconsistent");
224   assert(HasPolicyOp == HasPassthru && "Unexpected pseudo structure");
225 #endif
226   (void)HasPolicyOp;
227 
228   MI.setDesc(MCID);
229 
230   // TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs?
231   unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs();
232   MI.removeOperand(MaskOpIdx);
233 
234   // The unmasked pseudo will no longer be constrained to the vrnov0 reg class,
235   // so try and relax it to vr.
236   MRI->recomputeRegClass(MI.getOperand(0).getReg());
237   unsigned PassthruOpIdx = MI.getNumExplicitDefs();
238   if (HasPassthru) {
239     if (MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister)
240       MRI->recomputeRegClass(MI.getOperand(PassthruOpIdx).getReg());
241   } else
242     MI.removeOperand(PassthruOpIdx);
243 
244   return true;
245 }
246 
runOnMachineFunction(MachineFunction & MF)247 bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
248   if (skipFunction(MF.getFunction()))
249     return false;
250 
251   // Skip if the vector extension is not enabled.
252   const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
253   if (!ST.hasVInstructions())
254     return false;
255 
256   TII = ST.getInstrInfo();
257   MRI = &MF.getRegInfo();
258   TRI = MRI->getTargetRegisterInfo();
259 
260   bool Changed = false;
261 
262   // Masked pseudos coming out of isel will have their mask operand in the form:
263   //
264   // $v0:vr = COPY %mask:vr
265   // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr
266   //
267   // Because $v0 isn't in SSA, keep track of its definition at each use so we
268   // can check mask operands.
269   for (const MachineBasicBlock &MBB : MF) {
270     const MachineInstr *CurrentV0Def = nullptr;
271     for (const MachineInstr &MI : MBB) {
272       if (MI.readsRegister(RISCV::V0, TRI))
273         V0Defs[&MI] = CurrentV0Def;
274 
275       if (MI.definesRegister(RISCV::V0, TRI))
276         CurrentV0Def = &MI;
277     }
278   }
279 
280   for (MachineBasicBlock &MBB : MF) {
281     for (MachineInstr &MI : MBB) {
282       Changed |= convertToVLMAX(MI);
283       Changed |= convertToUnmasked(MI);
284       Changed |= convertVMergeToVMv(MI);
285     }
286   }
287 
288   return Changed;
289 }
290 
createRISCVVectorPeepholePass()291 FunctionPass *llvm::createRISCVVectorPeepholePass() {
292   return new RISCVVectorPeephole();
293 }
294