1 //===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===---------------------------------------------------------------------===// 8 // 9 // This pass does some optimizations for *W instructions at the MI level. 10 // 11 // First it removes unneeded sext.w instructions. Either because the sign 12 // extended bits aren't consumed or because the input was already sign extended 13 // by an earlier instruction. 14 // 15 // Then: 16 // 1. Unless explicit disabled or the target prefers instructions with W suffix, 17 // it removes the -w suffix from opw instructions whenever all users are 18 // dependent only on the lower word of the result of the instruction. 19 // The cases handled are: 20 // * addw because c.add has a larger register encoding than c.addw. 21 // * addiw because it helps reduce test differences between RV32 and RV64 22 // w/o being a pessimization. 23 // * mulw because c.mulw doesn't exist but c.mul does (w/ zcb) 24 // * slliw because c.slliw doesn't exist and c.slli does 25 // 26 // 2. Or if explicit enabled or the target prefers instructions with W suffix, 27 // it adds the W suffix to the instruction whenever all users are dependent 28 // only on the lower word of the result of the instruction. 29 // The cases handled are: 30 // * add/addi/sub/mul. 31 // * slli with imm < 32. 32 // * ld/lwu. 33 //===---------------------------------------------------------------------===// 34 35 #include "RISCV.h" 36 #include "RISCVMachineFunctionInfo.h" 37 #include "RISCVSubtarget.h" 38 #include "llvm/ADT/SmallSet.h" 39 #include "llvm/ADT/Statistic.h" 40 #include "llvm/CodeGen/MachineFunctionPass.h" 41 #include "llvm/CodeGen/TargetInstrInfo.h" 42 43 using namespace llvm; 44 45 #define DEBUG_TYPE "riscv-opt-w-instrs" 46 #define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions" 47 48 STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions"); 49 STATISTIC(NumTransformedToWInstrs, 50 "Number of instructions transformed to W-ops"); 51 52 static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal", 53 cl::desc("Disable removal of sext.w"), 54 cl::init(false), cl::Hidden); 55 static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix", 56 cl::desc("Disable strip W suffix"), 57 cl::init(false), cl::Hidden); 58 59 namespace { 60 61 class RISCVOptWInstrs : public MachineFunctionPass { 62 public: 63 static char ID; 64 65 RISCVOptWInstrs() : MachineFunctionPass(ID) {} 66 67 bool runOnMachineFunction(MachineFunction &MF) override; 68 bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII, 69 const RISCVSubtarget &ST, MachineRegisterInfo &MRI); 70 bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII, 71 const RISCVSubtarget &ST, MachineRegisterInfo &MRI); 72 bool appendWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII, 73 const RISCVSubtarget &ST, MachineRegisterInfo &MRI); 74 75 void getAnalysisUsage(AnalysisUsage &AU) const override { 76 AU.setPreservesCFG(); 77 MachineFunctionPass::getAnalysisUsage(AU); 78 } 79 80 StringRef getPassName() const override { return RISCV_OPT_W_INSTRS_NAME; } 81 }; 82 83 } // end anonymous namespace 84 85 char RISCVOptWInstrs::ID = 0; 86 INITIALIZE_PASS(RISCVOptWInstrs, DEBUG_TYPE, RISCV_OPT_W_INSTRS_NAME, false, 87 false) 88 89 FunctionPass *llvm::createRISCVOptWInstrsPass() { 90 return new RISCVOptWInstrs(); 91 } 92 93 static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp, 94 unsigned Bits) { 95 const MachineInstr &MI = *UserOp.getParent(); 96 unsigned MCOpcode = RISCV::getRVVMCOpcode(MI.getOpcode()); 97 98 if (!MCOpcode) 99 return false; 100 101 const MCInstrDesc &MCID = MI.getDesc(); 102 const uint64_t TSFlags = MCID.TSFlags; 103 if (!RISCVII::hasSEWOp(TSFlags)) 104 return false; 105 assert(RISCVII::hasVLOp(TSFlags)); 106 const unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MCID)).getImm(); 107 108 if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID)) 109 return false; 110 111 auto NumDemandedBits = 112 RISCV::getVectorLowDemandedScalarBits(MCOpcode, Log2SEW); 113 return NumDemandedBits && Bits >= *NumDemandedBits; 114 } 115 116 // Checks if all users only demand the lower \p OrigBits of the original 117 // instruction's result. 118 // TODO: handle multiple interdependent transformations 119 static bool hasAllNBitUsers(const MachineInstr &OrigMI, 120 const RISCVSubtarget &ST, 121 const MachineRegisterInfo &MRI, unsigned OrigBits) { 122 123 SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited; 124 SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist; 125 126 Worklist.push_back(std::make_pair(&OrigMI, OrigBits)); 127 128 while (!Worklist.empty()) { 129 auto P = Worklist.pop_back_val(); 130 const MachineInstr *MI = P.first; 131 unsigned Bits = P.second; 132 133 if (!Visited.insert(P).second) 134 continue; 135 136 // Only handle instructions with one def. 137 if (MI->getNumExplicitDefs() != 1) 138 return false; 139 140 Register DestReg = MI->getOperand(0).getReg(); 141 if (!DestReg.isVirtual()) 142 return false; 143 144 for (auto &UserOp : MRI.use_nodbg_operands(DestReg)) { 145 const MachineInstr *UserMI = UserOp.getParent(); 146 unsigned OpIdx = UserOp.getOperandNo(); 147 148 switch (UserMI->getOpcode()) { 149 default: 150 if (vectorPseudoHasAllNBitUsers(UserOp, Bits)) 151 break; 152 return false; 153 154 case RISCV::ADDIW: 155 case RISCV::ADDW: 156 case RISCV::DIVUW: 157 case RISCV::DIVW: 158 case RISCV::MULW: 159 case RISCV::REMUW: 160 case RISCV::REMW: 161 case RISCV::SLLIW: 162 case RISCV::SLLW: 163 case RISCV::SRAIW: 164 case RISCV::SRAW: 165 case RISCV::SRLIW: 166 case RISCV::SRLW: 167 case RISCV::SUBW: 168 case RISCV::ROLW: 169 case RISCV::RORW: 170 case RISCV::RORIW: 171 case RISCV::CLZW: 172 case RISCV::CTZW: 173 case RISCV::CPOPW: 174 case RISCV::SLLI_UW: 175 case RISCV::FMV_W_X: 176 case RISCV::FCVT_H_W: 177 case RISCV::FCVT_H_WU: 178 case RISCV::FCVT_S_W: 179 case RISCV::FCVT_S_WU: 180 case RISCV::FCVT_D_W: 181 case RISCV::FCVT_D_WU: 182 if (Bits >= 32) 183 break; 184 return false; 185 case RISCV::SEXT_B: 186 case RISCV::PACKH: 187 if (Bits >= 8) 188 break; 189 return false; 190 case RISCV::SEXT_H: 191 case RISCV::FMV_H_X: 192 case RISCV::ZEXT_H_RV32: 193 case RISCV::ZEXT_H_RV64: 194 case RISCV::PACKW: 195 if (Bits >= 16) 196 break; 197 return false; 198 199 case RISCV::PACK: 200 if (Bits >= (ST.getXLen() / 2)) 201 break; 202 return false; 203 204 case RISCV::SRLI: { 205 // If we are shifting right by less than Bits, and users don't demand 206 // any bits that were shifted into [Bits-1:0], then we can consider this 207 // as an N-Bit user. 208 unsigned ShAmt = UserMI->getOperand(2).getImm(); 209 if (Bits > ShAmt) { 210 Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt)); 211 break; 212 } 213 return false; 214 } 215 216 // these overwrite higher input bits, otherwise the lower word of output 217 // depends only on the lower word of input. So check their uses read W. 218 case RISCV::SLLI: 219 if (Bits >= (ST.getXLen() - UserMI->getOperand(2).getImm())) 220 break; 221 Worklist.push_back(std::make_pair(UserMI, Bits)); 222 break; 223 case RISCV::ANDI: { 224 uint64_t Imm = UserMI->getOperand(2).getImm(); 225 if (Bits >= (unsigned)llvm::bit_width(Imm)) 226 break; 227 Worklist.push_back(std::make_pair(UserMI, Bits)); 228 break; 229 } 230 case RISCV::ORI: { 231 uint64_t Imm = UserMI->getOperand(2).getImm(); 232 if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm)) 233 break; 234 Worklist.push_back(std::make_pair(UserMI, Bits)); 235 break; 236 } 237 238 case RISCV::SLL: 239 case RISCV::BSET: 240 case RISCV::BCLR: 241 case RISCV::BINV: 242 // Operand 2 is the shift amount which uses log2(xlen) bits. 243 if (OpIdx == 2) { 244 if (Bits >= Log2_32(ST.getXLen())) 245 break; 246 return false; 247 } 248 Worklist.push_back(std::make_pair(UserMI, Bits)); 249 break; 250 251 case RISCV::SRA: 252 case RISCV::SRL: 253 case RISCV::ROL: 254 case RISCV::ROR: 255 // Operand 2 is the shift amount which uses 6 bits. 256 if (OpIdx == 2 && Bits >= Log2_32(ST.getXLen())) 257 break; 258 return false; 259 260 case RISCV::ADD_UW: 261 case RISCV::SH1ADD_UW: 262 case RISCV::SH2ADD_UW: 263 case RISCV::SH3ADD_UW: 264 // Operand 1 is implicitly zero extended. 265 if (OpIdx == 1 && Bits >= 32) 266 break; 267 Worklist.push_back(std::make_pair(UserMI, Bits)); 268 break; 269 270 case RISCV::BEXTI: 271 if (UserMI->getOperand(2).getImm() >= Bits) 272 return false; 273 break; 274 275 case RISCV::SB: 276 // The first argument is the value to store. 277 if (OpIdx == 0 && Bits >= 8) 278 break; 279 return false; 280 case RISCV::SH: 281 // The first argument is the value to store. 282 if (OpIdx == 0 && Bits >= 16) 283 break; 284 return false; 285 case RISCV::SW: 286 // The first argument is the value to store. 287 if (OpIdx == 0 && Bits >= 32) 288 break; 289 return false; 290 291 // For these, lower word of output in these operations, depends only on 292 // the lower word of input. So, we check all uses only read lower word. 293 case RISCV::COPY: 294 case RISCV::PHI: 295 296 case RISCV::ADD: 297 case RISCV::ADDI: 298 case RISCV::AND: 299 case RISCV::MUL: 300 case RISCV::OR: 301 case RISCV::SUB: 302 case RISCV::XOR: 303 case RISCV::XORI: 304 305 case RISCV::ANDN: 306 case RISCV::BREV8: 307 case RISCV::CLMUL: 308 case RISCV::ORC_B: 309 case RISCV::ORN: 310 case RISCV::SH1ADD: 311 case RISCV::SH2ADD: 312 case RISCV::SH3ADD: 313 case RISCV::XNOR: 314 case RISCV::BSETI: 315 case RISCV::BCLRI: 316 case RISCV::BINVI: 317 Worklist.push_back(std::make_pair(UserMI, Bits)); 318 break; 319 320 case RISCV::PseudoCCMOVGPR: 321 // Either operand 4 or operand 5 is returned by this instruction. If 322 // only the lower word of the result is used, then only the lower word 323 // of operand 4 and 5 is used. 324 if (OpIdx != 4 && OpIdx != 5) 325 return false; 326 Worklist.push_back(std::make_pair(UserMI, Bits)); 327 break; 328 329 case RISCV::CZERO_EQZ: 330 case RISCV::CZERO_NEZ: 331 case RISCV::VT_MASKC: 332 case RISCV::VT_MASKCN: 333 if (OpIdx != 1) 334 return false; 335 Worklist.push_back(std::make_pair(UserMI, Bits)); 336 break; 337 } 338 } 339 } 340 341 return true; 342 } 343 344 static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST, 345 const MachineRegisterInfo &MRI) { 346 return hasAllNBitUsers(OrigMI, ST, MRI, 32); 347 } 348 349 // This function returns true if the machine instruction always outputs a value 350 // where bits 63:32 match bit 31. 351 static bool isSignExtendingOpW(const MachineInstr &MI, 352 const MachineRegisterInfo &MRI, unsigned OpNo) { 353 uint64_t TSFlags = MI.getDesc().TSFlags; 354 355 // Instructions that can be determined from opcode are marked in tablegen. 356 if (TSFlags & RISCVII::IsSignExtendingOpWMask) 357 return true; 358 359 // Special cases that require checking operands. 360 switch (MI.getOpcode()) { 361 // shifting right sufficiently makes the value 32-bit sign-extended 362 case RISCV::SRAI: 363 return MI.getOperand(2).getImm() >= 32; 364 case RISCV::SRLI: 365 return MI.getOperand(2).getImm() > 32; 366 // The LI pattern ADDI rd, X0, imm is sign extended. 367 case RISCV::ADDI: 368 return MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0; 369 // An ANDI with an 11 bit immediate will zero bits 63:11. 370 case RISCV::ANDI: 371 return isUInt<11>(MI.getOperand(2).getImm()); 372 // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11. 373 case RISCV::ORI: 374 return !isUInt<11>(MI.getOperand(2).getImm()); 375 // A bseti with X0 is sign extended if the immediate is less than 31. 376 case RISCV::BSETI: 377 return MI.getOperand(2).getImm() < 31 && 378 MI.getOperand(1).getReg() == RISCV::X0; 379 // Copying from X0 produces zero. 380 case RISCV::COPY: 381 return MI.getOperand(1).getReg() == RISCV::X0; 382 // Ignore the scratch register destination. 383 case RISCV::PseudoAtomicLoadNand32: 384 return OpNo == 0; 385 case RISCV::PseudoVMV_X_S: { 386 // vmv.x.s has at least 33 sign bits if log2(sew) <= 5. 387 int64_t Log2SEW = MI.getOperand(2).getImm(); 388 assert(Log2SEW >= 3 && Log2SEW <= 6 && "Unexpected Log2SEW"); 389 return Log2SEW <= 5; 390 } 391 } 392 393 return false; 394 } 395 396 static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST, 397 const MachineRegisterInfo &MRI, 398 SmallPtrSetImpl<MachineInstr *> &FixableDef) { 399 SmallSet<Register, 4> Visited; 400 SmallVector<Register, 4> Worklist; 401 402 auto AddRegToWorkList = [&](Register SrcReg) { 403 if (!SrcReg.isVirtual()) 404 return false; 405 Worklist.push_back(SrcReg); 406 return true; 407 }; 408 409 if (!AddRegToWorkList(SrcReg)) 410 return false; 411 412 while (!Worklist.empty()) { 413 Register Reg = Worklist.pop_back_val(); 414 415 // If we already visited this register, we don't need to check it again. 416 if (!Visited.insert(Reg).second) 417 continue; 418 419 MachineInstr *MI = MRI.getVRegDef(Reg); 420 if (!MI) 421 continue; 422 423 int OpNo = MI->findRegisterDefOperandIdx(Reg, /*TRI=*/nullptr); 424 assert(OpNo != -1 && "Couldn't find register"); 425 426 // If this is a sign extending operation we don't need to look any further. 427 if (isSignExtendingOpW(*MI, MRI, OpNo)) 428 continue; 429 430 // Is this an instruction that propagates sign extend? 431 switch (MI->getOpcode()) { 432 default: 433 // Unknown opcode, give up. 434 return false; 435 case RISCV::COPY: { 436 const MachineFunction *MF = MI->getMF(); 437 const RISCVMachineFunctionInfo *RVFI = 438 MF->getInfo<RISCVMachineFunctionInfo>(); 439 440 // If this is the entry block and the register is livein, see if we know 441 // it is sign extended. 442 if (MI->getParent() == &MF->front()) { 443 Register VReg = MI->getOperand(0).getReg(); 444 if (MF->getRegInfo().isLiveIn(VReg) && RVFI->isSExt32Register(VReg)) 445 continue; 446 } 447 448 Register CopySrcReg = MI->getOperand(1).getReg(); 449 if (CopySrcReg == RISCV::X10) { 450 // For a method return value, we check the ZExt/SExt flags in attribute. 451 // We assume the following code sequence for method call. 452 // PseudoCALL @bar, ... 453 // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2 454 // %0:gpr = COPY $x10 455 // 456 // We use the PseudoCall to look up the IR function being called to find 457 // its return attributes. 458 const MachineBasicBlock *MBB = MI->getParent(); 459 auto II = MI->getIterator(); 460 if (II == MBB->instr_begin() || 461 (--II)->getOpcode() != RISCV::ADJCALLSTACKUP) 462 return false; 463 464 const MachineInstr &CallMI = *(--II); 465 if (!CallMI.isCall() || !CallMI.getOperand(0).isGlobal()) 466 return false; 467 468 auto *CalleeFn = 469 dyn_cast_if_present<Function>(CallMI.getOperand(0).getGlobal()); 470 if (!CalleeFn) 471 return false; 472 473 auto *IntTy = dyn_cast<IntegerType>(CalleeFn->getReturnType()); 474 if (!IntTy) 475 return false; 476 477 const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs(); 478 unsigned BitWidth = IntTy->getBitWidth(); 479 if ((BitWidth <= 32 && Attrs.hasAttribute(Attribute::SExt)) || 480 (BitWidth < 32 && Attrs.hasAttribute(Attribute::ZExt))) 481 continue; 482 } 483 484 if (!AddRegToWorkList(CopySrcReg)) 485 return false; 486 487 break; 488 } 489 490 // For these, we just need to check if the 1st operand is sign extended. 491 case RISCV::BCLRI: 492 case RISCV::BINVI: 493 case RISCV::BSETI: 494 if (MI->getOperand(2).getImm() >= 31) 495 return false; 496 [[fallthrough]]; 497 case RISCV::REM: 498 case RISCV::ANDI: 499 case RISCV::ORI: 500 case RISCV::XORI: 501 // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R. 502 // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1 503 // Logical operations use a sign extended 12-bit immediate. 504 if (!AddRegToWorkList(MI->getOperand(1).getReg())) 505 return false; 506 507 break; 508 case RISCV::PseudoCCADDW: 509 case RISCV::PseudoCCADDIW: 510 case RISCV::PseudoCCSUBW: 511 case RISCV::PseudoCCSLLW: 512 case RISCV::PseudoCCSRLW: 513 case RISCV::PseudoCCSRAW: 514 case RISCV::PseudoCCSLLIW: 515 case RISCV::PseudoCCSRLIW: 516 case RISCV::PseudoCCSRAIW: 517 // Returns operand 4 or an ADDW/SUBW/etc. of operands 5 and 6. We only 518 // need to check if operand 4 is sign extended. 519 if (!AddRegToWorkList(MI->getOperand(4).getReg())) 520 return false; 521 break; 522 case RISCV::REMU: 523 case RISCV::AND: 524 case RISCV::OR: 525 case RISCV::XOR: 526 case RISCV::ANDN: 527 case RISCV::ORN: 528 case RISCV::XNOR: 529 case RISCV::MAX: 530 case RISCV::MAXU: 531 case RISCV::MIN: 532 case RISCV::MINU: 533 case RISCV::PseudoCCMOVGPR: 534 case RISCV::PseudoCCAND: 535 case RISCV::PseudoCCOR: 536 case RISCV::PseudoCCXOR: 537 case RISCV::PHI: { 538 // If all incoming values are sign-extended, the output of AND, OR, XOR, 539 // MIN, MAX, or PHI is also sign-extended. 540 541 // The input registers for PHI are operand 1, 3, ... 542 // The input registers for PseudoCCMOVGPR are 4 and 5. 543 // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6. 544 // The input registers for others are operand 1 and 2. 545 unsigned B = 1, E = 3, D = 1; 546 switch (MI->getOpcode()) { 547 case RISCV::PHI: 548 E = MI->getNumOperands(); 549 D = 2; 550 break; 551 case RISCV::PseudoCCMOVGPR: 552 B = 4; 553 E = 6; 554 break; 555 case RISCV::PseudoCCAND: 556 case RISCV::PseudoCCOR: 557 case RISCV::PseudoCCXOR: 558 B = 4; 559 E = 7; 560 break; 561 } 562 563 for (unsigned I = B; I != E; I += D) { 564 if (!MI->getOperand(I).isReg()) 565 return false; 566 567 if (!AddRegToWorkList(MI->getOperand(I).getReg())) 568 return false; 569 } 570 571 break; 572 } 573 574 case RISCV::CZERO_EQZ: 575 case RISCV::CZERO_NEZ: 576 case RISCV::VT_MASKC: 577 case RISCV::VT_MASKCN: 578 // Instructions return zero or operand 1. Result is sign extended if 579 // operand 1 is sign extended. 580 if (!AddRegToWorkList(MI->getOperand(1).getReg())) 581 return false; 582 break; 583 584 // With these opcode, we can "fix" them with the W-version 585 // if we know all users of the result only rely on bits 31:0 586 case RISCV::SLLI: 587 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits 588 if (MI->getOperand(2).getImm() >= 32) 589 return false; 590 [[fallthrough]]; 591 case RISCV::ADDI: 592 case RISCV::ADD: 593 case RISCV::LD: 594 case RISCV::LWU: 595 case RISCV::MUL: 596 case RISCV::SUB: 597 if (hasAllWUsers(*MI, ST, MRI)) { 598 FixableDef.insert(MI); 599 break; 600 } 601 return false; 602 } 603 } 604 605 // If we get here, then every node we visited produces a sign extended value 606 // or propagated sign extended values. So the result must be sign extended. 607 return true; 608 } 609 610 static unsigned getWOp(unsigned Opcode) { 611 switch (Opcode) { 612 case RISCV::ADDI: 613 return RISCV::ADDIW; 614 case RISCV::ADD: 615 return RISCV::ADDW; 616 case RISCV::LD: 617 case RISCV::LWU: 618 return RISCV::LW; 619 case RISCV::MUL: 620 return RISCV::MULW; 621 case RISCV::SLLI: 622 return RISCV::SLLIW; 623 case RISCV::SUB: 624 return RISCV::SUBW; 625 default: 626 llvm_unreachable("Unexpected opcode for replacement with W variant"); 627 } 628 } 629 630 bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF, 631 const RISCVInstrInfo &TII, 632 const RISCVSubtarget &ST, 633 MachineRegisterInfo &MRI) { 634 if (DisableSExtWRemoval) 635 return false; 636 637 bool MadeChange = false; 638 for (MachineBasicBlock &MBB : MF) { 639 for (MachineInstr &MI : llvm::make_early_inc_range(MBB)) { 640 // We're looking for the sext.w pattern ADDIW rd, rs1, 0. 641 if (!RISCV::isSEXT_W(MI)) 642 continue; 643 644 Register SrcReg = MI.getOperand(1).getReg(); 645 646 SmallPtrSet<MachineInstr *, 4> FixableDefs; 647 648 // If all users only use the lower bits, this sext.w is redundant. 649 // Or if all definitions reaching MI sign-extend their output, 650 // then sext.w is redundant. 651 if (!hasAllWUsers(MI, ST, MRI) && 652 !isSignExtendedW(SrcReg, ST, MRI, FixableDefs)) 653 continue; 654 655 Register DstReg = MI.getOperand(0).getReg(); 656 if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg))) 657 continue; 658 659 // Convert Fixable instructions to their W versions. 660 for (MachineInstr *Fixable : FixableDefs) { 661 LLVM_DEBUG(dbgs() << "Replacing " << *Fixable); 662 Fixable->setDesc(TII.get(getWOp(Fixable->getOpcode()))); 663 Fixable->clearFlag(MachineInstr::MIFlag::NoSWrap); 664 Fixable->clearFlag(MachineInstr::MIFlag::NoUWrap); 665 Fixable->clearFlag(MachineInstr::MIFlag::IsExact); 666 LLVM_DEBUG(dbgs() << " with " << *Fixable); 667 ++NumTransformedToWInstrs; 668 } 669 670 LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n"); 671 MRI.replaceRegWith(DstReg, SrcReg); 672 MRI.clearKillFlags(SrcReg); 673 MI.eraseFromParent(); 674 ++NumRemovedSExtW; 675 MadeChange = true; 676 } 677 } 678 679 return MadeChange; 680 } 681 682 bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF, 683 const RISCVInstrInfo &TII, 684 const RISCVSubtarget &ST, 685 MachineRegisterInfo &MRI) { 686 bool MadeChange = false; 687 for (MachineBasicBlock &MBB : MF) { 688 for (MachineInstr &MI : MBB) { 689 unsigned Opc; 690 switch (MI.getOpcode()) { 691 default: 692 continue; 693 case RISCV::ADDW: Opc = RISCV::ADD; break; 694 case RISCV::ADDIW: Opc = RISCV::ADDI; break; 695 case RISCV::MULW: Opc = RISCV::MUL; break; 696 case RISCV::SLLIW: Opc = RISCV::SLLI; break; 697 } 698 699 if (hasAllWUsers(MI, ST, MRI)) { 700 MI.setDesc(TII.get(Opc)); 701 MadeChange = true; 702 } 703 } 704 } 705 706 return MadeChange; 707 } 708 709 bool RISCVOptWInstrs::appendWSuffixes(MachineFunction &MF, 710 const RISCVInstrInfo &TII, 711 const RISCVSubtarget &ST, 712 MachineRegisterInfo &MRI) { 713 bool MadeChange = false; 714 for (MachineBasicBlock &MBB : MF) { 715 for (MachineInstr &MI : MBB) { 716 unsigned WOpc; 717 // TODO: Add more? 718 switch (MI.getOpcode()) { 719 default: 720 continue; 721 case RISCV::ADD: 722 WOpc = RISCV::ADDW; 723 break; 724 case RISCV::ADDI: 725 WOpc = RISCV::ADDIW; 726 break; 727 case RISCV::SUB: 728 WOpc = RISCV::SUBW; 729 break; 730 case RISCV::MUL: 731 WOpc = RISCV::MULW; 732 break; 733 case RISCV::SLLI: 734 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits 735 if (MI.getOperand(2).getImm() >= 32) 736 continue; 737 WOpc = RISCV::SLLIW; 738 break; 739 case RISCV::LD: 740 case RISCV::LWU: 741 WOpc = RISCV::LW; 742 break; 743 } 744 745 if (hasAllWUsers(MI, ST, MRI)) { 746 LLVM_DEBUG(dbgs() << "Replacing " << MI); 747 MI.setDesc(TII.get(WOpc)); 748 MI.clearFlag(MachineInstr::MIFlag::NoSWrap); 749 MI.clearFlag(MachineInstr::MIFlag::NoUWrap); 750 MI.clearFlag(MachineInstr::MIFlag::IsExact); 751 LLVM_DEBUG(dbgs() << " with " << MI); 752 ++NumTransformedToWInstrs; 753 MadeChange = true; 754 } 755 } 756 } 757 758 return MadeChange; 759 } 760 761 bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) { 762 if (skipFunction(MF.getFunction())) 763 return false; 764 765 MachineRegisterInfo &MRI = MF.getRegInfo(); 766 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); 767 const RISCVInstrInfo &TII = *ST.getInstrInfo(); 768 769 if (!ST.is64Bit()) 770 return false; 771 772 bool MadeChange = false; 773 MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI); 774 775 if (!(DisableStripWSuffix || ST.preferWInst())) 776 MadeChange |= stripWSuffixes(MF, TII, ST, MRI); 777 778 if (ST.preferWInst()) 779 MadeChange |= appendWSuffixes(MF, TII, ST, MRI); 780 781 return MadeChange; 782 } 783