Lines Matching +full:lower +full:- +full:case

1 //===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===//
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===---------------------------------------------------------------------===//
17 // it removes the -w suffix from opw instructions whenever all users are
18 // dependent only on the lower word of the result of the instruction.
28 // only on the lower word of the result of the instruction.
33 //===---------------------------------------------------------------------===//
45 #define DEBUG_TYPE "riscv-opt-w-instrs"
46 #define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions"
48 STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions");
50 "Number of instructions transformed to W-ops");
52 static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal",
55 static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix",
116 // Checks if all users only demand the lower \p OrigBits of the original
137 if (MI->getNumExplicitDefs() != 1) in hasAllNBitUsers()
140 Register DestReg = MI->getOperand(0).getReg(); in hasAllNBitUsers()
148 switch (UserMI->getOpcode()) { in hasAllNBitUsers()
154 case RISCV::ADDIW: in hasAllNBitUsers()
155 case RISCV::ADDW: in hasAllNBitUsers()
156 case RISCV::DIVUW: in hasAllNBitUsers()
157 case RISCV::DIVW: in hasAllNBitUsers()
158 case RISCV::MULW: in hasAllNBitUsers()
159 case RISCV::REMUW: in hasAllNBitUsers()
160 case RISCV::REMW: in hasAllNBitUsers()
161 case RISCV::SLLIW: in hasAllNBitUsers()
162 case RISCV::SLLW: in hasAllNBitUsers()
163 case RISCV::SRAIW: in hasAllNBitUsers()
164 case RISCV::SRAW: in hasAllNBitUsers()
165 case RISCV::SRLIW: in hasAllNBitUsers()
166 case RISCV::SRLW: in hasAllNBitUsers()
167 case RISCV::SUBW: in hasAllNBitUsers()
168 case RISCV::ROLW: in hasAllNBitUsers()
169 case RISCV::RORW: in hasAllNBitUsers()
170 case RISCV::RORIW: in hasAllNBitUsers()
171 case RISCV::CLZW: in hasAllNBitUsers()
172 case RISCV::CTZW: in hasAllNBitUsers()
173 case RISCV::CPOPW: in hasAllNBitUsers()
174 case RISCV::SLLI_UW: in hasAllNBitUsers()
175 case RISCV::FMV_W_X: in hasAllNBitUsers()
176 case RISCV::FCVT_H_W: in hasAllNBitUsers()
177 case RISCV::FCVT_H_WU: in hasAllNBitUsers()
178 case RISCV::FCVT_S_W: in hasAllNBitUsers()
179 case RISCV::FCVT_S_WU: in hasAllNBitUsers()
180 case RISCV::FCVT_D_W: in hasAllNBitUsers()
181 case RISCV::FCVT_D_WU: in hasAllNBitUsers()
185 case RISCV::SEXT_B: in hasAllNBitUsers()
186 case RISCV::PACKH: in hasAllNBitUsers()
190 case RISCV::SEXT_H: in hasAllNBitUsers()
191 case RISCV::FMV_H_X: in hasAllNBitUsers()
192 case RISCV::ZEXT_H_RV32: in hasAllNBitUsers()
193 case RISCV::ZEXT_H_RV64: in hasAllNBitUsers()
194 case RISCV::PACKW: in hasAllNBitUsers()
199 case RISCV::PACK: in hasAllNBitUsers()
204 case RISCV::SRLI: { in hasAllNBitUsers()
206 // any bits that were shifted into [Bits-1:0], then we can consider this in hasAllNBitUsers()
207 // as an N-Bit user. in hasAllNBitUsers()
208 unsigned ShAmt = UserMI->getOperand(2).getImm(); in hasAllNBitUsers()
210 Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt)); in hasAllNBitUsers()
216 // these overwrite higher input bits, otherwise the lower word of output in hasAllNBitUsers()
217 // depends only on the lower word of input. So check their uses read W. in hasAllNBitUsers()
218 case RISCV::SLLI: in hasAllNBitUsers()
219 if (Bits >= (ST.getXLen() - UserMI->getOperand(2).getImm())) in hasAllNBitUsers()
223 case RISCV::ANDI: { in hasAllNBitUsers()
224 uint64_t Imm = UserMI->getOperand(2).getImm(); in hasAllNBitUsers()
230 case RISCV::ORI: { in hasAllNBitUsers()
231 uint64_t Imm = UserMI->getOperand(2).getImm(); in hasAllNBitUsers()
238 case RISCV::SLL: in hasAllNBitUsers()
239 case RISCV::BSET: in hasAllNBitUsers()
240 case RISCV::BCLR: in hasAllNBitUsers()
241 case RISCV::BINV: in hasAllNBitUsers()
251 case RISCV::SRA: in hasAllNBitUsers()
252 case RISCV::SRL: in hasAllNBitUsers()
253 case RISCV::ROL: in hasAllNBitUsers()
254 case RISCV::ROR: in hasAllNBitUsers()
260 case RISCV::ADD_UW: in hasAllNBitUsers()
261 case RISCV::SH1ADD_UW: in hasAllNBitUsers()
262 case RISCV::SH2ADD_UW: in hasAllNBitUsers()
263 case RISCV::SH3ADD_UW: in hasAllNBitUsers()
270 case RISCV::BEXTI: in hasAllNBitUsers()
271 if (UserMI->getOperand(2).getImm() >= Bits) in hasAllNBitUsers()
275 case RISCV::SB: in hasAllNBitUsers()
280 case RISCV::SH: in hasAllNBitUsers()
285 case RISCV::SW: in hasAllNBitUsers()
291 // For these, lower word of output in these operations, depends only on in hasAllNBitUsers()
292 // the lower word of input. So, we check all uses only read lower word. in hasAllNBitUsers()
293 case RISCV::COPY: in hasAllNBitUsers()
294 case RISCV::PHI: in hasAllNBitUsers()
296 case RISCV::ADD: in hasAllNBitUsers()
297 case RISCV::ADDI: in hasAllNBitUsers()
298 case RISCV::AND: in hasAllNBitUsers()
299 case RISCV::MUL: in hasAllNBitUsers()
300 case RISCV::OR: in hasAllNBitUsers()
301 case RISCV::SUB: in hasAllNBitUsers()
302 case RISCV::XOR: in hasAllNBitUsers()
303 case RISCV::XORI: in hasAllNBitUsers()
305 case RISCV::ANDN: in hasAllNBitUsers()
306 case RISCV::BREV8: in hasAllNBitUsers()
307 case RISCV::CLMUL: in hasAllNBitUsers()
308 case RISCV::ORC_B: in hasAllNBitUsers()
309 case RISCV::ORN: in hasAllNBitUsers()
310 case RISCV::SH1ADD: in hasAllNBitUsers()
311 case RISCV::SH2ADD: in hasAllNBitUsers()
312 case RISCV::SH3ADD: in hasAllNBitUsers()
313 case RISCV::XNOR: in hasAllNBitUsers()
314 case RISCV::BSETI: in hasAllNBitUsers()
315 case RISCV::BCLRI: in hasAllNBitUsers()
316 case RISCV::BINVI: in hasAllNBitUsers()
320 case RISCV::PseudoCCMOVGPR: in hasAllNBitUsers()
322 // only the lower word of the result is used, then only the lower word in hasAllNBitUsers()
329 case RISCV::CZERO_EQZ: in hasAllNBitUsers()
330 case RISCV::CZERO_NEZ: in hasAllNBitUsers()
331 case RISCV::VT_MASKC: in hasAllNBitUsers()
332 case RISCV::VT_MASKCN: in hasAllNBitUsers()
361 // shifting right sufficiently makes the value 32-bit sign-extended in isSignExtendingOpW()
362 case RISCV::SRAI: in isSignExtendingOpW()
364 case RISCV::SRLI: in isSignExtendingOpW()
367 case RISCV::ADDI: in isSignExtendingOpW()
370 case RISCV::ANDI: in isSignExtendingOpW()
372 // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11. in isSignExtendingOpW()
373 case RISCV::ORI: in isSignExtendingOpW()
376 case RISCV::BSETI: in isSignExtendingOpW()
380 case RISCV::COPY: in isSignExtendingOpW()
383 case RISCV::PseudoAtomicLoadNand32: in isSignExtendingOpW()
385 case RISCV::PseudoVMV_X_S: { in isSignExtendingOpW()
423 int OpNo = MI->findRegisterDefOperandIdx(Reg, /*TRI=*/nullptr); in isSignExtendedW()
424 assert(OpNo != -1 && "Couldn't find register"); in isSignExtendedW()
431 switch (MI->getOpcode()) { in isSignExtendedW()
435 case RISCV::COPY: { in isSignExtendedW()
436 const MachineFunction *MF = MI->getMF(); in isSignExtendedW()
438 MF->getInfo<RISCVMachineFunctionInfo>(); in isSignExtendedW()
442 if (MI->getParent() == &MF->front()) { in isSignExtendedW()
443 Register VReg = MI->getOperand(0).getReg(); in isSignExtendedW()
444 if (MF->getRegInfo().isLiveIn(VReg) && RVFI->isSExt32Register(VReg)) in isSignExtendedW()
448 Register CopySrcReg = MI->getOperand(1).getReg(); in isSignExtendedW()
453 // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2 in isSignExtendedW()
458 const MachineBasicBlock *MBB = MI->getParent(); in isSignExtendedW()
459 auto II = MI->getIterator(); in isSignExtendedW()
460 if (II == MBB->instr_begin() || in isSignExtendedW()
461 (--II)->getOpcode() != RISCV::ADJCALLSTACKUP) in isSignExtendedW()
464 const MachineInstr &CallMI = *(--II); in isSignExtendedW()
473 auto *IntTy = dyn_cast<IntegerType>(CalleeFn->getReturnType()); in isSignExtendedW()
477 const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs(); in isSignExtendedW()
478 unsigned BitWidth = IntTy->getBitWidth(); in isSignExtendedW()
491 case RISCV::BCLRI: in isSignExtendedW()
492 case RISCV::BINVI: in isSignExtendedW()
493 case RISCV::BSETI: in isSignExtendedW()
494 if (MI->getOperand(2).getImm() >= 31) in isSignExtendedW()
497 case RISCV::REM: in isSignExtendedW()
498 case RISCV::ANDI: in isSignExtendedW()
499 case RISCV::ORI: in isSignExtendedW()
500 case RISCV::XORI: in isSignExtendedW()
501 // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R. in isSignExtendedW()
502 // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1 in isSignExtendedW()
503 // Logical operations use a sign extended 12-bit immediate. in isSignExtendedW()
504 if (!AddRegToWorkList(MI->getOperand(1).getReg())) in isSignExtendedW()
508 case RISCV::PseudoCCADDW: in isSignExtendedW()
509 case RISCV::PseudoCCADDIW: in isSignExtendedW()
510 case RISCV::PseudoCCSUBW: in isSignExtendedW()
511 case RISCV::PseudoCCSLLW: in isSignExtendedW()
512 case RISCV::PseudoCCSRLW: in isSignExtendedW()
513 case RISCV::PseudoCCSRAW: in isSignExtendedW()
514 case RISCV::PseudoCCSLLIW: in isSignExtendedW()
515 case RISCV::PseudoCCSRLIW: in isSignExtendedW()
516 case RISCV::PseudoCCSRAIW: in isSignExtendedW()
519 if (!AddRegToWorkList(MI->getOperand(4).getReg())) in isSignExtendedW()
522 case RISCV::REMU: in isSignExtendedW()
523 case RISCV::AND: in isSignExtendedW()
524 case RISCV::OR: in isSignExtendedW()
525 case RISCV::XOR: in isSignExtendedW()
526 case RISCV::ANDN: in isSignExtendedW()
527 case RISCV::ORN: in isSignExtendedW()
528 case RISCV::XNOR: in isSignExtendedW()
529 case RISCV::MAX: in isSignExtendedW()
530 case RISCV::MAXU: in isSignExtendedW()
531 case RISCV::MIN: in isSignExtendedW()
532 case RISCV::MINU: in isSignExtendedW()
533 case RISCV::PseudoCCMOVGPR: in isSignExtendedW()
534 case RISCV::PseudoCCAND: in isSignExtendedW()
535 case RISCV::PseudoCCOR: in isSignExtendedW()
536 case RISCV::PseudoCCXOR: in isSignExtendedW()
537 case RISCV::PHI: { in isSignExtendedW()
538 // If all incoming values are sign-extended, the output of AND, OR, XOR, in isSignExtendedW()
539 // MIN, MAX, or PHI is also sign-extended. in isSignExtendedW()
546 switch (MI->getOpcode()) { in isSignExtendedW()
547 case RISCV::PHI: in isSignExtendedW()
548 E = MI->getNumOperands(); in isSignExtendedW()
551 case RISCV::PseudoCCMOVGPR: in isSignExtendedW()
555 case RISCV::PseudoCCAND: in isSignExtendedW()
556 case RISCV::PseudoCCOR: in isSignExtendedW()
557 case RISCV::PseudoCCXOR: in isSignExtendedW()
564 if (!MI->getOperand(I).isReg()) in isSignExtendedW()
567 if (!AddRegToWorkList(MI->getOperand(I).getReg())) in isSignExtendedW()
574 case RISCV::CZERO_EQZ: in isSignExtendedW()
575 case RISCV::CZERO_NEZ: in isSignExtendedW()
576 case RISCV::VT_MASKC: in isSignExtendedW()
577 case RISCV::VT_MASKCN: in isSignExtendedW()
580 if (!AddRegToWorkList(MI->getOperand(1).getReg())) in isSignExtendedW()
584 // With these opcode, we can "fix" them with the W-version in isSignExtendedW()
586 case RISCV::SLLI: in isSignExtendedW()
588 if (MI->getOperand(2).getImm() >= 32) in isSignExtendedW()
591 case RISCV::ADDI: in isSignExtendedW()
592 case RISCV::ADD: in isSignExtendedW()
593 case RISCV::LD: in isSignExtendedW()
594 case RISCV::LWU: in isSignExtendedW()
595 case RISCV::MUL: in isSignExtendedW()
596 case RISCV::SUB: in isSignExtendedW()
612 case RISCV::ADDI: in getWOp()
614 case RISCV::ADD: in getWOp()
616 case RISCV::LD: in getWOp()
617 case RISCV::LWU: in getWOp()
619 case RISCV::MUL: in getWOp()
621 case RISCV::SLLI: in getWOp()
623 case RISCV::SUB: in getWOp()
648 // If all users only use the lower bits, this sext.w is redundant. in removeSExtWInstrs()
649 // Or if all definitions reaching MI sign-extend their output, in removeSExtWInstrs()
662 Fixable->setDesc(TII.get(getWOp(Fixable->getOpcode()))); in removeSExtWInstrs()
663 Fixable->clearFlag(MachineInstr::MIFlag::NoSWrap); in removeSExtWInstrs()
664 Fixable->clearFlag(MachineInstr::MIFlag::NoUWrap); in removeSExtWInstrs()
665 Fixable->clearFlag(MachineInstr::MIFlag::IsExact); in removeSExtWInstrs()
670 LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n"); in removeSExtWInstrs()
693 case RISCV::ADDW: Opc = RISCV::ADD; break; in stripWSuffixes()
694 case RISCV::ADDIW: Opc = RISCV::ADDI; break; in stripWSuffixes()
695 case RISCV::MULW: Opc = RISCV::MUL; break; in stripWSuffixes()
696 case RISCV::SLLIW: Opc = RISCV::SLLI; break; in stripWSuffixes()
721 case RISCV::ADD: in appendWSuffixes()
724 case RISCV::ADDI: in appendWSuffixes()
727 case RISCV::SUB: in appendWSuffixes()
730 case RISCV::MUL: in appendWSuffixes()
733 case RISCV::SLLI: in appendWSuffixes()
739 case RISCV::LD: in appendWSuffixes()
740 case RISCV::LWU: in appendWSuffixes()