1 //===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===---------------------------------------------------------------------===//
8 //
9 // This pass does some optimizations for *W instructions at the MI level.
10 //
11 // First it removes unneeded sext.w instructions. Either because the sign
12 // extended bits aren't consumed or because the input was already sign extended
13 // by an earlier instruction.
14 //
15 // Then:
16 // 1. Unless explicit disabled or the target prefers instructions with W suffix,
17 // it removes the -w suffix from opw instructions whenever all users are
18 // dependent only on the lower word of the result of the instruction.
19 // The cases handled are:
20 // * addw because c.add has a larger register encoding than c.addw.
21 // * addiw because it helps reduce test differences between RV32 and RV64
22 // w/o being a pessimization.
23 // * mulw because c.mulw doesn't exist but c.mul does (w/ zcb)
24 // * slliw because c.slliw doesn't exist and c.slli does
25 //
26 // 2. Or if explicit enabled or the target prefers instructions with W suffix,
27 // it adds the W suffix to the instruction whenever all users are dependent
28 // only on the lower word of the result of the instruction.
29 // The cases handled are:
30 // * add/addi/sub/mul.
31 // * slli with imm < 32.
32 // * ld/lwu.
33 //===---------------------------------------------------------------------===//
34
35 #include "RISCV.h"
36 #include "RISCVMachineFunctionInfo.h"
37 #include "RISCVSubtarget.h"
38 #include "llvm/ADT/SmallSet.h"
39 #include "llvm/ADT/Statistic.h"
40 #include "llvm/CodeGen/MachineFunctionPass.h"
41 #include "llvm/CodeGen/TargetInstrInfo.h"
42
43 using namespace llvm;
44
45 #define DEBUG_TYPE "riscv-opt-w-instrs"
46 #define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions"
47
48 STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions");
49 STATISTIC(NumTransformedToWInstrs,
50 "Number of instructions transformed to W-ops");
51
52 static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal",
53 cl::desc("Disable removal of sext.w"),
54 cl::init(false), cl::Hidden);
55 static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix",
56 cl::desc("Disable strip W suffix"),
57 cl::init(false), cl::Hidden);
58
59 namespace {
60
61 class RISCVOptWInstrs : public MachineFunctionPass {
62 public:
63 static char ID;
64
RISCVOptWInstrs()65 RISCVOptWInstrs() : MachineFunctionPass(ID) {}
66
67 bool runOnMachineFunction(MachineFunction &MF) override;
68 bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII,
69 const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
70 bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII,
71 const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
72 bool appendWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII,
73 const RISCVSubtarget &ST, MachineRegisterInfo &MRI);
74
getAnalysisUsage(AnalysisUsage & AU) const75 void getAnalysisUsage(AnalysisUsage &AU) const override {
76 AU.setPreservesCFG();
77 MachineFunctionPass::getAnalysisUsage(AU);
78 }
79
getPassName() const80 StringRef getPassName() const override { return RISCV_OPT_W_INSTRS_NAME; }
81 };
82
83 } // end anonymous namespace
84
85 char RISCVOptWInstrs::ID = 0;
INITIALIZE_PASS(RISCVOptWInstrs,DEBUG_TYPE,RISCV_OPT_W_INSTRS_NAME,false,false)86 INITIALIZE_PASS(RISCVOptWInstrs, DEBUG_TYPE, RISCV_OPT_W_INSTRS_NAME, false,
87 false)
88
89 FunctionPass *llvm::createRISCVOptWInstrsPass() {
90 return new RISCVOptWInstrs();
91 }
92
vectorPseudoHasAllNBitUsers(const MachineOperand & UserOp,unsigned Bits)93 static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp,
94 unsigned Bits) {
95 const MachineInstr &MI = *UserOp.getParent();
96 unsigned MCOpcode = RISCV::getRVVMCOpcode(MI.getOpcode());
97
98 if (!MCOpcode)
99 return false;
100
101 const MCInstrDesc &MCID = MI.getDesc();
102 const uint64_t TSFlags = MCID.TSFlags;
103 if (!RISCVII::hasSEWOp(TSFlags))
104 return false;
105 assert(RISCVII::hasVLOp(TSFlags));
106 const unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MCID)).getImm();
107
108 if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID))
109 return false;
110
111 auto NumDemandedBits =
112 RISCV::getVectorLowDemandedScalarBits(MCOpcode, Log2SEW);
113 return NumDemandedBits && Bits >= *NumDemandedBits;
114 }
115
116 // Checks if all users only demand the lower \p OrigBits of the original
117 // instruction's result.
118 // TODO: handle multiple interdependent transformations
hasAllNBitUsers(const MachineInstr & OrigMI,const RISCVSubtarget & ST,const MachineRegisterInfo & MRI,unsigned OrigBits)119 static bool hasAllNBitUsers(const MachineInstr &OrigMI,
120 const RISCVSubtarget &ST,
121 const MachineRegisterInfo &MRI, unsigned OrigBits) {
122
123 SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited;
124 SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist;
125
126 Worklist.emplace_back(&OrigMI, OrigBits);
127
128 while (!Worklist.empty()) {
129 auto P = Worklist.pop_back_val();
130 const MachineInstr *MI = P.first;
131 unsigned Bits = P.second;
132
133 if (!Visited.insert(P).second)
134 continue;
135
136 // Only handle instructions with one def.
137 if (MI->getNumExplicitDefs() != 1)
138 return false;
139
140 Register DestReg = MI->getOperand(0).getReg();
141 if (!DestReg.isVirtual())
142 return false;
143
144 for (auto &UserOp : MRI.use_nodbg_operands(DestReg)) {
145 const MachineInstr *UserMI = UserOp.getParent();
146 unsigned OpIdx = UserOp.getOperandNo();
147
148 switch (UserMI->getOpcode()) {
149 default:
150 if (vectorPseudoHasAllNBitUsers(UserOp, Bits))
151 break;
152 return false;
153
154 case RISCV::ADDIW:
155 case RISCV::ADDW:
156 case RISCV::DIVUW:
157 case RISCV::DIVW:
158 case RISCV::MULW:
159 case RISCV::REMUW:
160 case RISCV::REMW:
161 case RISCV::SLLW:
162 case RISCV::SRAIW:
163 case RISCV::SRAW:
164 case RISCV::SRLIW:
165 case RISCV::SRLW:
166 case RISCV::SUBW:
167 case RISCV::ROLW:
168 case RISCV::RORW:
169 case RISCV::RORIW:
170 case RISCV::CLZW:
171 case RISCV::CTZW:
172 case RISCV::CPOPW:
173 case RISCV::SLLI_UW:
174 case RISCV::FMV_W_X:
175 case RISCV::FCVT_H_W:
176 case RISCV::FCVT_H_W_INX:
177 case RISCV::FCVT_H_WU:
178 case RISCV::FCVT_H_WU_INX:
179 case RISCV::FCVT_S_W:
180 case RISCV::FCVT_S_W_INX:
181 case RISCV::FCVT_S_WU:
182 case RISCV::FCVT_S_WU_INX:
183 case RISCV::FCVT_D_W:
184 case RISCV::FCVT_D_W_INX:
185 case RISCV::FCVT_D_WU:
186 case RISCV::FCVT_D_WU_INX:
187 if (Bits >= 32)
188 break;
189 return false;
190
191 case RISCV::SEXT_B:
192 case RISCV::PACKH:
193 if (Bits >= 8)
194 break;
195 return false;
196 case RISCV::SEXT_H:
197 case RISCV::FMV_H_X:
198 case RISCV::ZEXT_H_RV32:
199 case RISCV::ZEXT_H_RV64:
200 case RISCV::PACKW:
201 if (Bits >= 16)
202 break;
203 return false;
204
205 case RISCV::PACK:
206 if (Bits >= (ST.getXLen() / 2))
207 break;
208 return false;
209
210 case RISCV::SRLI: {
211 // If we are shifting right by less than Bits, and users don't demand
212 // any bits that were shifted into [Bits-1:0], then we can consider this
213 // as an N-Bit user.
214 unsigned ShAmt = UserMI->getOperand(2).getImm();
215 if (Bits > ShAmt) {
216 Worklist.emplace_back(UserMI, Bits - ShAmt);
217 break;
218 }
219 return false;
220 }
221
222 // these overwrite higher input bits, otherwise the lower word of output
223 // depends only on the lower word of input. So check their uses read W.
224 case RISCV::SLLI: {
225 unsigned ShAmt = UserMI->getOperand(2).getImm();
226 if (Bits >= (ST.getXLen() - ShAmt))
227 break;
228 Worklist.emplace_back(UserMI, Bits + ShAmt);
229 break;
230 }
231 case RISCV::SLLIW: {
232 unsigned ShAmt = UserMI->getOperand(2).getImm();
233 if (Bits >= 32 - ShAmt)
234 break;
235 Worklist.emplace_back(UserMI, Bits + ShAmt);
236 break;
237 }
238
239 case RISCV::ANDI: {
240 uint64_t Imm = UserMI->getOperand(2).getImm();
241 if (Bits >= (unsigned)llvm::bit_width(Imm))
242 break;
243 Worklist.emplace_back(UserMI, Bits);
244 break;
245 }
246 case RISCV::ORI: {
247 uint64_t Imm = UserMI->getOperand(2).getImm();
248 if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm))
249 break;
250 Worklist.emplace_back(UserMI, Bits);
251 break;
252 }
253
254 case RISCV::SLL:
255 case RISCV::BSET:
256 case RISCV::BCLR:
257 case RISCV::BINV:
258 // Operand 2 is the shift amount which uses log2(xlen) bits.
259 if (OpIdx == 2) {
260 if (Bits >= Log2_32(ST.getXLen()))
261 break;
262 return false;
263 }
264 Worklist.emplace_back(UserMI, Bits);
265 break;
266
267 case RISCV::SRA:
268 case RISCV::SRL:
269 case RISCV::ROL:
270 case RISCV::ROR:
271 // Operand 2 is the shift amount which uses 6 bits.
272 if (OpIdx == 2 && Bits >= Log2_32(ST.getXLen()))
273 break;
274 return false;
275
276 case RISCV::ADD_UW:
277 case RISCV::SH1ADD_UW:
278 case RISCV::SH2ADD_UW:
279 case RISCV::SH3ADD_UW:
280 // Operand 1 is implicitly zero extended.
281 if (OpIdx == 1 && Bits >= 32)
282 break;
283 Worklist.emplace_back(UserMI, Bits);
284 break;
285
286 case RISCV::BEXTI:
287 if (UserMI->getOperand(2).getImm() >= Bits)
288 return false;
289 break;
290
291 case RISCV::SB:
292 // The first argument is the value to store.
293 if (OpIdx == 0 && Bits >= 8)
294 break;
295 return false;
296 case RISCV::SH:
297 // The first argument is the value to store.
298 if (OpIdx == 0 && Bits >= 16)
299 break;
300 return false;
301 case RISCV::SW:
302 // The first argument is the value to store.
303 if (OpIdx == 0 && Bits >= 32)
304 break;
305 return false;
306
307 // For these, lower word of output in these operations, depends only on
308 // the lower word of input. So, we check all uses only read lower word.
309 case RISCV::COPY:
310 case RISCV::PHI:
311
312 case RISCV::ADD:
313 case RISCV::ADDI:
314 case RISCV::AND:
315 case RISCV::MUL:
316 case RISCV::OR:
317 case RISCV::SUB:
318 case RISCV::XOR:
319 case RISCV::XORI:
320
321 case RISCV::ANDN:
322 case RISCV::CLMUL:
323 case RISCV::ORN:
324 case RISCV::SH1ADD:
325 case RISCV::SH2ADD:
326 case RISCV::SH3ADD:
327 case RISCV::XNOR:
328 case RISCV::BSETI:
329 case RISCV::BCLRI:
330 case RISCV::BINVI:
331 Worklist.emplace_back(UserMI, Bits);
332 break;
333
334 case RISCV::BREV8:
335 case RISCV::ORC_B:
336 // BREV8 and ORC_B work on bytes. Round Bits down to the nearest byte.
337 Worklist.emplace_back(UserMI, alignDown(Bits, 8));
338 break;
339
340 case RISCV::PseudoCCMOVGPR:
341 case RISCV::PseudoCCMOVGPRNoX0:
342 // Either operand 4 or operand 5 is returned by this instruction. If
343 // only the lower word of the result is used, then only the lower word
344 // of operand 4 and 5 is used.
345 if (OpIdx != 4 && OpIdx != 5)
346 return false;
347 Worklist.emplace_back(UserMI, Bits);
348 break;
349
350 case RISCV::CZERO_EQZ:
351 case RISCV::CZERO_NEZ:
352 case RISCV::VT_MASKC:
353 case RISCV::VT_MASKCN:
354 if (OpIdx != 1)
355 return false;
356 Worklist.emplace_back(UserMI, Bits);
357 break;
358 }
359 }
360 }
361
362 return true;
363 }
364
hasAllWUsers(const MachineInstr & OrigMI,const RISCVSubtarget & ST,const MachineRegisterInfo & MRI)365 static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST,
366 const MachineRegisterInfo &MRI) {
367 return hasAllNBitUsers(OrigMI, ST, MRI, 32);
368 }
369
370 // This function returns true if the machine instruction always outputs a value
371 // where bits 63:32 match bit 31.
isSignExtendingOpW(const MachineInstr & MI,unsigned OpNo)372 static bool isSignExtendingOpW(const MachineInstr &MI, unsigned OpNo) {
373 uint64_t TSFlags = MI.getDesc().TSFlags;
374
375 // Instructions that can be determined from opcode are marked in tablegen.
376 if (TSFlags & RISCVII::IsSignExtendingOpWMask)
377 return true;
378
379 // Special cases that require checking operands.
380 switch (MI.getOpcode()) {
381 // shifting right sufficiently makes the value 32-bit sign-extended
382 case RISCV::SRAI:
383 return MI.getOperand(2).getImm() >= 32;
384 case RISCV::SRLI:
385 return MI.getOperand(2).getImm() > 32;
386 // The LI pattern ADDI rd, X0, imm is sign extended.
387 case RISCV::ADDI:
388 return MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0;
389 // An ANDI with an 11 bit immediate will zero bits 63:11.
390 case RISCV::ANDI:
391 return isUInt<11>(MI.getOperand(2).getImm());
392 // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11.
393 case RISCV::ORI:
394 return !isUInt<11>(MI.getOperand(2).getImm());
395 // A bseti with X0 is sign extended if the immediate is less than 31.
396 case RISCV::BSETI:
397 return MI.getOperand(2).getImm() < 31 &&
398 MI.getOperand(1).getReg() == RISCV::X0;
399 // Copying from X0 produces zero.
400 case RISCV::COPY:
401 return MI.getOperand(1).getReg() == RISCV::X0;
402 // Ignore the scratch register destination.
403 case RISCV::PseudoAtomicLoadNand32:
404 return OpNo == 0;
405 case RISCV::PseudoVMV_X_S: {
406 // vmv.x.s has at least 33 sign bits if log2(sew) <= 5.
407 int64_t Log2SEW = MI.getOperand(2).getImm();
408 assert(Log2SEW >= 3 && Log2SEW <= 6 && "Unexpected Log2SEW");
409 return Log2SEW <= 5;
410 }
411 }
412
413 return false;
414 }
415
isSignExtendedW(Register SrcReg,const RISCVSubtarget & ST,const MachineRegisterInfo & MRI,SmallPtrSetImpl<MachineInstr * > & FixableDef)416 static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST,
417 const MachineRegisterInfo &MRI,
418 SmallPtrSetImpl<MachineInstr *> &FixableDef) {
419 SmallSet<Register, 4> Visited;
420 SmallVector<Register, 4> Worklist;
421
422 auto AddRegToWorkList = [&](Register SrcReg) {
423 if (!SrcReg.isVirtual())
424 return false;
425 Worklist.push_back(SrcReg);
426 return true;
427 };
428
429 if (!AddRegToWorkList(SrcReg))
430 return false;
431
432 while (!Worklist.empty()) {
433 Register Reg = Worklist.pop_back_val();
434
435 // If we already visited this register, we don't need to check it again.
436 if (!Visited.insert(Reg).second)
437 continue;
438
439 MachineInstr *MI = MRI.getVRegDef(Reg);
440 if (!MI)
441 continue;
442
443 int OpNo = MI->findRegisterDefOperandIdx(Reg, /*TRI=*/nullptr);
444 assert(OpNo != -1 && "Couldn't find register");
445
446 // If this is a sign extending operation we don't need to look any further.
447 if (isSignExtendingOpW(*MI, OpNo))
448 continue;
449
450 // Is this an instruction that propagates sign extend?
451 switch (MI->getOpcode()) {
452 default:
453 // Unknown opcode, give up.
454 return false;
455 case RISCV::COPY: {
456 const MachineFunction *MF = MI->getMF();
457 const RISCVMachineFunctionInfo *RVFI =
458 MF->getInfo<RISCVMachineFunctionInfo>();
459
460 // If this is the entry block and the register is livein, see if we know
461 // it is sign extended.
462 if (MI->getParent() == &MF->front()) {
463 Register VReg = MI->getOperand(0).getReg();
464 if (MF->getRegInfo().isLiveIn(VReg) && RVFI->isSExt32Register(VReg))
465 continue;
466 }
467
468 Register CopySrcReg = MI->getOperand(1).getReg();
469 if (CopySrcReg == RISCV::X10) {
470 // For a method return value, we check the ZExt/SExt flags in attribute.
471 // We assume the following code sequence for method call.
472 // PseudoCALL @bar, ...
473 // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2
474 // %0:gpr = COPY $x10
475 //
476 // We use the PseudoCall to look up the IR function being called to find
477 // its return attributes.
478 const MachineBasicBlock *MBB = MI->getParent();
479 auto II = MI->getIterator();
480 if (II == MBB->instr_begin() ||
481 (--II)->getOpcode() != RISCV::ADJCALLSTACKUP)
482 return false;
483
484 const MachineInstr &CallMI = *(--II);
485 if (!CallMI.isCall() || !CallMI.getOperand(0).isGlobal())
486 return false;
487
488 auto *CalleeFn =
489 dyn_cast_if_present<Function>(CallMI.getOperand(0).getGlobal());
490 if (!CalleeFn)
491 return false;
492
493 auto *IntTy = dyn_cast<IntegerType>(CalleeFn->getReturnType());
494 if (!IntTy)
495 return false;
496
497 const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs();
498 unsigned BitWidth = IntTy->getBitWidth();
499 if ((BitWidth <= 32 && Attrs.hasAttribute(Attribute::SExt)) ||
500 (BitWidth < 32 && Attrs.hasAttribute(Attribute::ZExt)))
501 continue;
502 }
503
504 if (!AddRegToWorkList(CopySrcReg))
505 return false;
506
507 break;
508 }
509
510 // For these, we just need to check if the 1st operand is sign extended.
511 case RISCV::BCLRI:
512 case RISCV::BINVI:
513 case RISCV::BSETI:
514 if (MI->getOperand(2).getImm() >= 31)
515 return false;
516 [[fallthrough]];
517 case RISCV::REM:
518 case RISCV::ANDI:
519 case RISCV::ORI:
520 case RISCV::XORI:
521 // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R.
522 // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1
523 // Logical operations use a sign extended 12-bit immediate.
524 if (!AddRegToWorkList(MI->getOperand(1).getReg()))
525 return false;
526
527 break;
528 case RISCV::PseudoCCADDW:
529 case RISCV::PseudoCCADDIW:
530 case RISCV::PseudoCCSUBW:
531 case RISCV::PseudoCCSLLW:
532 case RISCV::PseudoCCSRLW:
533 case RISCV::PseudoCCSRAW:
534 case RISCV::PseudoCCSLLIW:
535 case RISCV::PseudoCCSRLIW:
536 case RISCV::PseudoCCSRAIW:
537 // Returns operand 4 or an ADDW/SUBW/etc. of operands 5 and 6. We only
538 // need to check if operand 4 is sign extended.
539 if (!AddRegToWorkList(MI->getOperand(4).getReg()))
540 return false;
541 break;
542 case RISCV::REMU:
543 case RISCV::AND:
544 case RISCV::OR:
545 case RISCV::XOR:
546 case RISCV::ANDN:
547 case RISCV::ORN:
548 case RISCV::XNOR:
549 case RISCV::MAX:
550 case RISCV::MAXU:
551 case RISCV::MIN:
552 case RISCV::MINU:
553 case RISCV::PseudoCCMOVGPR:
554 case RISCV::PseudoCCMOVGPRNoX0:
555 case RISCV::PseudoCCAND:
556 case RISCV::PseudoCCOR:
557 case RISCV::PseudoCCXOR:
558 case RISCV::PHI: {
559 // If all incoming values are sign-extended, the output of AND, OR, XOR,
560 // MIN, MAX, or PHI is also sign-extended.
561
562 // The input registers for PHI are operand 1, 3, ...
563 // The input registers for PseudoCCMOVGPR(NoX0) are 4 and 5.
564 // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6.
565 // The input registers for others are operand 1 and 2.
566 unsigned B = 1, E = 3, D = 1;
567 switch (MI->getOpcode()) {
568 case RISCV::PHI:
569 E = MI->getNumOperands();
570 D = 2;
571 break;
572 case RISCV::PseudoCCMOVGPR:
573 case RISCV::PseudoCCMOVGPRNoX0:
574 B = 4;
575 E = 6;
576 break;
577 case RISCV::PseudoCCAND:
578 case RISCV::PseudoCCOR:
579 case RISCV::PseudoCCXOR:
580 B = 4;
581 E = 7;
582 break;
583 }
584
585 for (unsigned I = B; I != E; I += D) {
586 if (!MI->getOperand(I).isReg())
587 return false;
588
589 if (!AddRegToWorkList(MI->getOperand(I).getReg()))
590 return false;
591 }
592
593 break;
594 }
595
596 case RISCV::CZERO_EQZ:
597 case RISCV::CZERO_NEZ:
598 case RISCV::VT_MASKC:
599 case RISCV::VT_MASKCN:
600 // Instructions return zero or operand 1. Result is sign extended if
601 // operand 1 is sign extended.
602 if (!AddRegToWorkList(MI->getOperand(1).getReg()))
603 return false;
604 break;
605
606 case RISCV::ADDI: {
607 if (MI->getOperand(1).isReg() && MI->getOperand(1).getReg().isVirtual()) {
608 if (MachineInstr *SrcMI = MRI.getVRegDef(MI->getOperand(1).getReg())) {
609 if (SrcMI->getOpcode() == RISCV::LUI &&
610 SrcMI->getOperand(1).isImm()) {
611 uint64_t Imm = SrcMI->getOperand(1).getImm();
612 Imm = SignExtend64<32>(Imm << 12);
613 Imm += (uint64_t)MI->getOperand(2).getImm();
614 if (isInt<32>(Imm))
615 continue;
616 }
617 }
618 }
619
620 if (hasAllWUsers(*MI, ST, MRI)) {
621 FixableDef.insert(MI);
622 break;
623 }
624 return false;
625 }
626
627 // With these opcode, we can "fix" them with the W-version
628 // if we know all users of the result only rely on bits 31:0
629 case RISCV::SLLI:
630 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
631 if (MI->getOperand(2).getImm() >= 32)
632 return false;
633 [[fallthrough]];
634 case RISCV::ADD:
635 case RISCV::LD:
636 case RISCV::LWU:
637 case RISCV::MUL:
638 case RISCV::SUB:
639 if (hasAllWUsers(*MI, ST, MRI)) {
640 FixableDef.insert(MI);
641 break;
642 }
643 return false;
644 }
645 }
646
647 // If we get here, then every node we visited produces a sign extended value
648 // or propagated sign extended values. So the result must be sign extended.
649 return true;
650 }
651
getWOp(unsigned Opcode)652 static unsigned getWOp(unsigned Opcode) {
653 switch (Opcode) {
654 case RISCV::ADDI:
655 return RISCV::ADDIW;
656 case RISCV::ADD:
657 return RISCV::ADDW;
658 case RISCV::LD:
659 case RISCV::LWU:
660 return RISCV::LW;
661 case RISCV::MUL:
662 return RISCV::MULW;
663 case RISCV::SLLI:
664 return RISCV::SLLIW;
665 case RISCV::SUB:
666 return RISCV::SUBW;
667 default:
668 llvm_unreachable("Unexpected opcode for replacement with W variant");
669 }
670 }
671
removeSExtWInstrs(MachineFunction & MF,const RISCVInstrInfo & TII,const RISCVSubtarget & ST,MachineRegisterInfo & MRI)672 bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF,
673 const RISCVInstrInfo &TII,
674 const RISCVSubtarget &ST,
675 MachineRegisterInfo &MRI) {
676 if (DisableSExtWRemoval)
677 return false;
678
679 bool MadeChange = false;
680 for (MachineBasicBlock &MBB : MF) {
681 for (MachineInstr &MI : llvm::make_early_inc_range(MBB)) {
682 // We're looking for the sext.w pattern ADDIW rd, rs1, 0.
683 if (!RISCVInstrInfo::isSEXT_W(MI))
684 continue;
685
686 Register SrcReg = MI.getOperand(1).getReg();
687
688 SmallPtrSet<MachineInstr *, 4> FixableDefs;
689
690 // If all users only use the lower bits, this sext.w is redundant.
691 // Or if all definitions reaching MI sign-extend their output,
692 // then sext.w is redundant.
693 if (!hasAllWUsers(MI, ST, MRI) &&
694 !isSignExtendedW(SrcReg, ST, MRI, FixableDefs))
695 continue;
696
697 Register DstReg = MI.getOperand(0).getReg();
698 if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg)))
699 continue;
700
701 // Convert Fixable instructions to their W versions.
702 for (MachineInstr *Fixable : FixableDefs) {
703 LLVM_DEBUG(dbgs() << "Replacing " << *Fixable);
704 Fixable->setDesc(TII.get(getWOp(Fixable->getOpcode())));
705 Fixable->clearFlag(MachineInstr::MIFlag::NoSWrap);
706 Fixable->clearFlag(MachineInstr::MIFlag::NoUWrap);
707 Fixable->clearFlag(MachineInstr::MIFlag::IsExact);
708 LLVM_DEBUG(dbgs() << " with " << *Fixable);
709 ++NumTransformedToWInstrs;
710 }
711
712 LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n");
713 MRI.replaceRegWith(DstReg, SrcReg);
714 MRI.clearKillFlags(SrcReg);
715 MI.eraseFromParent();
716 ++NumRemovedSExtW;
717 MadeChange = true;
718 }
719 }
720
721 return MadeChange;
722 }
723
stripWSuffixes(MachineFunction & MF,const RISCVInstrInfo & TII,const RISCVSubtarget & ST,MachineRegisterInfo & MRI)724 bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF,
725 const RISCVInstrInfo &TII,
726 const RISCVSubtarget &ST,
727 MachineRegisterInfo &MRI) {
728 bool MadeChange = false;
729 for (MachineBasicBlock &MBB : MF) {
730 for (MachineInstr &MI : MBB) {
731 unsigned Opc;
732 switch (MI.getOpcode()) {
733 default:
734 continue;
735 case RISCV::ADDW: Opc = RISCV::ADD; break;
736 case RISCV::ADDIW: Opc = RISCV::ADDI; break;
737 case RISCV::MULW: Opc = RISCV::MUL; break;
738 case RISCV::SLLIW: Opc = RISCV::SLLI; break;
739 }
740
741 if (hasAllWUsers(MI, ST, MRI)) {
742 MI.setDesc(TII.get(Opc));
743 MadeChange = true;
744 }
745 }
746 }
747
748 return MadeChange;
749 }
750
appendWSuffixes(MachineFunction & MF,const RISCVInstrInfo & TII,const RISCVSubtarget & ST,MachineRegisterInfo & MRI)751 bool RISCVOptWInstrs::appendWSuffixes(MachineFunction &MF,
752 const RISCVInstrInfo &TII,
753 const RISCVSubtarget &ST,
754 MachineRegisterInfo &MRI) {
755 bool MadeChange = false;
756 for (MachineBasicBlock &MBB : MF) {
757 for (MachineInstr &MI : MBB) {
758 unsigned WOpc;
759 // TODO: Add more?
760 switch (MI.getOpcode()) {
761 default:
762 continue;
763 case RISCV::ADD:
764 WOpc = RISCV::ADDW;
765 break;
766 case RISCV::ADDI:
767 WOpc = RISCV::ADDIW;
768 break;
769 case RISCV::SUB:
770 WOpc = RISCV::SUBW;
771 break;
772 case RISCV::MUL:
773 WOpc = RISCV::MULW;
774 break;
775 case RISCV::SLLI:
776 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits
777 if (MI.getOperand(2).getImm() >= 32)
778 continue;
779 WOpc = RISCV::SLLIW;
780 break;
781 case RISCV::LD:
782 case RISCV::LWU:
783 WOpc = RISCV::LW;
784 break;
785 }
786
787 if (hasAllWUsers(MI, ST, MRI)) {
788 LLVM_DEBUG(dbgs() << "Replacing " << MI);
789 MI.setDesc(TII.get(WOpc));
790 MI.clearFlag(MachineInstr::MIFlag::NoSWrap);
791 MI.clearFlag(MachineInstr::MIFlag::NoUWrap);
792 MI.clearFlag(MachineInstr::MIFlag::IsExact);
793 LLVM_DEBUG(dbgs() << " with " << MI);
794 ++NumTransformedToWInstrs;
795 MadeChange = true;
796 }
797 }
798 }
799
800 return MadeChange;
801 }
802
runOnMachineFunction(MachineFunction & MF)803 bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) {
804 if (skipFunction(MF.getFunction()))
805 return false;
806
807 MachineRegisterInfo &MRI = MF.getRegInfo();
808 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>();
809 const RISCVInstrInfo &TII = *ST.getInstrInfo();
810
811 if (!ST.is64Bit())
812 return false;
813
814 bool MadeChange = false;
815 MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI);
816
817 if (!(DisableStripWSuffix || ST.preferWInst()))
818 MadeChange |= stripWSuffixes(MF, TII, ST, MRI);
819
820 if (ST.preferWInst())
821 MadeChange |= appendWSuffixes(MF, TII, ST, MRI);
822
823 return MadeChange;
824 }
825