xref: /freebsd/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVVectorPeephole.cpp (revision e64bea71c21eb42e97aa615188ba91f6cce0d36d)
1 //===- RISCVVectorPeephole.cpp - MI Vector Pseudo Peepholes ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass performs various vector pseudo peephole optimisations after
10 // instruction selection.
11 //
12 // Currently it converts vmerge.vvm to vmv.v.v
13 // PseudoVMERGE_VVM %false, %false, %true, %allonesmask, %vl, %sew
14 // ->
15 // PseudoVMV_V_V %false, %true, %vl, %sew
16 //
17 // And masked pseudos to unmasked pseudos
18 // PseudoVADD_V_V_MASK %passthru, %a, %b, %allonesmask, %vl, sew, policy
19 // ->
20 // PseudoVADD_V_V %passthru %a, %b, %vl, sew, policy
21 //
22 // It also converts AVLs to VLMAX where possible
23 // %vl = VLENB * something
24 // PseudoVADD_V_V %passthru, %a, %b, %vl, sew, policy
25 // ->
26 // PseudoVADD_V_V %passthru, %a, %b, -1, sew, policy
27 //
28 //===----------------------------------------------------------------------===//
29 
30 #include "RISCV.h"
31 #include "RISCVSubtarget.h"
32 #include "llvm/CodeGen/MachineFunctionPass.h"
33 #include "llvm/CodeGen/MachineRegisterInfo.h"
34 #include "llvm/CodeGen/TargetInstrInfo.h"
35 #include "llvm/CodeGen/TargetRegisterInfo.h"
36 
37 using namespace llvm;
38 
39 #define DEBUG_TYPE "riscv-vector-peephole"
40 
41 namespace {
42 
43 class RISCVVectorPeephole : public MachineFunctionPass {
44 public:
45   static char ID;
46   const TargetInstrInfo *TII;
47   MachineRegisterInfo *MRI;
48   const TargetRegisterInfo *TRI;
49   const RISCVSubtarget *ST;
RISCVVectorPeephole()50   RISCVVectorPeephole() : MachineFunctionPass(ID) {}
51 
52   bool runOnMachineFunction(MachineFunction &MF) override;
getRequiredProperties() const53   MachineFunctionProperties getRequiredProperties() const override {
54     return MachineFunctionProperties().setIsSSA();
55   }
56 
getPassName() const57   StringRef getPassName() const override {
58     return "RISC-V Vector Peephole Optimization";
59   }
60 
61 private:
62   bool tryToReduceVL(MachineInstr &MI) const;
63   bool convertToVLMAX(MachineInstr &MI) const;
64   bool convertToWholeRegister(MachineInstr &MI) const;
65   bool convertToUnmasked(MachineInstr &MI) const;
66   bool convertAllOnesVMergeToVMv(MachineInstr &MI) const;
67   bool convertSameMaskVMergeToVMv(MachineInstr &MI);
68   bool foldUndefPassthruVMV_V_V(MachineInstr &MI);
69   bool foldVMV_V_V(MachineInstr &MI);
70   bool foldVMergeToMask(MachineInstr &MI) const;
71 
72   bool hasSameEEW(const MachineInstr &User, const MachineInstr &Src) const;
73   bool isAllOnesMask(const MachineInstr *MaskDef) const;
74   std::optional<unsigned> getConstant(const MachineOperand &VL) const;
75   bool ensureDominates(const MachineOperand &Use, MachineInstr &Src) const;
76   bool isKnownSameDefs(Register A, Register B) const;
77 };
78 
79 } // namespace
80 
81 char RISCVVectorPeephole::ID = 0;
82 
83 INITIALIZE_PASS(RISCVVectorPeephole, DEBUG_TYPE, "RISC-V Fold Masks", false,
84                 false)
85 
86 /// Given \p User that has an input operand with EEW=SEW, which uses the dest
87 /// operand of \p Src with an unknown EEW, return true if their EEWs match.
hasSameEEW(const MachineInstr & User,const MachineInstr & Src) const88 bool RISCVVectorPeephole::hasSameEEW(const MachineInstr &User,
89                                      const MachineInstr &Src) const {
90   unsigned UserLog2SEW =
91       User.getOperand(RISCVII::getSEWOpNum(User.getDesc())).getImm();
92   unsigned SrcLog2SEW =
93       Src.getOperand(RISCVII::getSEWOpNum(Src.getDesc())).getImm();
94   unsigned SrcLog2EEW = RISCV::getDestLog2EEW(
95       TII->get(RISCV::getRVVMCOpcode(Src.getOpcode())), SrcLog2SEW);
96   return SrcLog2EEW == UserLog2SEW;
97 }
98 
99 // Attempt to reduce the VL of an instruction whose sole use is feeding a
100 // instruction with a narrower VL.  This currently works backwards from the
101 // user instruction (which might have a smaller VL).
tryToReduceVL(MachineInstr & MI) const102 bool RISCVVectorPeephole::tryToReduceVL(MachineInstr &MI) const {
103   // Note that the goal here is a bit multifaceted.
104   // 1) For store's reducing the VL of the value being stored may help to
105   //    reduce VL toggles.  This is somewhat of an artifact of the fact we
106   //    promote arithmetic instructions but VL predicate stores.
107   // 2) For vmv.v.v reducing VL eagerly on the source instruction allows us
108   //    to share code with the foldVMV_V_V transform below.
109   //
110   // Note that to the best of our knowledge, reducing VL is generally not
111   // a significant win on real hardware unless we can also reduce LMUL which
112   // this code doesn't try to do.
113   //
114   // TODO: We can handle a bunch more instructions here, and probably
115   // recurse backwards through operands too.
116   SmallVector<unsigned, 2> SrcIndices = {0};
117   switch (RISCV::getRVVMCOpcode(MI.getOpcode())) {
118   default:
119     return false;
120   case RISCV::VSE8_V:
121   case RISCV::VSE16_V:
122   case RISCV::VSE32_V:
123   case RISCV::VSE64_V:
124     break;
125   case RISCV::VMV_V_V:
126     SrcIndices[0] = 2;
127     break;
128   case RISCV::VMERGE_VVM:
129     SrcIndices.assign({2, 3});
130     break;
131   case RISCV::VREDSUM_VS:
132   case RISCV::VREDMAXU_VS:
133   case RISCV::VREDMAX_VS:
134   case RISCV::VREDMINU_VS:
135   case RISCV::VREDMIN_VS:
136   case RISCV::VREDAND_VS:
137   case RISCV::VREDOR_VS:
138   case RISCV::VREDXOR_VS:
139   case RISCV::VWREDSUM_VS:
140   case RISCV::VWREDSUMU_VS:
141   case RISCV::VFREDUSUM_VS:
142   case RISCV::VFREDOSUM_VS:
143   case RISCV::VFREDMAX_VS:
144   case RISCV::VFREDMIN_VS:
145   case RISCV::VFWREDUSUM_VS:
146   case RISCV::VFWREDOSUM_VS:
147     SrcIndices[0] = 2;
148     break;
149   }
150 
151   MachineOperand &VL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
152   if (VL.isImm() && VL.getImm() == RISCV::VLMaxSentinel)
153     return false;
154 
155   bool Changed = false;
156   for (unsigned SrcIdx : SrcIndices) {
157     Register SrcReg = MI.getOperand(SrcIdx).getReg();
158     // Note: one *use*, not one *user*.
159     if (!MRI->hasOneUse(SrcReg))
160       continue;
161 
162     MachineInstr *Src = MRI->getVRegDef(SrcReg);
163     if (!Src || Src->hasUnmodeledSideEffects() ||
164         Src->getParent() != MI.getParent() || Src->getNumDefs() != 1 ||
165         !RISCVII::hasVLOp(Src->getDesc().TSFlags) ||
166         !RISCVII::hasSEWOp(Src->getDesc().TSFlags))
167       continue;
168 
169     // Src's dest needs to have the same EEW as MI's input.
170     if (!hasSameEEW(MI, *Src))
171       continue;
172 
173     bool ElementsDependOnVL = RISCVII::elementsDependOnVL(
174         TII->get(RISCV::getRVVMCOpcode(Src->getOpcode())).TSFlags);
175     if (ElementsDependOnVL || Src->mayRaiseFPException())
176       continue;
177 
178     MachineOperand &SrcVL =
179         Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
180     if (VL.isIdenticalTo(SrcVL) || !RISCV::isVLKnownLE(VL, SrcVL))
181       continue;
182 
183     if (!ensureDominates(VL, *Src))
184       continue;
185 
186     if (VL.isImm())
187       SrcVL.ChangeToImmediate(VL.getImm());
188     else if (VL.isReg())
189       SrcVL.ChangeToRegister(VL.getReg(), false);
190 
191     Changed = true;
192   }
193 
194   // TODO: For instructions with a passthru, we could clear the passthru
195   // and tail policy since we've just proven the tail is not demanded.
196   return Changed;
197 }
198 
199 /// Check if an operand is an immediate or a materialized ADDI $x0, imm.
200 std::optional<unsigned>
getConstant(const MachineOperand & VL) const201 RISCVVectorPeephole::getConstant(const MachineOperand &VL) const {
202   if (VL.isImm())
203     return VL.getImm();
204 
205   MachineInstr *Def = MRI->getVRegDef(VL.getReg());
206   if (!Def || Def->getOpcode() != RISCV::ADDI ||
207       Def->getOperand(1).getReg() != RISCV::X0)
208     return std::nullopt;
209   return Def->getOperand(2).getImm();
210 }
211 
212 /// Convert AVLs that are known to be VLMAX to the VLMAX sentinel.
convertToVLMAX(MachineInstr & MI) const213 bool RISCVVectorPeephole::convertToVLMAX(MachineInstr &MI) const {
214   if (!RISCVII::hasVLOp(MI.getDesc().TSFlags) ||
215       !RISCVII::hasSEWOp(MI.getDesc().TSFlags))
216     return false;
217 
218   auto LMUL = RISCVVType::decodeVLMUL(RISCVII::getLMul(MI.getDesc().TSFlags));
219   // Fixed-point value, denominator=8
220   unsigned LMULFixed = LMUL.second ? (8 / LMUL.first) : 8 * LMUL.first;
221   unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MI.getDesc())).getImm();
222   // A Log2SEW of 0 is an operation on mask registers only
223   unsigned SEW = Log2SEW ? 1 << Log2SEW : 8;
224   assert(RISCVVType::isValidSEW(SEW) && "Unexpected SEW");
225   assert(8 * LMULFixed / SEW > 0);
226 
227   // If the exact VLEN is known then we know VLMAX, check if the AVL == VLMAX.
228   MachineOperand &VL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
229   if (auto VLen = ST->getRealVLen(), AVL = getConstant(VL);
230       VLen && AVL && (*VLen * LMULFixed) / SEW == *AVL * 8) {
231     VL.ChangeToImmediate(RISCV::VLMaxSentinel);
232     return true;
233   }
234 
235   // If an AVL is a VLENB that's possibly scaled to be equal to VLMAX, convert
236   // it to the VLMAX sentinel value.
237   if (!VL.isReg())
238     return false;
239   MachineInstr *Def = MRI->getVRegDef(VL.getReg());
240   if (!Def)
241     return false;
242 
243   // Fixed-point value, denominator=8
244   uint64_t ScaleFixed = 8;
245   // Check if the VLENB was potentially scaled with slli/srli
246   if (Def->getOpcode() == RISCV::SLLI) {
247     assert(Def->getOperand(2).getImm() < 64);
248     ScaleFixed <<= Def->getOperand(2).getImm();
249     Def = MRI->getVRegDef(Def->getOperand(1).getReg());
250   } else if (Def->getOpcode() == RISCV::SRLI) {
251     assert(Def->getOperand(2).getImm() < 64);
252     ScaleFixed >>= Def->getOperand(2).getImm();
253     Def = MRI->getVRegDef(Def->getOperand(1).getReg());
254   }
255 
256   if (!Def || Def->getOpcode() != RISCV::PseudoReadVLENB)
257     return false;
258 
259   // AVL = (VLENB * Scale)
260   //
261   // VLMAX = (VLENB * 8 * LMUL) / SEW
262   //
263   // AVL == VLMAX
264   // -> VLENB * Scale == (VLENB * 8 * LMUL) / SEW
265   // -> Scale == (8 * LMUL) / SEW
266   if (ScaleFixed != 8 * LMULFixed / SEW)
267     return false;
268 
269   VL.ChangeToImmediate(RISCV::VLMaxSentinel);
270 
271   return true;
272 }
273 
isAllOnesMask(const MachineInstr * MaskDef) const274 bool RISCVVectorPeephole::isAllOnesMask(const MachineInstr *MaskDef) const {
275   while (MaskDef->isCopy() && MaskDef->getOperand(1).getReg().isVirtual())
276     MaskDef = MRI->getVRegDef(MaskDef->getOperand(1).getReg());
277 
278   // TODO: Check that the VMSET is the expected bitwidth? The pseudo has
279   // undefined behaviour if it's the wrong bitwidth, so we could choose to
280   // assume that it's all-ones? Same applies to its VL.
281   switch (MaskDef->getOpcode()) {
282   case RISCV::PseudoVMSET_M_B1:
283   case RISCV::PseudoVMSET_M_B2:
284   case RISCV::PseudoVMSET_M_B4:
285   case RISCV::PseudoVMSET_M_B8:
286   case RISCV::PseudoVMSET_M_B16:
287   case RISCV::PseudoVMSET_M_B32:
288   case RISCV::PseudoVMSET_M_B64:
289     return true;
290   default:
291     return false;
292   }
293 }
294 
295 /// Convert unit strided unmasked loads and stores to whole-register equivalents
296 /// to avoid the dependency on $vl and $vtype.
297 ///
298 /// %x = PseudoVLE8_V_M1 %passthru, %ptr, %vlmax, policy
299 /// PseudoVSE8_V_M1 %v, %ptr, %vlmax
300 ///
301 /// ->
302 ///
303 /// %x = VL1RE8_V %ptr
304 /// VS1R_V %v, %ptr
convertToWholeRegister(MachineInstr & MI) const305 bool RISCVVectorPeephole::convertToWholeRegister(MachineInstr &MI) const {
306 #define CASE_WHOLE_REGISTER_LMUL_SEW(lmul, sew)                                \
307   case RISCV::PseudoVLE##sew##_V_M##lmul:                                      \
308     NewOpc = RISCV::VL##lmul##RE##sew##_V;                                     \
309     break;                                                                     \
310   case RISCV::PseudoVSE##sew##_V_M##lmul:                                      \
311     NewOpc = RISCV::VS##lmul##R_V;                                             \
312     break;
313 #define CASE_WHOLE_REGISTER_LMUL(lmul)                                         \
314   CASE_WHOLE_REGISTER_LMUL_SEW(lmul, 8)                                        \
315   CASE_WHOLE_REGISTER_LMUL_SEW(lmul, 16)                                       \
316   CASE_WHOLE_REGISTER_LMUL_SEW(lmul, 32)                                       \
317   CASE_WHOLE_REGISTER_LMUL_SEW(lmul, 64)
318 
319   unsigned NewOpc;
320   switch (MI.getOpcode()) {
321     CASE_WHOLE_REGISTER_LMUL(1)
322     CASE_WHOLE_REGISTER_LMUL(2)
323     CASE_WHOLE_REGISTER_LMUL(4)
324     CASE_WHOLE_REGISTER_LMUL(8)
325   default:
326     return false;
327   }
328 
329   MachineOperand &VLOp = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
330   if (!VLOp.isImm() || VLOp.getImm() != RISCV::VLMaxSentinel)
331     return false;
332 
333   // Whole register instructions aren't pseudos so they don't have
334   // policy/SEW/AVL ops, and they don't have passthrus.
335   if (RISCVII::hasVecPolicyOp(MI.getDesc().TSFlags))
336     MI.removeOperand(RISCVII::getVecPolicyOpNum(MI.getDesc()));
337   MI.removeOperand(RISCVII::getSEWOpNum(MI.getDesc()));
338   MI.removeOperand(RISCVII::getVLOpNum(MI.getDesc()));
339   if (RISCVII::isFirstDefTiedToFirstUse(MI.getDesc()))
340     MI.removeOperand(1);
341 
342   MI.setDesc(TII->get(NewOpc));
343 
344   return true;
345 }
346 
getVMV_V_VOpcodeForVMERGE_VVM(const MachineInstr & MI)347 static unsigned getVMV_V_VOpcodeForVMERGE_VVM(const MachineInstr &MI) {
348 #define CASE_VMERGE_TO_VMV(lmul)                                               \
349   case RISCV::PseudoVMERGE_VVM_##lmul:                                         \
350     return RISCV::PseudoVMV_V_V_##lmul;
351   switch (MI.getOpcode()) {
352   default:
353     return 0;
354     CASE_VMERGE_TO_VMV(MF8)
355     CASE_VMERGE_TO_VMV(MF4)
356     CASE_VMERGE_TO_VMV(MF2)
357     CASE_VMERGE_TO_VMV(M1)
358     CASE_VMERGE_TO_VMV(M2)
359     CASE_VMERGE_TO_VMV(M4)
360     CASE_VMERGE_TO_VMV(M8)
361   }
362 }
363 
364 /// Convert a PseudoVMERGE_VVM with an all ones mask to a PseudoVMV_V_V.
365 ///
366 /// %x = PseudoVMERGE_VVM %passthru, %false, %true, %allones, sew, vl
367 /// ->
368 /// %x = PseudoVMV_V_V %passthru, %true, vl, sew, tu_mu
convertAllOnesVMergeToVMv(MachineInstr & MI) const369 bool RISCVVectorPeephole::convertAllOnesVMergeToVMv(MachineInstr &MI) const {
370   unsigned NewOpc = getVMV_V_VOpcodeForVMERGE_VVM(MI);
371   if (!NewOpc)
372     return false;
373   if (!isAllOnesMask(MRI->getVRegDef(MI.getOperand(4).getReg())))
374     return false;
375 
376   MI.setDesc(TII->get(NewOpc));
377   MI.removeOperand(2); // False operand
378   MI.removeOperand(3); // Mask operand
379   MI.addOperand(
380       MachineOperand::CreateImm(RISCVVType::TAIL_UNDISTURBED_MASK_UNDISTURBED));
381 
382   // vmv.v.v doesn't have a mask operand, so we may be able to inflate the
383   // register class for the destination and passthru operands e.g. VRNoV0 -> VR
384   MRI->recomputeRegClass(MI.getOperand(0).getReg());
385   if (MI.getOperand(1).getReg() != RISCV::NoRegister)
386     MRI->recomputeRegClass(MI.getOperand(1).getReg());
387   return true;
388 }
389 
isKnownSameDefs(Register A,Register B) const390 bool RISCVVectorPeephole::isKnownSameDefs(Register A, Register B) const {
391   if (A.isPhysical() || B.isPhysical())
392     return false;
393 
394   auto LookThruVirtRegCopies = [this](Register Reg) {
395     while (MachineInstr *Def = MRI->getUniqueVRegDef(Reg)) {
396       if (!Def->isFullCopy())
397         break;
398       Register Src = Def->getOperand(1).getReg();
399       if (!Src.isVirtual())
400         break;
401       Reg = Src;
402     }
403     return Reg;
404   };
405 
406   return LookThruVirtRegCopies(A) == LookThruVirtRegCopies(B);
407 }
408 
409 /// If a PseudoVMERGE_VVM's true operand is a masked pseudo and both have the
410 /// same mask, and the masked pseudo's passthru is the same as the false
411 /// operand, we can convert the PseudoVMERGE_VVM to a PseudoVMV_V_V.
412 ///
413 /// %true = PseudoVADD_VV_M1_MASK %false, %x, %y, %mask, vl1, sew, policy
414 /// %x = PseudoVMERGE_VVM %passthru, %false, %true, %mask, vl2, sew
415 /// ->
416 /// %true = PseudoVADD_VV_M1_MASK %false, %x, %y, %mask, vl1, sew, policy
417 /// %x = PseudoVMV_V_V %passthru, %true, vl2, sew, tu_mu
convertSameMaskVMergeToVMv(MachineInstr & MI)418 bool RISCVVectorPeephole::convertSameMaskVMergeToVMv(MachineInstr &MI) {
419   unsigned NewOpc = getVMV_V_VOpcodeForVMERGE_VVM(MI);
420   if (!NewOpc)
421     return false;
422   MachineInstr *True = MRI->getVRegDef(MI.getOperand(3).getReg());
423 
424   if (!True || True->getParent() != MI.getParent())
425     return false;
426 
427   auto *TrueMaskedInfo = RISCV::getMaskedPseudoInfo(True->getOpcode());
428   if (!TrueMaskedInfo || !hasSameEEW(MI, *True))
429     return false;
430 
431   const MachineOperand &TrueMask =
432       True->getOperand(TrueMaskedInfo->MaskOpIdx + True->getNumExplicitDefs());
433   const MachineOperand &MIMask = MI.getOperand(4);
434   if (!isKnownSameDefs(TrueMask.getReg(), MIMask.getReg()))
435     return false;
436 
437   // Masked off lanes past TrueVL will come from False, and converting to vmv
438   // will lose these lanes unless MIVL <= TrueVL.
439   // TODO: We could relax this for False == Passthru and True policy == TU
440   const MachineOperand &MIVL = MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
441   const MachineOperand &TrueVL =
442       True->getOperand(RISCVII::getVLOpNum(True->getDesc()));
443   if (!RISCV::isVLKnownLE(MIVL, TrueVL))
444     return false;
445 
446   // True's passthru needs to be equivalent to False
447   Register TruePassthruReg = True->getOperand(1).getReg();
448   Register FalseReg = MI.getOperand(2).getReg();
449   if (TruePassthruReg != FalseReg) {
450     // If True's passthru is undef see if we can change it to False
451     if (TruePassthruReg != RISCV::NoRegister ||
452         !MRI->hasOneUse(MI.getOperand(3).getReg()) ||
453         !ensureDominates(MI.getOperand(2), *True))
454       return false;
455     True->getOperand(1).setReg(MI.getOperand(2).getReg());
456     // If True is masked then its passthru needs to be in VRNoV0.
457     MRI->constrainRegClass(True->getOperand(1).getReg(),
458                            TII->getRegClass(True->getDesc(), 1, TRI,
459                                             *True->getParent()->getParent()));
460   }
461 
462   MI.setDesc(TII->get(NewOpc));
463   MI.removeOperand(2); // False operand
464   MI.removeOperand(3); // Mask operand
465   MI.addOperand(
466       MachineOperand::CreateImm(RISCVVType::TAIL_UNDISTURBED_MASK_UNDISTURBED));
467 
468   // vmv.v.v doesn't have a mask operand, so we may be able to inflate the
469   // register class for the destination and passthru operands e.g. VRNoV0 -> VR
470   MRI->recomputeRegClass(MI.getOperand(0).getReg());
471   if (MI.getOperand(1).getReg() != RISCV::NoRegister)
472     MRI->recomputeRegClass(MI.getOperand(1).getReg());
473   return true;
474 }
475 
convertToUnmasked(MachineInstr & MI) const476 bool RISCVVectorPeephole::convertToUnmasked(MachineInstr &MI) const {
477   const RISCV::RISCVMaskedPseudoInfo *I =
478       RISCV::getMaskedPseudoInfo(MI.getOpcode());
479   if (!I)
480     return false;
481 
482   if (!isAllOnesMask(MRI->getVRegDef(
483           MI.getOperand(I->MaskOpIdx + MI.getNumExplicitDefs()).getReg())))
484     return false;
485 
486   // There are two classes of pseudos in the table - compares and
487   // everything else.  See the comment on RISCVMaskedPseudo for details.
488   const unsigned Opc = I->UnmaskedPseudo;
489   const MCInstrDesc &MCID = TII->get(Opc);
490   [[maybe_unused]] const bool HasPolicyOp =
491       RISCVII::hasVecPolicyOp(MCID.TSFlags);
492   const bool HasPassthru = RISCVII::isFirstDefTiedToFirstUse(MCID);
493   const MCInstrDesc &MaskedMCID = TII->get(MI.getOpcode());
494   assert((RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags) ||
495           !RISCVII::hasVecPolicyOp(MCID.TSFlags)) &&
496          "Unmasked pseudo has policy but masked pseudo doesn't?");
497   assert(HasPolicyOp == HasPassthru && "Unexpected pseudo structure");
498   assert(!(HasPassthru && !RISCVII::isFirstDefTiedToFirstUse(MaskedMCID)) &&
499          "Unmasked with passthru but masked with no passthru?");
500   (void)HasPolicyOp;
501 
502   MI.setDesc(MCID);
503 
504   // Drop the policy operand if unmasked doesn't need it.
505   if (RISCVII::hasVecPolicyOp(MaskedMCID.TSFlags) &&
506       !RISCVII::hasVecPolicyOp(MCID.TSFlags))
507     MI.removeOperand(RISCVII::getVecPolicyOpNum(MaskedMCID));
508 
509   // TODO: Increment all MaskOpIdxs in tablegen by num of explicit defs?
510   unsigned MaskOpIdx = I->MaskOpIdx + MI.getNumExplicitDefs();
511   MI.removeOperand(MaskOpIdx);
512 
513   // The unmasked pseudo will no longer be constrained to the vrnov0 reg class,
514   // so try and relax it to vr.
515   MRI->recomputeRegClass(MI.getOperand(0).getReg());
516 
517   // If the original masked pseudo had a passthru, relax it or remove it.
518   if (RISCVII::isFirstDefTiedToFirstUse(MaskedMCID)) {
519     unsigned PassthruOpIdx = MI.getNumExplicitDefs();
520     if (HasPassthru) {
521       if (MI.getOperand(PassthruOpIdx).getReg() != RISCV::NoRegister)
522         MRI->recomputeRegClass(MI.getOperand(PassthruOpIdx).getReg());
523     } else
524       MI.removeOperand(PassthruOpIdx);
525   }
526 
527   return true;
528 }
529 
530 /// Check if it's safe to move From down to To, checking that no physical
531 /// registers are clobbered.
isSafeToMove(const MachineInstr & From,const MachineInstr & To)532 static bool isSafeToMove(const MachineInstr &From, const MachineInstr &To) {
533   assert(From.getParent() == To.getParent());
534   SmallVector<Register> PhysUses, PhysDefs;
535   for (const MachineOperand &MO : From.all_uses())
536     if (MO.getReg().isPhysical())
537       PhysUses.push_back(MO.getReg());
538   for (const MachineOperand &MO : From.all_defs())
539     if (MO.getReg().isPhysical())
540       PhysDefs.push_back(MO.getReg());
541   bool SawStore = false;
542   for (auto II = std::next(From.getIterator()); II != To.getIterator(); II++) {
543     for (Register PhysReg : PhysUses)
544       if (II->definesRegister(PhysReg, nullptr))
545         return false;
546     for (Register PhysReg : PhysDefs)
547       if (II->definesRegister(PhysReg, nullptr) ||
548           II->readsRegister(PhysReg, nullptr))
549         return false;
550     if (II->mayStore()) {
551       SawStore = true;
552       break;
553     }
554   }
555   return From.isSafeToMove(SawStore);
556 }
557 
558 /// Given A and B are in the same MBB, returns true if A comes before B.
dominates(MachineBasicBlock::const_iterator A,MachineBasicBlock::const_iterator B)559 static bool dominates(MachineBasicBlock::const_iterator A,
560                       MachineBasicBlock::const_iterator B) {
561   assert(A->getParent() == B->getParent());
562   const MachineBasicBlock *MBB = A->getParent();
563   auto MBBEnd = MBB->end();
564   if (B == MBBEnd)
565     return true;
566 
567   MachineBasicBlock::const_iterator I = MBB->begin();
568   for (; &*I != A && &*I != B; ++I)
569     ;
570 
571   return &*I == A;
572 }
573 
574 /// If the register in \p MO doesn't dominate \p Src, try to move \p Src so it
575 /// does. Returns false if doesn't dominate and we can't move. \p MO must be in
576 /// the same basic block as \Src.
ensureDominates(const MachineOperand & MO,MachineInstr & Src) const577 bool RISCVVectorPeephole::ensureDominates(const MachineOperand &MO,
578                                           MachineInstr &Src) const {
579   assert(MO.getParent()->getParent() == Src.getParent());
580   if (!MO.isReg() || MO.getReg() == RISCV::NoRegister)
581     return true;
582 
583   MachineInstr *Def = MRI->getVRegDef(MO.getReg());
584   if (Def->getParent() == Src.getParent() && !dominates(Def, Src)) {
585     if (!isSafeToMove(Src, *Def->getNextNode()))
586       return false;
587     Src.moveBefore(Def->getNextNode());
588   }
589 
590   return true;
591 }
592 
593 /// If a PseudoVMV_V_V's passthru is undef then we can replace it with its input
foldUndefPassthruVMV_V_V(MachineInstr & MI)594 bool RISCVVectorPeephole::foldUndefPassthruVMV_V_V(MachineInstr &MI) {
595   if (RISCV::getRVVMCOpcode(MI.getOpcode()) != RISCV::VMV_V_V)
596     return false;
597   if (MI.getOperand(1).getReg() != RISCV::NoRegister)
598     return false;
599 
600   // If the input was a pseudo with a policy operand, we can give it a tail
601   // agnostic policy if MI's undef tail subsumes the input's.
602   MachineInstr *Src = MRI->getVRegDef(MI.getOperand(2).getReg());
603   if (Src && !Src->hasUnmodeledSideEffects() &&
604       MRI->hasOneUse(MI.getOperand(2).getReg()) &&
605       RISCVII::hasVLOp(Src->getDesc().TSFlags) &&
606       RISCVII::hasVecPolicyOp(Src->getDesc().TSFlags) && hasSameEEW(MI, *Src)) {
607     const MachineOperand &MIVL = MI.getOperand(3);
608     const MachineOperand &SrcVL =
609         Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
610 
611     MachineOperand &SrcPolicy =
612         Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc()));
613 
614     if (RISCV::isVLKnownLE(MIVL, SrcVL))
615       SrcPolicy.setImm(SrcPolicy.getImm() | RISCVVType::TAIL_AGNOSTIC);
616   }
617 
618   MRI->constrainRegClass(MI.getOperand(2).getReg(),
619                          MRI->getRegClass(MI.getOperand(0).getReg()));
620   MRI->replaceRegWith(MI.getOperand(0).getReg(), MI.getOperand(2).getReg());
621   MRI->clearKillFlags(MI.getOperand(2).getReg());
622   MI.eraseFromParent();
623   return true;
624 }
625 
626 /// If a PseudoVMV_V_V is the only user of its input, fold its passthru and VL
627 /// into it.
628 ///
629 /// %x = PseudoVADD_V_V_M1 %passthru, %a, %b, %vl1, sew, policy
630 /// %y = PseudoVMV_V_V_M1 %passthru, %x, %vl2, sew, policy
631 ///    (where %vl1 <= %vl2, see related tryToReduceVL)
632 ///
633 /// ->
634 ///
635 /// %y = PseudoVADD_V_V_M1 %passthru, %a, %b, vl1, sew, policy
foldVMV_V_V(MachineInstr & MI)636 bool RISCVVectorPeephole::foldVMV_V_V(MachineInstr &MI) {
637   if (RISCV::getRVVMCOpcode(MI.getOpcode()) != RISCV::VMV_V_V)
638     return false;
639 
640   MachineOperand &Passthru = MI.getOperand(1);
641 
642   if (!MRI->hasOneUse(MI.getOperand(2).getReg()))
643     return false;
644 
645   MachineInstr *Src = MRI->getVRegDef(MI.getOperand(2).getReg());
646   if (!Src || Src->hasUnmodeledSideEffects() ||
647       Src->getParent() != MI.getParent() ||
648       !RISCVII::isFirstDefTiedToFirstUse(Src->getDesc()) ||
649       !RISCVII::hasVLOp(Src->getDesc().TSFlags) ||
650       !RISCVII::hasVecPolicyOp(Src->getDesc().TSFlags))
651     return false;
652 
653   // Src's dest needs to have the same EEW as MI's input.
654   if (!hasSameEEW(MI, *Src))
655     return false;
656 
657   // Src needs to have the same passthru as VMV_V_V
658   MachineOperand &SrcPassthru = Src->getOperand(Src->getNumExplicitDefs());
659   if (SrcPassthru.getReg() != RISCV::NoRegister &&
660       SrcPassthru.getReg() != Passthru.getReg())
661     return false;
662 
663   // Src VL will have already been reduced if legal (see tryToReduceVL),
664   // so we don't need to handle a smaller source VL here.  However, the
665   // user's VL may be larger
666   MachineOperand &SrcVL = Src->getOperand(RISCVII::getVLOpNum(Src->getDesc()));
667   if (!RISCV::isVLKnownLE(SrcVL, MI.getOperand(3)))
668     return false;
669 
670   // If the new passthru doesn't dominate Src, try to move Src so it does.
671   if (!ensureDominates(Passthru, *Src))
672     return false;
673 
674   if (SrcPassthru.getReg() != Passthru.getReg()) {
675     SrcPassthru.setReg(Passthru.getReg());
676     // If Src is masked then its passthru needs to be in VRNoV0.
677     if (Passthru.getReg() != RISCV::NoRegister)
678       MRI->constrainRegClass(Passthru.getReg(),
679                              TII->getRegClass(Src->getDesc(),
680                                               SrcPassthru.getOperandNo(), TRI,
681                                               *Src->getParent()->getParent()));
682   }
683 
684   // If MI was tail agnostic and the VL didn't increase, preserve it.
685   int64_t Policy = RISCVVType::TAIL_UNDISTURBED_MASK_UNDISTURBED;
686   if ((MI.getOperand(5).getImm() & RISCVVType::TAIL_AGNOSTIC) &&
687       RISCV::isVLKnownLE(MI.getOperand(3), SrcVL))
688     Policy |= RISCVVType::TAIL_AGNOSTIC;
689   Src->getOperand(RISCVII::getVecPolicyOpNum(Src->getDesc())).setImm(Policy);
690 
691   MRI->constrainRegClass(Src->getOperand(0).getReg(),
692                          MRI->getRegClass(MI.getOperand(0).getReg()));
693   MRI->replaceRegWith(MI.getOperand(0).getReg(), Src->getOperand(0).getReg());
694   MI.eraseFromParent();
695 
696   return true;
697 }
698 
699 /// Try to fold away VMERGE_VVM instructions into their operands:
700 ///
701 /// %true = PseudoVADD_VV ...
702 /// %x = PseudoVMERGE_VVM_M1 %false, %false, %true, %mask
703 /// ->
704 /// %x = PseudoVADD_VV_M1_MASK %false, ..., %mask
705 ///
706 /// We can only fold if vmerge's passthru operand, vmerge's false operand and
707 /// %true's passthru operand (if it has one) are the same. This is because we
708 /// have to consolidate them into one passthru operand in the result.
709 ///
710 /// If %true is masked, then we can use its mask instead of vmerge's if vmerge's
711 /// mask is all ones.
712 ///
713 /// The resulting VL is the minimum of the two VLs.
714 ///
715 /// The resulting policy is the effective policy the vmerge would have had,
716 /// i.e. whether or not it's passthru operand was implicit-def.
foldVMergeToMask(MachineInstr & MI) const717 bool RISCVVectorPeephole::foldVMergeToMask(MachineInstr &MI) const {
718   if (RISCV::getRVVMCOpcode(MI.getOpcode()) != RISCV::VMERGE_VVM)
719     return false;
720 
721   Register PassthruReg = MI.getOperand(1).getReg();
722   Register FalseReg = MI.getOperand(2).getReg();
723   Register TrueReg = MI.getOperand(3).getReg();
724   if (!TrueReg.isVirtual() || !MRI->hasOneUse(TrueReg))
725     return false;
726   MachineInstr &True = *MRI->getUniqueVRegDef(TrueReg);
727   if (True.getParent() != MI.getParent())
728     return false;
729   const MachineOperand &MaskOp = MI.getOperand(4);
730   MachineInstr *Mask = MRI->getUniqueVRegDef(MaskOp.getReg());
731   assert(Mask);
732 
733   const RISCV::RISCVMaskedPseudoInfo *Info =
734       RISCV::lookupMaskedIntrinsicByUnmasked(True.getOpcode());
735   if (!Info)
736     return false;
737 
738   // If the EEW of True is different from vmerge's SEW, then we can't fold.
739   if (!hasSameEEW(MI, True))
740     return false;
741 
742   // We require that either passthru and false are the same, or that passthru
743   // is undefined.
744   if (PassthruReg && !isKnownSameDefs(PassthruReg, FalseReg))
745     return false;
746 
747   // If True has a passthru operand then it needs to be the same as vmerge's
748   // False, since False will be used for the result's passthru operand.
749   Register TruePassthru = True.getOperand(True.getNumExplicitDefs()).getReg();
750   if (RISCVII::isFirstDefTiedToFirstUse(True.getDesc()) && TruePassthru &&
751       !isKnownSameDefs(TruePassthru, FalseReg))
752     return false;
753 
754   // Make sure it doesn't raise any observable fp exceptions, since changing the
755   // active elements will affect how fflags is set.
756   if (True.hasUnmodeledSideEffects() || True.mayRaiseFPException())
757     return false;
758 
759   const MachineOperand &VMergeVL =
760       MI.getOperand(RISCVII::getVLOpNum(MI.getDesc()));
761   const MachineOperand &TrueVL =
762       True.getOperand(RISCVII::getVLOpNum(True.getDesc()));
763 
764   MachineOperand MinVL = MachineOperand::CreateImm(0);
765   if (RISCV::isVLKnownLE(TrueVL, VMergeVL))
766     MinVL = TrueVL;
767   else if (RISCV::isVLKnownLE(VMergeVL, TrueVL))
768     MinVL = VMergeVL;
769   else
770     return false;
771 
772   unsigned RVVTSFlags =
773       TII->get(RISCV::getRVVMCOpcode(True.getOpcode())).TSFlags;
774   if (RISCVII::elementsDependOnVL(RVVTSFlags) && !TrueVL.isIdenticalTo(MinVL))
775     return false;
776   if (RISCVII::elementsDependOnMask(RVVTSFlags) && !isAllOnesMask(Mask))
777     return false;
778 
779   // Use a tumu policy, relaxing it to tail agnostic provided that the passthru
780   // operand is undefined.
781   //
782   // However, if the VL became smaller than what the vmerge had originally, then
783   // elements past VL that were previously in the vmerge's body will have moved
784   // to the tail. In that case we always need to use tail undisturbed to
785   // preserve them.
786   uint64_t Policy = RISCVVType::TAIL_UNDISTURBED_MASK_UNDISTURBED;
787   if (!PassthruReg && RISCV::isVLKnownLE(VMergeVL, MinVL))
788     Policy |= RISCVVType::TAIL_AGNOSTIC;
789 
790   assert(RISCVII::hasVecPolicyOp(True.getDesc().TSFlags) &&
791          "Foldable unmasked pseudo should have a policy op already");
792 
793   // Make sure the mask dominates True, otherwise move down True so it does.
794   // VL will always dominate since if it's a register they need to be the same.
795   if (!ensureDominates(MaskOp, True))
796     return false;
797 
798   True.setDesc(TII->get(Info->MaskedPseudo));
799 
800   // Insert the mask operand.
801   // TODO: Increment MaskOpIdx by number of explicit defs?
802   True.insert(True.operands_begin() + Info->MaskOpIdx +
803                   True.getNumExplicitDefs(),
804               MachineOperand::CreateReg(MaskOp.getReg(), false));
805 
806   // Update the passthru, AVL and policy.
807   True.getOperand(True.getNumExplicitDefs()).setReg(FalseReg);
808   True.removeOperand(RISCVII::getVLOpNum(True.getDesc()));
809   True.insert(True.operands_begin() + RISCVII::getVLOpNum(True.getDesc()),
810               MinVL);
811   True.getOperand(RISCVII::getVecPolicyOpNum(True.getDesc())).setImm(Policy);
812 
813   MRI->replaceRegWith(True.getOperand(0).getReg(), MI.getOperand(0).getReg());
814   // Now that True is masked, constrain its operands from vr -> vrnov0.
815   for (MachineOperand &MO : True.explicit_operands()) {
816     if (!MO.isReg() || !MO.getReg().isVirtual())
817       continue;
818     MRI->constrainRegClass(
819         MO.getReg(), True.getRegClassConstraint(MO.getOperandNo(), TII, TRI));
820   }
821   MI.eraseFromParent();
822 
823   return true;
824 }
825 
runOnMachineFunction(MachineFunction & MF)826 bool RISCVVectorPeephole::runOnMachineFunction(MachineFunction &MF) {
827   if (skipFunction(MF.getFunction()))
828     return false;
829 
830   // Skip if the vector extension is not enabled.
831   ST = &MF.getSubtarget<RISCVSubtarget>();
832   if (!ST->hasVInstructions())
833     return false;
834 
835   TII = ST->getInstrInfo();
836   MRI = &MF.getRegInfo();
837   TRI = MRI->getTargetRegisterInfo();
838 
839   bool Changed = false;
840 
841   for (MachineBasicBlock &MBB : MF) {
842     for (MachineInstr &MI : make_early_inc_range(MBB))
843       Changed |= foldVMergeToMask(MI);
844 
845     for (MachineInstr &MI : make_early_inc_range(MBB)) {
846       Changed |= convertToVLMAX(MI);
847       Changed |= tryToReduceVL(MI);
848       Changed |= convertToUnmasked(MI);
849       Changed |= convertToWholeRegister(MI);
850       Changed |= convertAllOnesVMergeToVMv(MI);
851       Changed |= convertSameMaskVMergeToVMv(MI);
852       if (foldUndefPassthruVMV_V_V(MI)) {
853         Changed |= true;
854         continue; // MI is erased
855       }
856       Changed |= foldVMV_V_V(MI);
857     }
858   }
859 
860   return Changed;
861 }
862 
createRISCVVectorPeepholePass()863 FunctionPass *llvm::createRISCVVectorPeepholePass() {
864   return new RISCVVectorPeephole();
865 }
866