1*0fca6ea1SDimitry Andric //===- RISCVVectorPeephole.cpp - MI Vector Pseudo Peepholes ---------------===//
2*0fca6ea1SDimitry Andric //
3*0fca6ea1SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*0fca6ea1SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5*0fca6ea1SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*0fca6ea1SDimitry Andric //
7*0fca6ea1SDimitry Andric //===----------------------------------------------------------------------===//
8*0fca6ea1SDimitry Andric //
9*0fca6ea1SDimitry Andric // This pass performs various vector pseudo peephole optimisations after
10*0fca6ea1SDimitry Andric // instruction selection.
11*0fca6ea1SDimitry Andric //
12*0fca6ea1SDimitry Andric // Currently it converts vmerge.vvm to vmv.v.v
13*0fca6ea1SDimitry Andric // PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew
14*0fca6ea1SDimitry Andric // ->
15*0fca6ea1SDimitry Andric // PseudoVMV_V_V %false, %true, %vl, %sew
16*0fca6ea1SDimitry Andric //
17*0fca6ea1SDimitry Andric // And masked pseudos to unmasked pseudos
18*0fca6ea1SDimitry Andric // PseudoVADD_V_V_MASK %passthru, %a, %b, %allonesmask, %vl, sew, policy
19*0fca6ea1SDimitry Andric // ->
20*0fca6ea1SDimitry Andric // PseudoVADD_V_V %passthru %a, %b, %vl, sew, policy
21*0fca6ea1SDimitry Andric //
22*0fca6ea1SDimitry Andric // It also converts AVLs to VLMAX where possible
23*0fca6ea1SDimitry Andric // %vl = VLENB * something
24*0fca6ea1SDimitry Andric // PseudoVADD_V_V %passthru, %a, %b, %vl, sew, policy
25*0fca6ea1SDimitry Andric // ->
26*0fca6ea1SDimitry Andric // PseudoVADD_V_V %passthru, %a, %b, -1, sew, policy
27*0fca6ea1SDimitry Andric //
28*0fca6ea1SDimitry Andric //===----------------------------------------------------------------------===//
29*0fca6ea1SDimitry Andric
30*0fca6ea1SDimitry Andric #include "RISCV.h"
31*0fca6ea1SDimitry Andric #include "RISCVISelDAGToDAG.h"
32*0fca6ea1SDimitry Andric #include "RISCVSubtarget.h"
33*0fca6ea1SDimitry Andric #include "llvm/CodeGen/MachineFunctionPass.h"
34*0fca6ea1SDimitry Andric #include "llvm/CodeGen/MachineRegisterInfo.h"
35*0fca6ea1SDimitry Andric #include "llvm/CodeGen/TargetInstrInfo.h"
36*0fca6ea1SDimitry Andric #include "llvm/CodeGen/TargetRegisterInfo.h"
37*0fca6ea1SDimitry Andric
38*0fca6ea1SDimitry Andric using namespace llvm;
39*0fca6ea1SDimitry Andric
40*0fca6ea1SDimitry Andric #define DEBUG_TYPE "riscv-vector-peephole"
41*0fca6ea1SDimitry Andric
42*0fca6ea1SDimitry Andric namespace {
43*0fca6ea1SDimitry Andric
44*0fca6ea1SDimitry Andric class RISCVVectorPeephole : public MachineFunctionPass {
45*0fca6ea1SDimitry Andric public:
46*0fca6ea1SDimitry Andric static char ID;
47*0fca6ea1SDimitry Andric const TargetInstrInfo *TII;
48*0fca6ea1SDimitry Andric MachineRegisterInfo *MRI;
49*0fca6ea1SDimitry Andric const TargetRegisterInfo *TRI;
RISCVVectorPeephole()50*0fca6ea1SDimitry Andric RISCVVectorPeephole() : MachineFunctionPass(ID) {}
51*0fca6ea1SDimitry Andric
52*0fca6ea1SDimitry Andric bool runOnMachineFunction(MachineFunction &MF) override;
getRequiredProperties() const53*0fca6ea1SDimitry Andric MachineFunctionProperties getRequiredProperties() const override {
54*0fca6ea1SDimitry Andric return MachineFunctionProperties().set(
55*0fca6ea1SDimitry Andric MachineFunctionProperties::Property::IsSSA);
56*0fca6ea1SDimitry Andric }
57*0fca6ea1SDimitry Andric
getPassName() const58*0fca6ea1SDimitry Andric StringRef getPassName() const override { return "RISC-V Fold Masks"; }
59*0fca6ea1SDimitry Andric
60*0fca6ea1SDimitry Andric private:
61*0fca6ea1SDimitry Andric bool convertToVLMAX(MachineInstr &MI) const;
62*0fca6ea1SDimitry Andric bool convertToUnmasked(MachineInstr &MI) const;
63*0fca6ea1SDimitry Andric bool convertVMergeToVMv(MachineInstr &MI) const;
64*0fca6ea1SDimitry Andric
65*0fca6ea1SDimitry Andric bool isAllOnesMask(const MachineInstr *MaskDef) const;
66*0fca6ea1SDimitry Andric
67*0fca6ea1SDimitry Andric /// Maps uses of V0 to the corresponding def of V0.
68*0fca6ea1SDimitry Andric DenseMap<const MachineInstr *, const MachineInstr *> V0Defs;
69*0fca6ea1SDimitry Andric };
70*0fca6ea1SDimitry Andric
71*0fca6ea1SDimitry Andric } // namespace
72*0fca6ea1SDimitry Andric
73*0fca6ea1SDimitry Andric char RISCVVectorPeephole::ID = 0;
74*0fca6ea1SDimitry Andric
75*0fca6ea1SDimitry Andric INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false,
76*0fca6ea1SDimitry Andric false)
77*0fca6ea1SDimitry Andric
78*0fca6ea1SDimitry Andric // If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert it
79*0fca6ea1SDimitry Andric // to the VLMAX sentinel value.
convertToVLMAX(MachineInstr & MI) const80*0fca6ea1SDimitry Andric bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const {
81*0fca6ea1SDimitry Andric if (!RISCVII::hasVLOp(MI.getDesc().TSFlags) ||
82*0fca6ea1SDimitry Andric !RISCVII::hasSEWOp(MI.getDesc().TSFlags))
83*0fca6ea1SDimitry Andric return false;
84*0fca6ea1SDimitry Andric MachineOperand &VL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
85*0fca6ea1SDimitry Andric if (!VL.isReg())
86*0fca6ea1SDimitry Andric return false;
87*0fca6ea1SDimitry Andric MachineInstr *Def = MRI->getVRegDef(VL.getReg());
88*0fca6ea1SDimitry Andric if (!Def)
89*0fca6ea1SDimitry Andric return false;
90*0fca6ea1SDimitry Andric
91*0fca6ea1SDimitry Andric // Fixed-point value, denominator=8
92*0fca6ea1SDimitry Andric uint64_t ScaleFixed = 8;
93*0fca6ea1SDimitry Andric // Check if the VLENB was potentially scaled with slli/srli
94*0fca6ea1SDimitry Andric if (Def->getOpcode() == RISCV::SLLI) {
95*0fca6ea1SDimitry Andric assert(Def->getOperand(2).getImm() < 64);
96*0fca6ea1SDimitry Andric ScaleFixed <<= Def->getOperand(2).getImm();
97*0fca6ea1SDimitry Andric Def = MRI->getVRegDef(Def->getOperand(1).getReg());
98*0fca6ea1SDimitry Andric } else if (Def->getOpcode() == RISCV::SRLI) {
99*0fca6ea1SDimitry Andric assert(Def->getOperand(2).getImm() < 64);
100*0fca6ea1SDimitry Andric ScaleFixed >>= Def->getOperand(2).getImm();
101*0fca6ea1SDimitry Andric Def = MRI->getVRegDef(Def->getOperand(1).getReg());
102*0fca6ea1SDimitry Andric }
103*0fca6ea1SDimitry Andric
104*0fca6ea1SDimitry Andric if (!Def || Def->getOpcode() != RISCV::PseudoReadVLENB)
105*0fca6ea1SDimitry Andric return false;
106*0fca6ea1SDimitry Andric
107*0fca6ea1SDimitry Andric auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags));
108*0fca6ea1SDimitry Andric // Fixed-point value, denominator=8
109*0fca6ea1SDimitry Andric unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first;
110*0fca6ea1SDimitry Andric unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
111*0fca6ea1SDimitry Andric // A Log2SEW of 0 is an operation on mask registers only
112*0fca6ea1SDimitry Andric unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
113*0fca6ea1SDimitry Andric assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW");
114*0fca6ea1SDimitry Andric assert(8 * LMULFixed / SEW > 0);
115*0fca6ea1SDimitry Andric
116*0fca6ea1SDimitry Andric // AVL = (VLENB * Scale)
117*0fca6ea1SDimitry Andric //
118*0fca6ea1SDimitry Andric // VLMAX = (VLENB * 8 * LMUL) / SEW
119*0fca6ea1SDimitry Andric //
120*0fca6ea1SDimitry Andric // AVL == VLMAX
121*0fca6ea1SDimitry Andric // -> VLENB * Scale == (VLENB * 8 * LMUL) / SEW
122*0fca6ea1SDimitry Andric // -> Scale == (8 * LMUL) / SEW
123*0fca6ea1SDimitry Andric if (ScaleFixed != 8 * LMULFixed / SEW)
124*0fca6ea1SDimitry Andric return false;
125*0fca6ea1SDimitry Andric
126*0fca6ea1SDimitry Andric VL.ChangeToImmediate(RISCV::VLMaxSentinel);
127*0fca6ea1SDimitry Andric
128*0fca6ea1SDimitry Andric return true;
129*0fca6ea1SDimitry Andric }
130*0fca6ea1SDimitry Andric
isAllOnesMask(const MachineInstr * MaskDef) const131*0fca6ea1SDimitry Andric bool RISCVVectorPeephole::isAllOnesMask(const MachineInstr *MaskDef) const {
132*0fca6ea1SDimitry Andric assert(MaskDef && MaskDef->isCopy() &&
133*0fca6ea1SDimitry Andric MaskDef->getOperand(0).getReg() == RISCV::V0);
134*0fca6ea1SDimitry Andric Register SrcReg = TRI->lookThruCopyLike(MaskDef->getOperand(1).getReg(), MRI);
135*0fca6ea1SDimitry Andric if (!SrcReg.isVirtual())
136*0fca6ea1SDimitry Andric return false;
137*0fca6ea1SDimitry Andric MaskDef = MRI->getVRegDef(SrcReg);
138*0fca6ea1SDimitry Andric if (!MaskDef)
139*0fca6ea1SDimitry Andric return false;
140*0fca6ea1SDimitry Andric
141*0fca6ea1SDimitry Andric // TODO: Check that the VMSET is the expected bitwidth? The pseudo has
142*0fca6ea1SDimitry Andric // undefined behaviour if it's the wrong bitwidth, so we could choose to
143*0fca6ea1SDimitry Andric // assume that it's all-ones? Same applies to its VL.
144*0fca6ea1SDimitry Andric switch (MaskDef->getOpcode()) {
145*0fca6ea1SDimitry Andric case RISCV::PseudoVMSET_M_B1:
146*0fca6ea1SDimitry Andric case RISCV::PseudoVMSET_M_B2:
147*0fca6ea1SDimitry Andric case RISCV::PseudoVMSET_M_B4:
148*0fca6ea1SDimitry Andric case RISCV::PseudoVMSET_M_B8:
149*0fca6ea1SDimitry Andric case RISCV::PseudoVMSET_M_B16:
150*0fca6ea1SDimitry Andric case RISCV::PseudoVMSET_M_B32:
151*0fca6ea1SDimitry Andric case RISCV::PseudoVMSET_M_B64:
152*0fca6ea1SDimitry Andric return true;
153*0fca6ea1SDimitry Andric default:
154*0fca6ea1SDimitry Andric return false;
155*0fca6ea1SDimitry Andric }
156*0fca6ea1SDimitry Andric }
157*0fca6ea1SDimitry Andric
158*0fca6ea1SDimitry Andric // Transform (VMERGE_VVM_<LMUL> false, false, true, allones, vl, sew) to
159*0fca6ea1SDimitry Andric // (VMV_V_V_<LMUL> false, true, vl, sew). It may decrease uses of VMSET.
convertVMergeToVMv(MachineInstr & MI) const160*0fca6ea1SDimitry Andric bool RISCVVectorPeephole::convertVMergeToVMv(MachineInstr &MI) const {
161*0fca6ea1SDimitry Andric #define CASE_VMERGE_TO_VMV(lmul) \
162*0fca6ea1SDimitry Andric case RISCV::PseudoVMERGE_VVM_##lmul: \
163*0fca6ea1SDimitry Andric NewOpc = RISCV::PseudoVMV_V_V_##lmul; \
164*0fca6ea1SDimitry Andric break;
165*0fca6ea1SDimitry Andric unsigned NewOpc;
166*0fca6ea1SDimitry Andric switch (MI.getOpcode()) {
167*0fca6ea1SDimitry Andric default:
168*0fca6ea1SDimitry Andric return false;
169*0fca6ea1SDimitry Andric CASE_VMERGE_TO_VMV(MF8)
170*0fca6ea1SDimitry Andric CASE_VMERGE_TO_VMV(MF4)
171*0fca6ea1SDimitry Andric CASE_VMERGE_TO_VMV(MF2)
172*0fca6ea1SDimitry Andric CASE_VMERGE_TO_VMV(M1)
173*0fca6ea1SDimitry Andric CASE_VMERGE_TO_VMV(M2)
174*0fca6ea1SDimitry Andric CASE_VMERGE_TO_VMV(M4)
175*0fca6ea1SDimitry Andric CASE_VMERGE_TO_VMV(M8)
176*0fca6ea1SDimitry Andric }
177*0fca6ea1SDimitry Andric
178*0fca6ea1SDimitry Andric Register MergeReg = MI.getOperand(1).getReg();
179*0fca6ea1SDimitry Andric Register FalseReg = MI.getOperand(2).getReg();
180*0fca6ea1SDimitry Andric // Check merge == false (or merge == undef)
181*0fca6ea1SDimitry Andric if (MergeReg != RISCV::NoRegister && TRI->lookThruCopyLike(MergeReg, MRI) !=
182*0fca6ea1SDimitry Andric TRI->lookThruCopyLike(FalseReg, MRI))
183*0fca6ea1SDimitry Andric return false;
184*0fca6ea1SDimitry Andric
185*0fca6ea1SDimitry Andric assert(MI.getOperand(4).isReg() && MI.getOperand(4).getReg() == RISCV::V0);
186*0fca6ea1SDimitry Andric if (!isAllOnesMask(V0Defs.lookup(&MI)))
187*0fca6ea1SDimitry Andric return false;
188*0fca6ea1SDimitry Andric
189*0fca6ea1SDimitry Andric MI.setDesc(TII->get(NewOpc));
190*0fca6ea1SDimitry Andric MI.removeOperand(1); // Merge operand
191*0fca6ea1SDimitry Andric MI.tieOperands(0, 1); // Tie false to dest
192*0fca6ea1SDimitry Andric MI.removeOperand(3); // Mask operand
193*0fca6ea1SDimitry Andric MI.addOperand(
194*0fca6ea1SDimitry Andric MachineOperand::CreateImm(RISCVII::TAIL_UNDISTURBED_MASK_UNDISTURBED));
195*0fca6ea1SDimitry Andric
196*0fca6ea1SDimitry Andric // vmv.v.v doesn't have a mask operand, so we may be able to inflate the
197*0fca6ea1SDimitry Andric // register class for the destination and merge operands e.g. VRNoV0 -> VR
198*0fca6ea1SDimitry Andric MRI->recomputeRegClass(MI.getOperand(0).getReg());
199*0fca6ea1SDimitry Andric MRI->recomputeRegClass(MI.getOperand(1).getReg());
200*0fca6ea1SDimitry Andric return true;
201*0fca6ea1SDimitry Andric }
202*0fca6ea1SDimitry Andric
convertToUnmasked(MachineInstr & MI) const203*0fca6ea1SDimitry Andric bool RISCVVectorPeephole::convertToUnmasked(MachineInstr &MI) const {
204*0fca6ea1SDimitry Andric const RISCV::RISCVMaskedPseudoInfo *I =
205*0fca6ea1SDimitry Andric RISCV::getMaskedPseudoInfo(MI.getOpcode());
206*0fca6ea1SDimitry Andric if (!I)
207*0fca6ea1SDimitry Andric return false;
208*0fca6ea1SDimitry Andric
209*0fca6ea1SDimitry Andric if (!isAllOnesMask(V0Defs.lookup(&MI)))
210*0fca6ea1SDimitry Andric return false;
211*0fca6ea1SDimitry Andric
212*0fca6ea1SDimitry Andric // There are two classes of pseudos in the table - compares and
213*0fca6ea1SDimitry Andric // everything else. See the comment on RISCVMaskedPseudo for details.
214*0fca6ea1SDimitry Andric const unsigned Opc = I->UnmaskedPseudo;
215*0fca6ea1SDimitry Andric const MCInstrDesc &MCID = TII->get(Opc);
216*0fca6ea1SDimitry Andric [[maybe_unused]] const bool HasPolicyOp =
217*0fca6ea1SDimitry Andric RISCVII::hasVecPolicyOp(MCID.TSFlags);
218*0fca6ea1SDimitry Andric const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MCID);
219*0fca6ea1SDimitry Andric #ifndef NDEBUG
220*0fca6ea1SDimitry Andric const MCInstrDesc &MaskedMCID = TII->get(MI.getOpcode());
221*0fca6ea1SDimitry Andric assert(RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags) ==
222*0fca6ea1SDimitry Andric RISCVII::hasVecPolicyOp(MCID.TSFlags) &&
223*0fca6ea1SDimitry Andric "Masked and unmasked pseudos are inconsistent");
224*0fca6ea1SDimitry Andric assert(HasPolicyOp == HasPassthru && "Unexpected pseudo structure");
225*0fca6ea1SDimitry Andric #endif
226*0fca6ea1SDimitry Andric (void)HasPolicyOp;
227*0fca6ea1SDimitry Andric
228*0fca6ea1SDimitry Andric MI.setDesc(MCID);
229*0fca6ea1SDimitry Andric
230*0fca6ea1SDimitry Andric // TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs?
231*0fca6ea1SDimitry Andric unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs();
232*0fca6ea1SDimitry Andric MI.removeOperand(MaskOpIdx);
233*0fca6ea1SDimitry Andric
234*0fca6ea1SDimitry Andric // The unmasked pseudo will no longer be constrained to the vrnov0 reg class,
235*0fca6ea1SDimitry Andric // so try and relax it to vr.
236*0fca6ea1SDimitry Andric MRI->recomputeRegClass(MI.getOperand(0).getReg());
237*0fca6ea1SDimitry Andric unsigned PassthruOpIdx = MI.getNumExplicitDefs();
238*0fca6ea1SDimitry Andric if (HasPassthru) {
239*0fca6ea1SDimitry Andric if (MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister)
240*0fca6ea1SDimitry Andric MRI->recomputeRegClass(MI.getOperand(PassthruOpIdx).getReg());
241*0fca6ea1SDimitry Andric } else
242*0fca6ea1SDimitry Andric MI.removeOperand(PassthruOpIdx);
243*0fca6ea1SDimitry Andric
244*0fca6ea1SDimitry Andric return true;
245*0fca6ea1SDimitry Andric }
246*0fca6ea1SDimitry Andric
runOnMachineFunction(MachineFunction & MF)247*0fca6ea1SDimitry Andric bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
248*0fca6ea1SDimitry Andric if (skipFunction(MF.getFunction()))
249*0fca6ea1SDimitry Andric return false;
250*0fca6ea1SDimitry Andric
251*0fca6ea1SDimitry Andric // Skip if the vector extension is not enabled.
252*0fca6ea1SDimitry Andric const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
253*0fca6ea1SDimitry Andric if (!ST.hasVInstructions())
254*0fca6ea1SDimitry Andric return false;
255*0fca6ea1SDimitry Andric
256*0fca6ea1SDimitry Andric TII = ST.getInstrInfo();
257*0fca6ea1SDimitry Andric MRI = &MF.getRegInfo();
258*0fca6ea1SDimitry Andric TRI = MRI->getTargetRegisterInfo();
259*0fca6ea1SDimitry Andric
260*0fca6ea1SDimitry Andric bool Changed = false;
261*0fca6ea1SDimitry Andric
262*0fca6ea1SDimitry Andric // Masked pseudos coming out of isel will have their mask operand in the form:
263*0fca6ea1SDimitry Andric //
264*0fca6ea1SDimitry Andric // $v0:vr = COPY %mask:vr
265*0fca6ea1SDimitry Andric // %x:vr = Pseudo_MASK %a:vr, %b:br, $v0:vr
266*0fca6ea1SDimitry Andric //
267*0fca6ea1SDimitry Andric // Because $v0 isn't in SSA, keep track of its definition at each use so we
268*0fca6ea1SDimitry Andric // can check mask operands.
269*0fca6ea1SDimitry Andric for (const MachineBasicBlock &MBB : MF) {
270*0fca6ea1SDimitry Andric const MachineInstr *CurrentV0Def = nullptr;
271*0fca6ea1SDimitry Andric for (const MachineInstr &MI : MBB) {
272*0fca6ea1SDimitry Andric if (MI.readsRegister(RISCV::V0, TRI))
273*0fca6ea1SDimitry Andric V0Defs[&MI] = CurrentV0Def;
274*0fca6ea1SDimitry Andric
275*0fca6ea1SDimitry Andric if (MI.definesRegister(RISCV::V0, TRI))
276*0fca6ea1SDimitry Andric CurrentV0Def = &MI;
277*0fca6ea1SDimitry Andric }
278*0fca6ea1SDimitry Andric }
279*0fca6ea1SDimitry Andric
280*0fca6ea1SDimitry Andric for (MachineBasicBlock &MBB : MF) {
281*0fca6ea1SDimitry Andric for (MachineInstr &MI : MBB) {
282*0fca6ea1SDimitry Andric Changed |= convertToVLMAX(MI);
283*0fca6ea1SDimitry Andric Changed |= convertToUnmasked(MI);
284*0fca6ea1SDimitry Andric Changed |= convertVMergeToVMv(MI);
285*0fca6ea1SDimitry Andric }
286*0fca6ea1SDimitry Andric }
287*0fca6ea1SDimitry Andric
288*0fca6ea1SDimitry Andric return Changed;
289*0fca6ea1SDimitry Andric }
290*0fca6ea1SDimitry Andric
createRISCVVectorPeepholePass()291*0fca6ea1SDimitry Andric FunctionPass *llvm::createRISCVVectorPeepholePass() {
292*0fca6ea1SDimitry Andric return new RISCVVectorPeephole();
293*0fca6ea1SDimitry Andric }
294