1 //===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===---------------------------------------------------------------------===// 8 // 9 // This pass does some optimizations for *W instructions at the MI level. 10 // 11 // First it removes unneeded sext.w instructions. Either because the sign 12 // extended bits aren't consumed or because the input was already sign extended 13 // by an earlier instruction. 14 // 15 // Then it removes the -w suffix from opw instructions whenever all users are 16 // dependent only on the lower word of the result of the instruction. 17 // The cases handled are: 18 // * addw because c.add has a larger register encoding than c.addw. 19 // * addiw because it helps reduce test differences between RV32 and RV64 20 // w/o being a pessimization. 21 // * mulw because c.mulw doesn't exist but c.mul does (w/ zcb) 22 // * slliw because c.slliw doesn't exist and c.slli does 23 // 24 //===---------------------------------------------------------------------===// 25 26 #include "RISCV.h" 27 #include "RISCVMachineFunctionInfo.h" 28 #include "RISCVSubtarget.h" 29 #include "llvm/ADT/SmallSet.h" 30 #include "llvm/ADT/Statistic.h" 31 #include "llvm/CodeGen/MachineFunctionPass.h" 32 #include "llvm/CodeGen/TargetInstrInfo.h" 33 34 using namespace llvm; 35 36 #define DEBUG_TYPE "riscv-opt-w-instrs" 37 #define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions" 38 39 STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions"); 40 STATISTIC(NumTransformedToWInstrs, 41 "Number of instructions transformed to W-ops"); 42 43 static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal", 44 cl::desc("Disable removal of sext.w"), 45 cl::init(false), cl::Hidden); 46 static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix", 47 cl::desc("Disable strip W suffix"), 48 cl::init(false), cl::Hidden); 49 50 namespace { 51 52 class RISCVOptWInstrs : public MachineFunctionPass { 53 public: 54 static char ID; 55 56 RISCVOptWInstrs() : MachineFunctionPass(ID) {} 57 58 bool runOnMachineFunction(MachineFunction &MF) override; 59 bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII, 60 const RISCVSubtarget &ST, MachineRegisterInfo &MRI); 61 bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII, 62 const RISCVSubtarget &ST, MachineRegisterInfo &MRI); 63 64 void getAnalysisUsage(AnalysisUsage &AU) const override { 65 AU.setPreservesCFG(); 66 MachineFunctionPass::getAnalysisUsage(AU); 67 } 68 69 StringRef getPassName() const override { return RISCV_OPT_W_INSTRS_NAME; } 70 }; 71 72 } // end anonymous namespace 73 74 char RISCVOptWInstrs::ID = 0; 75 INITIALIZE_PASS(RISCVOptWInstrs, DEBUG_TYPE, RISCV_OPT_W_INSTRS_NAME, false, 76 false) 77 78 FunctionPass *llvm::createRISCVOptWInstrsPass() { 79 return new RISCVOptWInstrs(); 80 } 81 82 static bool vectorPseudoHasAllNBitUsers(const MachineOperand &UserOp, 83 unsigned Bits) { 84 const MachineInstr &MI = *UserOp.getParent(); 85 unsigned MCOpcode = RISCV::getRVVMCOpcode(MI.getOpcode()); 86 87 if (!MCOpcode) 88 return false; 89 90 const MCInstrDesc &MCID = MI.getDesc(); 91 const uint64_t TSFlags = MCID.TSFlags; 92 if (!RISCVII::hasSEWOp(TSFlags)) 93 return false; 94 assert(RISCVII::hasVLOp(TSFlags)); 95 const unsigned Log2SEW = MI.getOperand(RISCVII::getSEWOpNum(MCID)).getImm(); 96 97 if (UserOp.getOperandNo() == RISCVII::getVLOpNum(MCID)) 98 return false; 99 100 auto NumDemandedBits = 101 RISCV::getVectorLowDemandedScalarBits(MCOpcode, Log2SEW); 102 return NumDemandedBits && Bits >= *NumDemandedBits; 103 } 104 105 // Checks if all users only demand the lower \p OrigBits of the original 106 // instruction's result. 107 // TODO: handle multiple interdependent transformations 108 static bool hasAllNBitUsers(const MachineInstr &OrigMI, 109 const RISCVSubtarget &ST, 110 const MachineRegisterInfo &MRI, unsigned OrigBits) { 111 112 SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited; 113 SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist; 114 115 Worklist.push_back(std::make_pair(&OrigMI, OrigBits)); 116 117 while (!Worklist.empty()) { 118 auto P = Worklist.pop_back_val(); 119 const MachineInstr *MI = P.first; 120 unsigned Bits = P.second; 121 122 if (!Visited.insert(P).second) 123 continue; 124 125 // Only handle instructions with one def. 126 if (MI->getNumExplicitDefs() != 1) 127 return false; 128 129 Register DestReg = MI->getOperand(0).getReg(); 130 if (!DestReg.isVirtual()) 131 return false; 132 133 for (auto &UserOp : MRI.use_nodbg_operands(DestReg)) { 134 const MachineInstr *UserMI = UserOp.getParent(); 135 unsigned OpIdx = UserOp.getOperandNo(); 136 137 switch (UserMI->getOpcode()) { 138 default: 139 if (vectorPseudoHasAllNBitUsers(UserOp, Bits)) 140 break; 141 return false; 142 143 case RISCV::ADDIW: 144 case RISCV::ADDW: 145 case RISCV::DIVUW: 146 case RISCV::DIVW: 147 case RISCV::MULW: 148 case RISCV::REMUW: 149 case RISCV::REMW: 150 case RISCV::SLLIW: 151 case RISCV::SLLW: 152 case RISCV::SRAIW: 153 case RISCV::SRAW: 154 case RISCV::SRLIW: 155 case RISCV::SRLW: 156 case RISCV::SUBW: 157 case RISCV::ROLW: 158 case RISCV::RORW: 159 case RISCV::RORIW: 160 case RISCV::CLZW: 161 case RISCV::CTZW: 162 case RISCV::CPOPW: 163 case RISCV::SLLI_UW: 164 case RISCV::FMV_W_X: 165 case RISCV::FCVT_H_W: 166 case RISCV::FCVT_H_WU: 167 case RISCV::FCVT_S_W: 168 case RISCV::FCVT_S_WU: 169 case RISCV::FCVT_D_W: 170 case RISCV::FCVT_D_WU: 171 if (Bits >= 32) 172 break; 173 return false; 174 case RISCV::SEXT_B: 175 case RISCV::PACKH: 176 if (Bits >= 8) 177 break; 178 return false; 179 case RISCV::SEXT_H: 180 case RISCV::FMV_H_X: 181 case RISCV::ZEXT_H_RV32: 182 case RISCV::ZEXT_H_RV64: 183 case RISCV::PACKW: 184 if (Bits >= 16) 185 break; 186 return false; 187 188 case RISCV::PACK: 189 if (Bits >= (ST.getXLen() / 2)) 190 break; 191 return false; 192 193 case RISCV::SRLI: { 194 // If we are shifting right by less than Bits, and users don't demand 195 // any bits that were shifted into [Bits-1:0], then we can consider this 196 // as an N-Bit user. 197 unsigned ShAmt = UserMI->getOperand(2).getImm(); 198 if (Bits > ShAmt) { 199 Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt)); 200 break; 201 } 202 return false; 203 } 204 205 // these overwrite higher input bits, otherwise the lower word of output 206 // depends only on the lower word of input. So check their uses read W. 207 case RISCV::SLLI: 208 if (Bits >= (ST.getXLen() - UserMI->getOperand(2).getImm())) 209 break; 210 Worklist.push_back(std::make_pair(UserMI, Bits)); 211 break; 212 case RISCV::ANDI: { 213 uint64_t Imm = UserMI->getOperand(2).getImm(); 214 if (Bits >= (unsigned)llvm::bit_width(Imm)) 215 break; 216 Worklist.push_back(std::make_pair(UserMI, Bits)); 217 break; 218 } 219 case RISCV::ORI: { 220 uint64_t Imm = UserMI->getOperand(2).getImm(); 221 if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm)) 222 break; 223 Worklist.push_back(std::make_pair(UserMI, Bits)); 224 break; 225 } 226 227 case RISCV::SLL: 228 case RISCV::BSET: 229 case RISCV::BCLR: 230 case RISCV::BINV: 231 // Operand 2 is the shift amount which uses log2(xlen) bits. 232 if (OpIdx == 2) { 233 if (Bits >= Log2_32(ST.getXLen())) 234 break; 235 return false; 236 } 237 Worklist.push_back(std::make_pair(UserMI, Bits)); 238 break; 239 240 case RISCV::SRA: 241 case RISCV::SRL: 242 case RISCV::ROL: 243 case RISCV::ROR: 244 // Operand 2 is the shift amount which uses 6 bits. 245 if (OpIdx == 2 && Bits >= Log2_32(ST.getXLen())) 246 break; 247 return false; 248 249 case RISCV::ADD_UW: 250 case RISCV::SH1ADD_UW: 251 case RISCV::SH2ADD_UW: 252 case RISCV::SH3ADD_UW: 253 // Operand 1 is implicitly zero extended. 254 if (OpIdx == 1 && Bits >= 32) 255 break; 256 Worklist.push_back(std::make_pair(UserMI, Bits)); 257 break; 258 259 case RISCV::BEXTI: 260 if (UserMI->getOperand(2).getImm() >= Bits) 261 return false; 262 break; 263 264 case RISCV::SB: 265 // The first argument is the value to store. 266 if (OpIdx == 0 && Bits >= 8) 267 break; 268 return false; 269 case RISCV::SH: 270 // The first argument is the value to store. 271 if (OpIdx == 0 && Bits >= 16) 272 break; 273 return false; 274 case RISCV::SW: 275 // The first argument is the value to store. 276 if (OpIdx == 0 && Bits >= 32) 277 break; 278 return false; 279 280 // For these, lower word of output in these operations, depends only on 281 // the lower word of input. So, we check all uses only read lower word. 282 case RISCV::COPY: 283 case RISCV::PHI: 284 285 case RISCV::ADD: 286 case RISCV::ADDI: 287 case RISCV::AND: 288 case RISCV::MUL: 289 case RISCV::OR: 290 case RISCV::SUB: 291 case RISCV::XOR: 292 case RISCV::XORI: 293 294 case RISCV::ANDN: 295 case RISCV::BREV8: 296 case RISCV::CLMUL: 297 case RISCV::ORC_B: 298 case RISCV::ORN: 299 case RISCV::SH1ADD: 300 case RISCV::SH2ADD: 301 case RISCV::SH3ADD: 302 case RISCV::XNOR: 303 case RISCV::BSETI: 304 case RISCV::BCLRI: 305 case RISCV::BINVI: 306 Worklist.push_back(std::make_pair(UserMI, Bits)); 307 break; 308 309 case RISCV::PseudoCCMOVGPR: 310 // Either operand 4 or operand 5 is returned by this instruction. If 311 // only the lower word of the result is used, then only the lower word 312 // of operand 4 and 5 is used. 313 if (OpIdx != 4 && OpIdx != 5) 314 return false; 315 Worklist.push_back(std::make_pair(UserMI, Bits)); 316 break; 317 318 case RISCV::CZERO_EQZ: 319 case RISCV::CZERO_NEZ: 320 case RISCV::VT_MASKC: 321 case RISCV::VT_MASKCN: 322 if (OpIdx != 1) 323 return false; 324 Worklist.push_back(std::make_pair(UserMI, Bits)); 325 break; 326 } 327 } 328 } 329 330 return true; 331 } 332 333 static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST, 334 const MachineRegisterInfo &MRI) { 335 return hasAllNBitUsers(OrigMI, ST, MRI, 32); 336 } 337 338 // This function returns true if the machine instruction always outputs a value 339 // where bits 63:32 match bit 31. 340 static bool isSignExtendingOpW(const MachineInstr &MI, 341 const MachineRegisterInfo &MRI) { 342 uint64_t TSFlags = MI.getDesc().TSFlags; 343 344 // Instructions that can be determined from opcode are marked in tablegen. 345 if (TSFlags & RISCVII::IsSignExtendingOpWMask) 346 return true; 347 348 // Special cases that require checking operands. 349 switch (MI.getOpcode()) { 350 // shifting right sufficiently makes the value 32-bit sign-extended 351 case RISCV::SRAI: 352 return MI.getOperand(2).getImm() >= 32; 353 case RISCV::SRLI: 354 return MI.getOperand(2).getImm() > 32; 355 // The LI pattern ADDI rd, X0, imm is sign extended. 356 case RISCV::ADDI: 357 return MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0; 358 // An ANDI with an 11 bit immediate will zero bits 63:11. 359 case RISCV::ANDI: 360 return isUInt<11>(MI.getOperand(2).getImm()); 361 // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11. 362 case RISCV::ORI: 363 return !isUInt<11>(MI.getOperand(2).getImm()); 364 // A bseti with X0 is sign extended if the immediate is less than 31. 365 case RISCV::BSETI: 366 return MI.getOperand(2).getImm() < 31 && 367 MI.getOperand(1).getReg() == RISCV::X0; 368 // Copying from X0 produces zero. 369 case RISCV::COPY: 370 return MI.getOperand(1).getReg() == RISCV::X0; 371 case RISCV::PseudoAtomicLoadNand32: 372 return true; 373 case RISCV::PseudoVMV_X_S_MF8: 374 case RISCV::PseudoVMV_X_S_MF4: 375 case RISCV::PseudoVMV_X_S_MF2: 376 case RISCV::PseudoVMV_X_S_M1: 377 case RISCV::PseudoVMV_X_S_M2: 378 case RISCV::PseudoVMV_X_S_M4: 379 case RISCV::PseudoVMV_X_S_M8: { 380 // vmv.x.s has at least 33 sign bits if log2(sew) <= 5. 381 int64_t Log2SEW = MI.getOperand(2).getImm(); 382 assert(Log2SEW >= 3 && Log2SEW <= 6 && "Unexpected Log2SEW"); 383 return Log2SEW <= 5; 384 } 385 } 386 387 return false; 388 } 389 390 static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST, 391 const MachineRegisterInfo &MRI, 392 SmallPtrSetImpl<MachineInstr *> &FixableDef) { 393 394 SmallPtrSet<const MachineInstr *, 4> Visited; 395 SmallVector<MachineInstr *, 4> Worklist; 396 397 auto AddRegDefToWorkList = [&](Register SrcReg) { 398 if (!SrcReg.isVirtual()) 399 return false; 400 MachineInstr *SrcMI = MRI.getVRegDef(SrcReg); 401 if (!SrcMI) 402 return false; 403 // Code assumes the register is operand 0. 404 // TODO: Maybe the worklist should store register? 405 if (!SrcMI->getOperand(0).isReg() || 406 SrcMI->getOperand(0).getReg() != SrcReg) 407 return false; 408 // Add SrcMI to the worklist. 409 Worklist.push_back(SrcMI); 410 return true; 411 }; 412 413 if (!AddRegDefToWorkList(SrcReg)) 414 return false; 415 416 while (!Worklist.empty()) { 417 MachineInstr *MI = Worklist.pop_back_val(); 418 419 // If we already visited this instruction, we don't need to check it again. 420 if (!Visited.insert(MI).second) 421 continue; 422 423 // If this is a sign extending operation we don't need to look any further. 424 if (isSignExtendingOpW(*MI, MRI)) 425 continue; 426 427 // Is this an instruction that propagates sign extend? 428 switch (MI->getOpcode()) { 429 default: 430 // Unknown opcode, give up. 431 return false; 432 case RISCV::COPY: { 433 const MachineFunction *MF = MI->getMF(); 434 const RISCVMachineFunctionInfo *RVFI = 435 MF->getInfo<RISCVMachineFunctionInfo>(); 436 437 // If this is the entry block and the register is livein, see if we know 438 // it is sign extended. 439 if (MI->getParent() == &MF->front()) { 440 Register VReg = MI->getOperand(0).getReg(); 441 if (MF->getRegInfo().isLiveIn(VReg) && RVFI->isSExt32Register(VReg)) 442 continue; 443 } 444 445 Register CopySrcReg = MI->getOperand(1).getReg(); 446 if (CopySrcReg == RISCV::X10) { 447 // For a method return value, we check the ZExt/SExt flags in attribute. 448 // We assume the following code sequence for method call. 449 // PseudoCALL @bar, ... 450 // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2 451 // %0:gpr = COPY $x10 452 // 453 // We use the PseudoCall to look up the IR function being called to find 454 // its return attributes. 455 const MachineBasicBlock *MBB = MI->getParent(); 456 auto II = MI->getIterator(); 457 if (II == MBB->instr_begin() || 458 (--II)->getOpcode() != RISCV::ADJCALLSTACKUP) 459 return false; 460 461 const MachineInstr &CallMI = *(--II); 462 if (!CallMI.isCall() || !CallMI.getOperand(0).isGlobal()) 463 return false; 464 465 auto *CalleeFn = 466 dyn_cast_if_present<Function>(CallMI.getOperand(0).getGlobal()); 467 if (!CalleeFn) 468 return false; 469 470 auto *IntTy = dyn_cast<IntegerType>(CalleeFn->getReturnType()); 471 if (!IntTy) 472 return false; 473 474 const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs(); 475 unsigned BitWidth = IntTy->getBitWidth(); 476 if ((BitWidth <= 32 && Attrs.hasAttribute(Attribute::SExt)) || 477 (BitWidth < 32 && Attrs.hasAttribute(Attribute::ZExt))) 478 continue; 479 } 480 481 if (!AddRegDefToWorkList(CopySrcReg)) 482 return false; 483 484 break; 485 } 486 487 // For these, we just need to check if the 1st operand is sign extended. 488 case RISCV::BCLRI: 489 case RISCV::BINVI: 490 case RISCV::BSETI: 491 if (MI->getOperand(2).getImm() >= 31) 492 return false; 493 [[fallthrough]]; 494 case RISCV::REM: 495 case RISCV::ANDI: 496 case RISCV::ORI: 497 case RISCV::XORI: 498 // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R. 499 // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1 500 // Logical operations use a sign extended 12-bit immediate. 501 if (!AddRegDefToWorkList(MI->getOperand(1).getReg())) 502 return false; 503 504 break; 505 case RISCV::PseudoCCADDW: 506 case RISCV::PseudoCCADDIW: 507 case RISCV::PseudoCCSUBW: 508 case RISCV::PseudoCCSLLW: 509 case RISCV::PseudoCCSRLW: 510 case RISCV::PseudoCCSRAW: 511 case RISCV::PseudoCCSLLIW: 512 case RISCV::PseudoCCSRLIW: 513 case RISCV::PseudoCCSRAIW: 514 // Returns operand 4 or an ADDW/SUBW/etc. of operands 5 and 6. We only 515 // need to check if operand 4 is sign extended. 516 if (!AddRegDefToWorkList(MI->getOperand(4).getReg())) 517 return false; 518 break; 519 case RISCV::REMU: 520 case RISCV::AND: 521 case RISCV::OR: 522 case RISCV::XOR: 523 case RISCV::ANDN: 524 case RISCV::ORN: 525 case RISCV::XNOR: 526 case RISCV::MAX: 527 case RISCV::MAXU: 528 case RISCV::MIN: 529 case RISCV::MINU: 530 case RISCV::PseudoCCMOVGPR: 531 case RISCV::PseudoCCAND: 532 case RISCV::PseudoCCOR: 533 case RISCV::PseudoCCXOR: 534 case RISCV::PHI: { 535 // If all incoming values are sign-extended, the output of AND, OR, XOR, 536 // MIN, MAX, or PHI is also sign-extended. 537 538 // The input registers for PHI are operand 1, 3, ... 539 // The input registers for PseudoCCMOVGPR are 4 and 5. 540 // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6. 541 // The input registers for others are operand 1 and 2. 542 unsigned B = 1, E = 3, D = 1; 543 switch (MI->getOpcode()) { 544 case RISCV::PHI: 545 E = MI->getNumOperands(); 546 D = 2; 547 break; 548 case RISCV::PseudoCCMOVGPR: 549 B = 4; 550 E = 6; 551 break; 552 case RISCV::PseudoCCAND: 553 case RISCV::PseudoCCOR: 554 case RISCV::PseudoCCXOR: 555 B = 4; 556 E = 7; 557 break; 558 } 559 560 for (unsigned I = B; I != E; I += D) { 561 if (!MI->getOperand(I).isReg()) 562 return false; 563 564 if (!AddRegDefToWorkList(MI->getOperand(I).getReg())) 565 return false; 566 } 567 568 break; 569 } 570 571 case RISCV::CZERO_EQZ: 572 case RISCV::CZERO_NEZ: 573 case RISCV::VT_MASKC: 574 case RISCV::VT_MASKCN: 575 // Instructions return zero or operand 1. Result is sign extended if 576 // operand 1 is sign extended. 577 if (!AddRegDefToWorkList(MI->getOperand(1).getReg())) 578 return false; 579 break; 580 581 // With these opcode, we can "fix" them with the W-version 582 // if we know all users of the result only rely on bits 31:0 583 case RISCV::SLLI: 584 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits 585 if (MI->getOperand(2).getImm() >= 32) 586 return false; 587 [[fallthrough]]; 588 case RISCV::ADDI: 589 case RISCV::ADD: 590 case RISCV::LD: 591 case RISCV::LWU: 592 case RISCV::MUL: 593 case RISCV::SUB: 594 if (hasAllWUsers(*MI, ST, MRI)) { 595 FixableDef.insert(MI); 596 break; 597 } 598 return false; 599 } 600 } 601 602 // If we get here, then every node we visited produces a sign extended value 603 // or propagated sign extended values. So the result must be sign extended. 604 return true; 605 } 606 607 static unsigned getWOp(unsigned Opcode) { 608 switch (Opcode) { 609 case RISCV::ADDI: 610 return RISCV::ADDIW; 611 case RISCV::ADD: 612 return RISCV::ADDW; 613 case RISCV::LD: 614 case RISCV::LWU: 615 return RISCV::LW; 616 case RISCV::MUL: 617 return RISCV::MULW; 618 case RISCV::SLLI: 619 return RISCV::SLLIW; 620 case RISCV::SUB: 621 return RISCV::SUBW; 622 default: 623 llvm_unreachable("Unexpected opcode for replacement with W variant"); 624 } 625 } 626 627 bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF, 628 const RISCVInstrInfo &TII, 629 const RISCVSubtarget &ST, 630 MachineRegisterInfo &MRI) { 631 if (DisableSExtWRemoval) 632 return false; 633 634 bool MadeChange = false; 635 for (MachineBasicBlock &MBB : MF) { 636 for (MachineInstr &MI : llvm::make_early_inc_range(MBB)) { 637 // We're looking for the sext.w pattern ADDIW rd, rs1, 0. 638 if (!RISCV::isSEXT_W(MI)) 639 continue; 640 641 Register SrcReg = MI.getOperand(1).getReg(); 642 643 SmallPtrSet<MachineInstr *, 4> FixableDefs; 644 645 // If all users only use the lower bits, this sext.w is redundant. 646 // Or if all definitions reaching MI sign-extend their output, 647 // then sext.w is redundant. 648 if (!hasAllWUsers(MI, ST, MRI) && 649 !isSignExtendedW(SrcReg, ST, MRI, FixableDefs)) 650 continue; 651 652 Register DstReg = MI.getOperand(0).getReg(); 653 if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg))) 654 continue; 655 656 // Convert Fixable instructions to their W versions. 657 for (MachineInstr *Fixable : FixableDefs) { 658 LLVM_DEBUG(dbgs() << "Replacing " << *Fixable); 659 Fixable->setDesc(TII.get(getWOp(Fixable->getOpcode()))); 660 Fixable->clearFlag(MachineInstr::MIFlag::NoSWrap); 661 Fixable->clearFlag(MachineInstr::MIFlag::NoUWrap); 662 Fixable->clearFlag(MachineInstr::MIFlag::IsExact); 663 LLVM_DEBUG(dbgs() << " with " << *Fixable); 664 ++NumTransformedToWInstrs; 665 } 666 667 LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n"); 668 MRI.replaceRegWith(DstReg, SrcReg); 669 MRI.clearKillFlags(SrcReg); 670 MI.eraseFromParent(); 671 ++NumRemovedSExtW; 672 MadeChange = true; 673 } 674 } 675 676 return MadeChange; 677 } 678 679 bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF, 680 const RISCVInstrInfo &TII, 681 const RISCVSubtarget &ST, 682 MachineRegisterInfo &MRI) { 683 if (DisableStripWSuffix) 684 return false; 685 686 bool MadeChange = false; 687 for (MachineBasicBlock &MBB : MF) { 688 for (MachineInstr &MI : MBB) { 689 unsigned Opc; 690 switch (MI.getOpcode()) { 691 default: 692 continue; 693 case RISCV::ADDW: Opc = RISCV::ADD; break; 694 case RISCV::ADDIW: Opc = RISCV::ADDI; break; 695 case RISCV::MULW: Opc = RISCV::MUL; break; 696 case RISCV::SLLIW: Opc = RISCV::SLLI; break; 697 } 698 699 if (hasAllWUsers(MI, ST, MRI)) { 700 MI.setDesc(TII.get(Opc)); 701 MadeChange = true; 702 } 703 } 704 } 705 706 return MadeChange; 707 } 708 709 bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) { 710 if (skipFunction(MF.getFunction())) 711 return false; 712 713 MachineRegisterInfo &MRI = MF.getRegInfo(); 714 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); 715 const RISCVInstrInfo &TII = *ST.getInstrInfo(); 716 717 if (!ST.is64Bit()) 718 return false; 719 720 bool MadeChange = false; 721 MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI); 722 MadeChange |= stripWSuffixes(MF, TII, ST, MRI); 723 724 return MadeChange; 725 } 726