xref: /freebsd/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVOptWInstrs.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 //
9 // This pass does some optimizations for *W instructions at the MI level.
10 //
11 // First it removes unneeded sext.w instructions. Either because the sign
12 // extended bits aren't consumed or because the input was already sign extended
13 // by an earlier instruction.
14 //
15 // Then:
16 // 1. Unless explicit disabled or the target prefers instructions with W suffix,
17 //    it removes the -w suffix from opw instructions whenever all users are
18 //    dependent only on the lower word of the result of the instruction.
19 //    The cases handled are:
20 //    * addw because c.add has a larger register encoding than c.addw.
21 //    * addiw because it helps reduce test differences between RV32 and RV64
22 //      w/o being a pessimization.
23 //    * mulw because c.mulw doesn't exist but c.mul does (w/ zcb)
24 //    * slliw because c.slliw doesn't exist and c.slli does
25 //
26 // 2. Or if explicit enabled or the target prefers instructions with W suffix,
27 //    it adds the W suffix to the instruction whenever all users are dependent
28 //    only on the lower word of the result of the instruction.
29 //    The cases handled are:
30 //    * add/addi/sub/mul.
31 //    * slli with imm < 32.
32 //    * ld/lwu.
33 //===---------------------------------------------------------------------===//
34 
35 #include "RISCV.h"
36 #include "RISCVMachineFunctionInfo.h"
37 #include "RISCVSubtarget.h"
38 #include "llvm/ADT/SmallSet.h"
39 #include "llvm/ADT/Statistic.h"
40 #include "llvm/CodeGen/MachineFunctionPass.h"
41 #include "llvm/CodeGen/TargetInstrInfo.h"
42 
43 using namespace llvm;
44 
45 #define DEBUG_TYPE "riscv-opt-w-instrs"
46 #define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions"
47 
48 STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions");
49 STATISTIC(NumTransformedToWInstrs,
50           "Number of instructions transformed to W-ops");
51 
52 static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal",
53                                          cl::desc("Disable removal of sext.w"),
54                                          cl::init(false), cl::Hidden);
55 static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix",
56                                          cl::desc("Disable strip W suffix"),
57                                          cl::init(false), cl::Hidden);
58 
59 namespace {
60 
61 class RISCVOptWInstrs : public MachineFunctionPass {
62 public:
63   static char ID;
64 
RISCVOptWInstrs()65   RISCVOptWInstrs() : MachineFunctionPass(ID) {}
66 
67   bool runOnMachineFunction(MachineFunction &MF) override;
68   bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII,
69                          const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
70   bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII,
71                       const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
72   bool appendWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII,
73                        const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
74 
getAnalysisUsage(AnalysisUsage & AU) const75   void getAnalysisUsage(AnalysisUsage &AU) const override {
76     AU.setPreservesCFG();
77     MachineFunctionPass::getAnalysisUsage(AU);
78   }
79 
getPassName() const80   StringRef getPassName() const override { return RISCV_OPT_W_INSTRS_NAME; }
81 };
82 
83 } // end anonymous namespace
84 
85 char RISCVOptWInstrs::ID = 0;
INITIALIZE_PASS(RISCVOptWInstrs,DEBUG_TYPE,RISCV_OPT_W_INSTRS_NAME,false,false)86 INITIALIZE_PASS(RISCVOptWInstrs, DEBUG_TYPE, RISCV_OPT_W_INSTRS_NAME, false,
87                 false)
88 
89 FunctionPass *llvm::createRISCVOptWInstrsPass() {
90   return new RISCVOptWInstrs();
91 }
92 
vectorPseudoHasAllNBitUsers(const MachineOperand & UserOp,unsigned Bits)93 static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp,
94                                         unsigned Bits) {
95   const MachineInstr &MI = *UserOp.getParent();
96   unsigned MCOpcode = RISCV::getRVVMCOpcode(MI.getOpcode());
97 
98   if (!MCOpcode)
99     return false;
100 
101   const MCInstrDesc &MCID = MI.getDesc();
102   const uint64_t TSFlags = MCID.TSFlags;
103   if (!RISCVII::hasSEWOp(TSFlags))
104     return false;
105   assert(RISCVII::hasVLOp(TSFlags));
106   const unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MCID)).getImm();
107 
108   if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID))
109     return false;
110 
111   auto NumDemandedBits =
112       RISCV::getVectorLowDemandedScalarBits(MCOpcode, Log2SEW);
113   return NumDemandedBits && Bits >= *NumDemandedBits;
114 }
115 
116 // Checks if all users only demand the lower \p OrigBits of the original
117 // instruction's result.
118 // TODO: handle multiple interdependent transformations
hasAllNBitUsers(const MachineInstr & OrigMI,const RISCVSubtarget & ST,const MachineRegisterInfo & MRI,unsigned OrigBits)119 static bool hasAllNBitUsers(const MachineInstr &OrigMI,
120                             const RISCVSubtarget &ST,
121                             const MachineRegisterInfo &MRI, unsigned OrigBits) {
122 
123   SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited;
124   SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist;
125 
126   Worklist.emplace_back(&OrigMI, OrigBits);
127 
128   while (!Worklist.empty()) {
129     auto P = Worklist.pop_back_val();
130     const MachineInstr *MI = P.first;
131     unsigned Bits = P.second;
132 
133     if (!Visited.insert(P).second)
134       continue;
135 
136     // Only handle instructions with one def.
137     if (MI->getNumExplicitDefs() != 1)
138       return false;
139 
140     Register DestReg = MI->getOperand(0).getReg();
141     if (!DestReg.isVirtual())
142       return false;
143 
144     for (auto &UserOp : MRI.use_nodbg_operands(DestReg)) {
145       const MachineInstr *UserMI = UserOp.getParent();
146       unsigned OpIdx = UserOp.getOperandNo();
147 
148       switch (UserMI->getOpcode()) {
149       default:
150         if (vectorPseudoHasAllNBitUsers(UserOp, Bits))
151           break;
152         return false;
153 
154       case RISCV::ADDIW:
155       case RISCV::ADDW:
156       case RISCV::DIVUW:
157       case RISCV::DIVW:
158       case RISCV::MULW:
159       case RISCV::REMUW:
160       case RISCV::REMW:
161       case RISCV::SLLW:
162       case RISCV::SRAIW:
163       case RISCV::SRAW:
164       case RISCV::SRLIW:
165       case RISCV::SRLW:
166       case RISCV::SUBW:
167       case RISCV::ROLW:
168       case RISCV::RORW:
169       case RISCV::RORIW:
170       case RISCV::CLZW:
171       case RISCV::CTZW:
172       case RISCV::CPOPW:
173       case RISCV::SLLI_UW:
174       case RISCV::FMV_W_X:
175       case RISCV::FCVT_H_W:
176       case RISCV::FCVT_H_W_INX:
177       case RISCV::FCVT_H_WU:
178       case RISCV::FCVT_H_WU_INX:
179       case RISCV::FCVT_S_W:
180       case RISCV::FCVT_S_W_INX:
181       case RISCV::FCVT_S_WU:
182       case RISCV::FCVT_S_WU_INX:
183       case RISCV::FCVT_D_W:
184       case RISCV::FCVT_D_W_INX:
185       case RISCV::FCVT_D_WU:
186       case RISCV::FCVT_D_WU_INX:
187         if (Bits >= 32)
188           break;
189         return false;
190 
191       case RISCV::SEXT_B:
192       case RISCV::PACKH:
193         if (Bits >= 8)
194           break;
195         return false;
196       case RISCV::SEXT_H:
197       case RISCV::FMV_H_X:
198       case RISCV::ZEXT_H_RV32:
199       case RISCV::ZEXT_H_RV64:
200       case RISCV::PACKW:
201         if (Bits >= 16)
202           break;
203         return false;
204 
205       case RISCV::PACK:
206         if (Bits >= (ST.getXLen() / 2))
207           break;
208         return false;
209 
210       case RISCV::SRLI: {
211         // If we are shifting right by less than Bits, and users don't demand
212         // any bits that were shifted into [Bits-1:0], then we can consider this
213         // as an N-Bit user.
214         unsigned ShAmt = UserMI->getOperand(2).getImm();
215         if (Bits > ShAmt) {
216           Worklist.emplace_back(UserMI, Bits - ShAmt);
217           break;
218         }
219         return false;
220       }
221 
222       // these overwrite higher input bits, otherwise the lower word of output
223       // depends only on the lower word of input. So check their uses read W.
224       case RISCV::SLLI: {
225         unsigned ShAmt = UserMI->getOperand(2).getImm();
226         if (Bits >= (ST.getXLen() - ShAmt))
227           break;
228         Worklist.emplace_back(UserMI, Bits + ShAmt);
229         break;
230       }
231       case RISCV::SLLIW: {
232         unsigned ShAmt = UserMI->getOperand(2).getImm();
233         if (Bits >= 32 - ShAmt)
234           break;
235         Worklist.emplace_back(UserMI, Bits + ShAmt);
236         break;
237       }
238 
239       case RISCV::ANDI: {
240         uint64_t Imm = UserMI->getOperand(2).getImm();
241         if (Bits >= (unsigned)llvm::bit_width(Imm))
242           break;
243         Worklist.emplace_back(UserMI, Bits);
244         break;
245       }
246       case RISCV::ORI: {
247         uint64_t Imm = UserMI->getOperand(2).getImm();
248         if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm))
249           break;
250         Worklist.emplace_back(UserMI, Bits);
251         break;
252       }
253 
254       case RISCV::SLL:
255       case RISCV::BSET:
256       case RISCV::BCLR:
257       case RISCV::BINV:
258         // Operand 2 is the shift amount which uses log2(xlen) bits.
259         if (OpIdx == 2) {
260           if (Bits >= Log2_32(ST.getXLen()))
261             break;
262           return false;
263         }
264         Worklist.emplace_back(UserMI, Bits);
265         break;
266 
267       case RISCV::SRA:
268       case RISCV::SRL:
269       case RISCV::ROL:
270       case RISCV::ROR:
271         // Operand 2 is the shift amount which uses 6 bits.
272         if (OpIdx == 2 && Bits >= Log2_32(ST.getXLen()))
273           break;
274         return false;
275 
276       case RISCV::ADD_UW:
277       case RISCV::SH1ADD_UW:
278       case RISCV::SH2ADD_UW:
279       case RISCV::SH3ADD_UW:
280         // Operand 1 is implicitly zero extended.
281         if (OpIdx == 1 && Bits >= 32)
282           break;
283         Worklist.emplace_back(UserMI, Bits);
284         break;
285 
286       case RISCV::BEXTI:
287         if (UserMI->getOperand(2).getImm() >= Bits)
288           return false;
289         break;
290 
291       case RISCV::SB:
292         // The first argument is the value to store.
293         if (OpIdx == 0 && Bits >= 8)
294           break;
295         return false;
296       case RISCV::SH:
297         // The first argument is the value to store.
298         if (OpIdx == 0 && Bits >= 16)
299           break;
300         return false;
301       case RISCV::SW:
302         // The first argument is the value to store.
303         if (OpIdx == 0 && Bits >= 32)
304           break;
305         return false;
306 
307       // For these, lower word of output in these operations, depends only on
308       // the lower word of input. So, we check all uses only read lower word.
309       case RISCV::COPY:
310       case RISCV::PHI:
311 
312       case RISCV::ADD:
313       case RISCV::ADDI:
314       case RISCV::AND:
315       case RISCV::MUL:
316       case RISCV::OR:
317       case RISCV::SUB:
318       case RISCV::XOR:
319       case RISCV::XORI:
320 
321       case RISCV::ANDN:
322       case RISCV::CLMUL:
323       case RISCV::ORN:
324       case RISCV::SH1ADD:
325       case RISCV::SH2ADD:
326       case RISCV::SH3ADD:
327       case RISCV::XNOR:
328       case RISCV::BSETI:
329       case RISCV::BCLRI:
330       case RISCV::BINVI:
331         Worklist.emplace_back(UserMI, Bits);
332         break;
333 
334       case RISCV::BREV8:
335       case RISCV::ORC_B:
336         // BREV8 and ORC_B work on bytes. Round Bits down to the nearest byte.
337         Worklist.emplace_back(UserMI, alignDown(Bits, 8));
338         break;
339 
340       case RISCV::PseudoCCMOVGPR:
341       case RISCV::PseudoCCMOVGPRNoX0:
342         // Either operand 4 or operand 5 is returned by this instruction. If
343         // only the lower word of the result is used, then only the lower word
344         // of operand 4 and 5 is used.
345         if (OpIdx != 4 && OpIdx != 5)
346           return false;
347         Worklist.emplace_back(UserMI, Bits);
348         break;
349 
350       case RISCV::CZERO_EQZ:
351       case RISCV::CZERO_NEZ:
352       case RISCV::VT_MASKC:
353       case RISCV::VT_MASKCN:
354         if (OpIdx != 1)
355           return false;
356         Worklist.emplace_back(UserMI, Bits);
357         break;
358       }
359     }
360   }
361 
362   return true;
363 }
364 
hasAllWUsers(const MachineInstr & OrigMI,const RISCVSubtarget & ST,const MachineRegisterInfo & MRI)365 static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST,
366                          const MachineRegisterInfo &MRI) {
367   return hasAllNBitUsers(OrigMI, ST, MRI, 32);
368 }
369 
370 // This function returns true if the machine instruction always outputs a value
371 // where bits 63:32 match bit 31.
isSignExtendingOpW(const MachineInstr & MI,unsigned OpNo)372 static bool isSignExtendingOpW(const MachineInstr &MI, unsigned OpNo) {
373   uint64_t TSFlags = MI.getDesc().TSFlags;
374 
375   // Instructions that can be determined from opcode are marked in tablegen.
376   if (TSFlags & RISCVII::IsSignExtendingOpWMask)
377     return true;
378 
379   // Special cases that require checking operands.
380   switch (MI.getOpcode()) {
381   // shifting right sufficiently makes the value 32-bit sign-extended
382   case RISCV::SRAI:
383     return MI.getOperand(2).getImm() >= 32;
384   case RISCV::SRLI:
385     return MI.getOperand(2).getImm() > 32;
386   // The LI pattern ADDI rd, X0, imm is sign extended.
387   case RISCV::ADDI:
388     return MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0;
389   // An ANDI with an 11 bit immediate will zero bits 63:11.
390   case RISCV::ANDI:
391     return isUInt<11>(MI.getOperand(2).getImm());
392   // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11.
393   case RISCV::ORI:
394     return !isUInt<11>(MI.getOperand(2).getImm());
395   // A bseti with X0 is sign extended if the immediate is less than 31.
396   case RISCV::BSETI:
397     return MI.getOperand(2).getImm() < 31 &&
398            MI.getOperand(1).getReg() == RISCV::X0;
399   // Copying from X0 produces zero.
400   case RISCV::COPY:
401     return MI.getOperand(1).getReg() == RISCV::X0;
402   // Ignore the scratch register destination.
403   case RISCV::PseudoAtomicLoadNand32:
404     return OpNo == 0;
405   case RISCV::PseudoVMV_X_S: {
406     // vmv.x.s has at least 33 sign bits if log2(sew) <= 5.
407     int64_t Log2SEW = MI.getOperand(2).getImm();
408     assert(Log2SEW >= 3 && Log2SEW <= 6 && "Unexpected Log2SEW");
409     return Log2SEW <= 5;
410   }
411   }
412 
413   return false;
414 }
415 
isSignExtendedW(Register SrcReg,const RISCVSubtarget & ST,const MachineRegisterInfo & MRI,SmallPtrSetImpl<MachineInstr * > & FixableDef)416 static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST,
417                             const MachineRegisterInfo &MRI,
418                             SmallPtrSetImpl<MachineInstr *> &FixableDef) {
419   SmallSet<Register, 4> Visited;
420   SmallVector<Register, 4> Worklist;
421 
422   auto AddRegToWorkList = [&](Register SrcReg) {
423     if (!SrcReg.isVirtual())
424       return false;
425     Worklist.push_back(SrcReg);
426     return true;
427   };
428 
429   if (!AddRegToWorkList(SrcReg))
430     return false;
431 
432   while (!Worklist.empty()) {
433     Register Reg = Worklist.pop_back_val();
434 
435     // If we already visited this register, we don't need to check it again.
436     if (!Visited.insert(Reg).second)
437       continue;
438 
439     MachineInstr *MI = MRI.getVRegDef(Reg);
440     if (!MI)
441       continue;
442 
443     int OpNo = MI->findRegisterDefOperandIdx(Reg, /*TRI=*/nullptr);
444     assert(OpNo != -1 && "Couldn't find register");
445 
446     // If this is a sign extending operation we don't need to look any further.
447     if (isSignExtendingOpW(*MI, OpNo))
448       continue;
449 
450     // Is this an instruction that propagates sign extend?
451     switch (MI->getOpcode()) {
452     default:
453       // Unknown opcode, give up.
454       return false;
455     case RISCV::COPY: {
456       const MachineFunction *MF = MI->getMF();
457       const RISCVMachineFunctionInfo *RVFI =
458           MF->getInfo<RISCVMachineFunctionInfo>();
459 
460       // If this is the entry block and the register is livein, see if we know
461       // it is sign extended.
462       if (MI->getParent() == &MF->front()) {
463         Register VReg = MI->getOperand(0).getReg();
464         if (MF->getRegInfo().isLiveIn(VReg) && RVFI->isSExt32Register(VReg))
465           continue;
466       }
467 
468       Register CopySrcReg = MI->getOperand(1).getReg();
469       if (CopySrcReg == RISCV::X10) {
470         // For a method return value, we check the ZExt/SExt flags in attribute.
471         // We assume the following code sequence for method call.
472         // PseudoCALL @bar, ...
473         // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2
474         // %0:gpr = COPY $x10
475         //
476         // We use the PseudoCall to look up the IR function being called to find
477         // its return attributes.
478         const MachineBasicBlock *MBB = MI->getParent();
479         auto II = MI->getIterator();
480         if (II == MBB->instr_begin() ||
481             (--II)->getOpcode() != RISCV::ADJCALLSTACKUP)
482           return false;
483 
484         const MachineInstr &CallMI = *(--II);
485         if (!CallMI.isCall() || !CallMI.getOperand(0).isGlobal())
486           return false;
487 
488         auto *CalleeFn =
489             dyn_cast_if_present<Function>(CallMI.getOperand(0).getGlobal());
490         if (!CalleeFn)
491           return false;
492 
493         auto *IntTy = dyn_cast<IntegerType>(CalleeFn->getReturnType());
494         if (!IntTy)
495           return false;
496 
497         const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs();
498         unsigned BitWidth = IntTy->getBitWidth();
499         if ((BitWidth <= 32 && Attrs.hasAttribute(Attribute::SExt)) ||
500             (BitWidth < 32 && Attrs.hasAttribute(Attribute::ZExt)))
501           continue;
502       }
503 
504       if (!AddRegToWorkList(CopySrcReg))
505         return false;
506 
507       break;
508     }
509 
510     // For these, we just need to check if the 1st operand is sign extended.
511     case RISCV::BCLRI:
512     case RISCV::BINVI:
513     case RISCV::BSETI:
514       if (MI->getOperand(2).getImm() >= 31)
515         return false;
516       [[fallthrough]];
517     case RISCV::REM:
518     case RISCV::ANDI:
519     case RISCV::ORI:
520     case RISCV::XORI:
521       // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R.
522       // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1
523       // Logical operations use a sign extended 12-bit immediate.
524       if (!AddRegToWorkList(MI->getOperand(1).getReg()))
525         return false;
526 
527       break;
528     case RISCV::PseudoCCADDW:
529     case RISCV::PseudoCCADDIW:
530     case RISCV::PseudoCCSUBW:
531     case RISCV::PseudoCCSLLW:
532     case RISCV::PseudoCCSRLW:
533     case RISCV::PseudoCCSRAW:
534     case RISCV::PseudoCCSLLIW:
535     case RISCV::PseudoCCSRLIW:
536     case RISCV::PseudoCCSRAIW:
537       // Returns operand 4 or an ADDW/SUBW/etc. of operands 5 and 6. We only
538       // need to check if operand 4 is sign extended.
539       if (!AddRegToWorkList(MI->getOperand(4).getReg()))
540         return false;
541       break;
542     case RISCV::REMU:
543     case RISCV::AND:
544     case RISCV::OR:
545     case RISCV::XOR:
546     case RISCV::ANDN:
547     case RISCV::ORN:
548     case RISCV::XNOR:
549     case RISCV::MAX:
550     case RISCV::MAXU:
551     case RISCV::MIN:
552     case RISCV::MINU:
553     case RISCV::PseudoCCMOVGPR:
554     case RISCV::PseudoCCMOVGPRNoX0:
555     case RISCV::PseudoCCAND:
556     case RISCV::PseudoCCOR:
557     case RISCV::PseudoCCXOR:
558     case RISCV::PHI: {
559       // If all incoming values are sign-extended, the output of AND, OR, XOR,
560       // MIN, MAX, or PHI is also sign-extended.
561 
562       // The input registers for PHI are operand 1, 3, ...
563       // The input registers for PseudoCCMOVGPR(NoX0) are 4 and 5.
564       // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6.
565       // The input registers for others are operand 1 and 2.
566       unsigned B = 1, E = 3, D = 1;
567       switch (MI->getOpcode()) {
568       case RISCV::PHI:
569         E = MI->getNumOperands();
570         D = 2;
571         break;
572       case RISCV::PseudoCCMOVGPR:
573       case RISCV::PseudoCCMOVGPRNoX0:
574         B = 4;
575         E = 6;
576         break;
577       case RISCV::PseudoCCAND:
578       case RISCV::PseudoCCOR:
579       case RISCV::PseudoCCXOR:
580         B = 4;
581         E = 7;
582         break;
583        }
584 
585       for (unsigned I = B; I != E; I += D) {
586         if (!MI->getOperand(I).isReg())
587           return false;
588 
589         if (!AddRegToWorkList(MI->getOperand(I).getReg()))
590           return false;
591       }
592 
593       break;
594     }
595 
596     case RISCV::CZERO_EQZ:
597     case RISCV::CZERO_NEZ:
598     case RISCV::VT_MASKC:
599     case RISCV::VT_MASKCN:
600       // Instructions return zero or operand 1. Result is sign extended if
601       // operand 1 is sign extended.
602       if (!AddRegToWorkList(MI->getOperand(1).getReg()))
603         return false;
604       break;
605 
606     case RISCV::ADDI: {
607       if (MI->getOperand(1).isReg() && MI->getOperand(1).getReg().isVirtual()) {
608         if (MachineInstr *SrcMI = MRI.getVRegDef(MI->getOperand(1).getReg())) {
609           if (SrcMI->getOpcode() == RISCV::LUI &&
610               SrcMI->getOperand(1).isImm()) {
611             uint64_t Imm = SrcMI->getOperand(1).getImm();
612             Imm = SignExtend64<32>(Imm << 12);
613             Imm += (uint64_t)MI->getOperand(2).getImm();
614             if (isInt<32>(Imm))
615               continue;
616           }
617         }
618       }
619 
620       if (hasAllWUsers(*MI, ST, MRI)) {
621         FixableDef.insert(MI);
622         break;
623       }
624       return false;
625     }
626 
627     // With these opcode, we can "fix" them with the W-version
628     // if we know all users of the result only rely on bits 31:0
629     case RISCV::SLLI:
630       // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
631       if (MI->getOperand(2).getImm() >= 32)
632         return false;
633       [[fallthrough]];
634     case RISCV::ADD:
635     case RISCV::LD:
636     case RISCV::LWU:
637     case RISCV::MUL:
638     case RISCV::SUB:
639       if (hasAllWUsers(*MI, ST, MRI)) {
640         FixableDef.insert(MI);
641         break;
642       }
643       return false;
644     }
645   }
646 
647   // If we get here, then every node we visited produces a sign extended value
648   // or propagated sign extended values. So the result must be sign extended.
649   return true;
650 }
651 
getWOp(unsigned Opcode)652 static unsigned getWOp(unsigned Opcode) {
653   switch (Opcode) {
654   case RISCV::ADDI:
655     return RISCV::ADDIW;
656   case RISCV::ADD:
657     return RISCV::ADDW;
658   case RISCV::LD:
659   case RISCV::LWU:
660     return RISCV::LW;
661   case RISCV::MUL:
662     return RISCV::MULW;
663   case RISCV::SLLI:
664     return RISCV::SLLIW;
665   case RISCV::SUB:
666     return RISCV::SUBW;
667   default:
668     llvm_unreachable("Unexpected opcode for replacement with W variant");
669   }
670 }
671 
removeSExtWInstrs(MachineFunction & MF,const RISCVInstrInfo & TII,const RISCVSubtarget & ST,MachineRegisterInfo & MRI)672 bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF,
673                                         const RISCVInstrInfo &TII,
674                                         const RISCVSubtarget &ST,
675                                         MachineRegisterInfo &MRI) {
676   if (DisableSExtWRemoval)
677     return false;
678 
679   bool MadeChange = false;
680   for (MachineBasicBlock &MBB : MF) {
681     for (MachineInstr &MI : llvm::make_early_inc_range(MBB)) {
682       // We're looking for the sext.w pattern ADDIW rd, rs1, 0.
683       if (!RISCVInstrInfo::isSEXT_W(MI))
684         continue;
685 
686       Register SrcReg = MI.getOperand(1).getReg();
687 
688       SmallPtrSet<MachineInstr *, 4> FixableDefs;
689 
690       // If all users only use the lower bits, this sext.w is redundant.
691       // Or if all definitions reaching MI sign-extend their output,
692       // then sext.w is redundant.
693       if (!hasAllWUsers(MI, ST, MRI) &&
694           !isSignExtendedW(SrcReg, ST, MRI, FixableDefs))
695         continue;
696 
697       Register DstReg = MI.getOperand(0).getReg();
698       if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg)))
699         continue;
700 
701       // Convert Fixable instructions to their W versions.
702       for (MachineInstr *Fixable : FixableDefs) {
703         LLVM_DEBUG(dbgs() << "Replacing " << *Fixable);
704         Fixable->setDesc(TII.get(getWOp(Fixable->getOpcode())));
705         Fixable->clearFlag(MachineInstr::MIFlag::NoSWrap);
706         Fixable->clearFlag(MachineInstr::MIFlag::NoUWrap);
707         Fixable->clearFlag(MachineInstr::MIFlag::IsExact);
708         LLVM_DEBUG(dbgs() << "     with " << *Fixable);
709         ++NumTransformedToWInstrs;
710       }
711 
712       LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n");
713       MRI.replaceRegWith(DstReg, SrcReg);
714       MRI.clearKillFlags(SrcReg);
715       MI.eraseFromParent();
716       ++NumRemovedSExtW;
717       MadeChange = true;
718     }
719   }
720 
721   return MadeChange;
722 }
723 
stripWSuffixes(MachineFunction & MF,const RISCVInstrInfo & TII,const RISCVSubtarget & ST,MachineRegisterInfo & MRI)724 bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF,
725                                      const RISCVInstrInfo &TII,
726                                      const RISCVSubtarget &ST,
727                                      MachineRegisterInfo &MRI) {
728   bool MadeChange = false;
729   for (MachineBasicBlock &MBB : MF) {
730     for (MachineInstr &MI : MBB) {
731       unsigned Opc;
732       switch (MI.getOpcode()) {
733       default:
734         continue;
735       case RISCV::ADDW:  Opc = RISCV::ADD;  break;
736       case RISCV::ADDIW: Opc = RISCV::ADDI; break;
737       case RISCV::MULW:  Opc = RISCV::MUL;  break;
738       case RISCV::SLLIW: Opc = RISCV::SLLI; break;
739       }
740 
741       if (hasAllWUsers(MI, ST, MRI)) {
742         MI.setDesc(TII.get(Opc));
743         MadeChange = true;
744       }
745     }
746   }
747 
748   return MadeChange;
749 }
750 
appendWSuffixes(MachineFunction & MF,const RISCVInstrInfo & TII,const RISCVSubtarget & ST,MachineRegisterInfo & MRI)751 bool RISCVOptWInstrs::appendWSuffixes(MachineFunction &MF,
752                                       const RISCVInstrInfo &TII,
753                                       const RISCVSubtarget &ST,
754                                       MachineRegisterInfo &MRI) {
755   bool MadeChange = false;
756   for (MachineBasicBlock &MBB : MF) {
757     for (MachineInstr &MI : MBB) {
758       unsigned WOpc;
759       // TODO: Add more?
760       switch (MI.getOpcode()) {
761       default:
762         continue;
763       case RISCV::ADD:
764         WOpc = RISCV::ADDW;
765         break;
766       case RISCV::ADDI:
767         WOpc = RISCV::ADDIW;
768         break;
769       case RISCV::SUB:
770         WOpc = RISCV::SUBW;
771         break;
772       case RISCV::MUL:
773         WOpc = RISCV::MULW;
774         break;
775       case RISCV::SLLI:
776         // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
777         if (MI.getOperand(2).getImm() >= 32)
778           continue;
779         WOpc = RISCV::SLLIW;
780         break;
781       case RISCV::LD:
782       case RISCV::LWU:
783         WOpc = RISCV::LW;
784         break;
785       }
786 
787       if (hasAllWUsers(MI, ST, MRI)) {
788         LLVM_DEBUG(dbgs() << "Replacing " << MI);
789         MI.setDesc(TII.get(WOpc));
790         MI.clearFlag(MachineInstr::MIFlag::NoSWrap);
791         MI.clearFlag(MachineInstr::MIFlag::NoUWrap);
792         MI.clearFlag(MachineInstr::MIFlag::IsExact);
793         LLVM_DEBUG(dbgs() << "     with " << MI);
794         ++NumTransformedToWInstrs;
795         MadeChange = true;
796       }
797     }
798   }
799 
800   return MadeChange;
801 }
802 
runOnMachineFunction(MachineFunction & MF)803 bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) {
804   if (skipFunction(MF.getFunction()))
805     return false;
806 
807   MachineRegisterInfo &MRI = MF.getRegInfo();
808   const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
809   const RISCVInstrInfo &TII = *ST.getInstrInfo();
810 
811   if (!ST.is64Bit())
812     return false;
813 
814   bool MadeChange = false;
815   MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI);
816 
817   if (!(DisableStripWSuffix || ST.preferWInst()))
818     MadeChange |= stripWSuffixes(MF, TII, ST, MRI);
819 
820   if (ST.preferWInst())
821     MadeChange |= appendWSuffixes(MF, TII, ST, MRI);
822 
823   return MadeChange;
824 }
825