1 //===- RISCVOptWInstrs.cpp - MI W instruction optimizations ---------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===---------------------------------------------------------------------===// 8 // 9 // This pass does some optimizations for *W instructions at the MI level. 10 // 11 // First it removes unneeded sext.w instructions. Either because the sign 12 // extended bits aren't consumed or because the input was already sign extended 13 // by an earlier instruction. 14 // 15 // Then it removes the -w suffix from each addiw and slliw instructions 16 // whenever all users are dependent only on the lower word of the result of the 17 // instruction. We do this only for addiw, slliw, and mulw because the -w forms 18 // are less compressible. 19 // 20 //===---------------------------------------------------------------------===// 21 22 #include "RISCV.h" 23 #include "RISCVMachineFunctionInfo.h" 24 #include "RISCVSubtarget.h" 25 #include "llvm/ADT/Statistic.h" 26 #include "llvm/CodeGen/MachineFunctionPass.h" 27 #include "llvm/CodeGen/TargetInstrInfo.h" 28 29 using namespace llvm; 30 31 #define DEBUG_TYPE "riscv-opt-w-instrs" 32 #define RISCV_OPT_W_INSTRS_NAME "RISC-V Optimize W Instructions" 33 34 STATISTIC(NumRemovedSExtW, "Number of removed sign-extensions"); 35 STATISTIC(NumTransformedToWInstrs, 36 "Number of instructions transformed to W-ops"); 37 38 static cl::opt<bool> DisableSExtWRemoval("riscv-disable-sextw-removal", 39 cl::desc("Disable removal of sext.w"), 40 cl::init(false), cl::Hidden); 41 static cl::opt<bool> DisableStripWSuffix("riscv-disable-strip-w-suffix", 42 cl::desc("Disable strip W suffix"), 43 cl::init(false), cl::Hidden); 44 45 namespace { 46 47 class RISCVOptWInstrs : public MachineFunctionPass { 48 public: 49 static char ID; 50 51 RISCVOptWInstrs() : MachineFunctionPass(ID) { 52 initializeRISCVOptWInstrsPass(*PassRegistry::getPassRegistry()); 53 } 54 55 bool runOnMachineFunction(MachineFunction &MF) override; 56 bool removeSExtWInstrs(MachineFunction &MF, const RISCVInstrInfo &TII, 57 const RISCVSubtarget &ST, MachineRegisterInfo &MRI); 58 bool stripWSuffixes(MachineFunction &MF, const RISCVInstrInfo &TII, 59 const RISCVSubtarget &ST, MachineRegisterInfo &MRI); 60 61 void getAnalysisUsage(AnalysisUsage &AU) const override { 62 AU.setPreservesCFG(); 63 MachineFunctionPass::getAnalysisUsage(AU); 64 } 65 66 StringRef getPassName() const override { return RISCV_OPT_W_INSTRS_NAME; } 67 }; 68 69 } // end anonymous namespace 70 71 char RISCVOptWInstrs::ID = 0; 72 INITIALIZE_PASS(RISCVOptWInstrs, DEBUG_TYPE, RISCV_OPT_W_INSTRS_NAME, false, 73 false) 74 75 FunctionPass *llvm::createRISCVOptWInstrsPass() { 76 return new RISCVOptWInstrs(); 77 } 78 79 // Checks if all users only demand the lower \p OrigBits of the original 80 // instruction's result. 81 // TODO: handle multiple interdependent transformations 82 static bool hasAllNBitUsers(const MachineInstr &OrigMI, 83 const RISCVSubtarget &ST, 84 const MachineRegisterInfo &MRI, unsigned OrigBits) { 85 86 SmallSet<std::pair<const MachineInstr *, unsigned>, 4> Visited; 87 SmallVector<std::pair<const MachineInstr *, unsigned>, 4> Worklist; 88 89 Worklist.push_back(std::make_pair(&OrigMI, OrigBits)); 90 91 while (!Worklist.empty()) { 92 auto P = Worklist.pop_back_val(); 93 const MachineInstr *MI = P.first; 94 unsigned Bits = P.second; 95 96 if (!Visited.insert(P).second) 97 continue; 98 99 // Only handle instructions with one def. 100 if (MI->getNumExplicitDefs() != 1) 101 return false; 102 103 for (auto &UserOp : MRI.use_operands(MI->getOperand(0).getReg())) { 104 const MachineInstr *UserMI = UserOp.getParent(); 105 unsigned OpIdx = UserOp.getOperandNo(); 106 107 switch (UserMI->getOpcode()) { 108 default: 109 return false; 110 111 case RISCV::ADDIW: 112 case RISCV::ADDW: 113 case RISCV::DIVUW: 114 case RISCV::DIVW: 115 case RISCV::MULW: 116 case RISCV::REMUW: 117 case RISCV::REMW: 118 case RISCV::SLLIW: 119 case RISCV::SLLW: 120 case RISCV::SRAIW: 121 case RISCV::SRAW: 122 case RISCV::SRLIW: 123 case RISCV::SRLW: 124 case RISCV::SUBW: 125 case RISCV::ROLW: 126 case RISCV::RORW: 127 case RISCV::RORIW: 128 case RISCV::CLZW: 129 case RISCV::CTZW: 130 case RISCV::CPOPW: 131 case RISCV::SLLI_UW: 132 case RISCV::FMV_W_X: 133 case RISCV::FCVT_H_W: 134 case RISCV::FCVT_H_WU: 135 case RISCV::FCVT_S_W: 136 case RISCV::FCVT_S_WU: 137 case RISCV::FCVT_D_W: 138 case RISCV::FCVT_D_WU: 139 if (Bits >= 32) 140 break; 141 return false; 142 case RISCV::SEXT_B: 143 case RISCV::PACKH: 144 if (Bits >= 8) 145 break; 146 return false; 147 case RISCV::SEXT_H: 148 case RISCV::FMV_H_X: 149 case RISCV::ZEXT_H_RV32: 150 case RISCV::ZEXT_H_RV64: 151 case RISCV::PACKW: 152 if (Bits >= 16) 153 break; 154 return false; 155 156 case RISCV::PACK: 157 if (Bits >= (ST.getXLen() / 2)) 158 break; 159 return false; 160 161 case RISCV::SRLI: { 162 // If we are shifting right by less than Bits, and users don't demand 163 // any bits that were shifted into [Bits-1:0], then we can consider this 164 // as an N-Bit user. 165 unsigned ShAmt = UserMI->getOperand(2).getImm(); 166 if (Bits > ShAmt) { 167 Worklist.push_back(std::make_pair(UserMI, Bits - ShAmt)); 168 break; 169 } 170 return false; 171 } 172 173 // these overwrite higher input bits, otherwise the lower word of output 174 // depends only on the lower word of input. So check their uses read W. 175 case RISCV::SLLI: 176 if (Bits >= (ST.getXLen() - UserMI->getOperand(2).getImm())) 177 break; 178 Worklist.push_back(std::make_pair(UserMI, Bits)); 179 break; 180 case RISCV::ANDI: { 181 uint64_t Imm = UserMI->getOperand(2).getImm(); 182 if (Bits >= (unsigned)llvm::bit_width(Imm)) 183 break; 184 Worklist.push_back(std::make_pair(UserMI, Bits)); 185 break; 186 } 187 case RISCV::ORI: { 188 uint64_t Imm = UserMI->getOperand(2).getImm(); 189 if (Bits >= (unsigned)llvm::bit_width<uint64_t>(~Imm)) 190 break; 191 Worklist.push_back(std::make_pair(UserMI, Bits)); 192 break; 193 } 194 195 case RISCV::SLL: 196 case RISCV::BSET: 197 case RISCV::BCLR: 198 case RISCV::BINV: 199 // Operand 2 is the shift amount which uses log2(xlen) bits. 200 if (OpIdx == 2) { 201 if (Bits >= Log2_32(ST.getXLen())) 202 break; 203 return false; 204 } 205 Worklist.push_back(std::make_pair(UserMI, Bits)); 206 break; 207 208 case RISCV::SRA: 209 case RISCV::SRL: 210 case RISCV::ROL: 211 case RISCV::ROR: 212 // Operand 2 is the shift amount which uses 6 bits. 213 if (OpIdx == 2 && Bits >= Log2_32(ST.getXLen())) 214 break; 215 return false; 216 217 case RISCV::ADD_UW: 218 case RISCV::SH1ADD_UW: 219 case RISCV::SH2ADD_UW: 220 case RISCV::SH3ADD_UW: 221 // Operand 1 is implicitly zero extended. 222 if (OpIdx == 1 && Bits >= 32) 223 break; 224 Worklist.push_back(std::make_pair(UserMI, Bits)); 225 break; 226 227 case RISCV::BEXTI: 228 if (UserMI->getOperand(2).getImm() >= Bits) 229 return false; 230 break; 231 232 case RISCV::SB: 233 // The first argument is the value to store. 234 if (OpIdx == 0 && Bits >= 8) 235 break; 236 return false; 237 case RISCV::SH: 238 // The first argument is the value to store. 239 if (OpIdx == 0 && Bits >= 16) 240 break; 241 return false; 242 case RISCV::SW: 243 // The first argument is the value to store. 244 if (OpIdx == 0 && Bits >= 32) 245 break; 246 return false; 247 248 // For these, lower word of output in these operations, depends only on 249 // the lower word of input. So, we check all uses only read lower word. 250 case RISCV::COPY: 251 case RISCV::PHI: 252 253 case RISCV::ADD: 254 case RISCV::ADDI: 255 case RISCV::AND: 256 case RISCV::MUL: 257 case RISCV::OR: 258 case RISCV::SUB: 259 case RISCV::XOR: 260 case RISCV::XORI: 261 262 case RISCV::ANDN: 263 case RISCV::BREV8: 264 case RISCV::CLMUL: 265 case RISCV::ORC_B: 266 case RISCV::ORN: 267 case RISCV::SH1ADD: 268 case RISCV::SH2ADD: 269 case RISCV::SH3ADD: 270 case RISCV::XNOR: 271 case RISCV::BSETI: 272 case RISCV::BCLRI: 273 case RISCV::BINVI: 274 Worklist.push_back(std::make_pair(UserMI, Bits)); 275 break; 276 277 case RISCV::PseudoCCMOVGPR: 278 // Either operand 4 or operand 5 is returned by this instruction. If 279 // only the lower word of the result is used, then only the lower word 280 // of operand 4 and 5 is used. 281 if (OpIdx != 4 && OpIdx != 5) 282 return false; 283 Worklist.push_back(std::make_pair(UserMI, Bits)); 284 break; 285 286 case RISCV::VT_MASKC: 287 case RISCV::VT_MASKCN: 288 if (OpIdx != 1) 289 return false; 290 Worklist.push_back(std::make_pair(UserMI, Bits)); 291 break; 292 } 293 } 294 } 295 296 return true; 297 } 298 299 static bool hasAllWUsers(const MachineInstr &OrigMI, const RISCVSubtarget &ST, 300 const MachineRegisterInfo &MRI) { 301 return hasAllNBitUsers(OrigMI, ST, MRI, 32); 302 } 303 304 // This function returns true if the machine instruction always outputs a value 305 // where bits 63:32 match bit 31. 306 static bool isSignExtendingOpW(const MachineInstr &MI, 307 const MachineRegisterInfo &MRI) { 308 uint64_t TSFlags = MI.getDesc().TSFlags; 309 310 // Instructions that can be determined from opcode are marked in tablegen. 311 if (TSFlags & RISCVII::IsSignExtendingOpWMask) 312 return true; 313 314 // Special cases that require checking operands. 315 switch (MI.getOpcode()) { 316 // shifting right sufficiently makes the value 32-bit sign-extended 317 case RISCV::SRAI: 318 return MI.getOperand(2).getImm() >= 32; 319 case RISCV::SRLI: 320 return MI.getOperand(2).getImm() > 32; 321 // The LI pattern ADDI rd, X0, imm is sign extended. 322 case RISCV::ADDI: 323 return MI.getOperand(1).isReg() && MI.getOperand(1).getReg() == RISCV::X0; 324 // An ANDI with an 11 bit immediate will zero bits 63:11. 325 case RISCV::ANDI: 326 return isUInt<11>(MI.getOperand(2).getImm()); 327 // An ORI with an >11 bit immediate (negative 12-bit) will set bits 63:11. 328 case RISCV::ORI: 329 return !isUInt<11>(MI.getOperand(2).getImm()); 330 // Copying from X0 produces zero. 331 case RISCV::COPY: 332 return MI.getOperand(1).getReg() == RISCV::X0; 333 } 334 335 return false; 336 } 337 338 static bool isSignExtendedW(Register SrcReg, const RISCVSubtarget &ST, 339 const MachineRegisterInfo &MRI, 340 SmallPtrSetImpl<MachineInstr *> &FixableDef) { 341 342 SmallPtrSet<const MachineInstr *, 4> Visited; 343 SmallVector<MachineInstr *, 4> Worklist; 344 345 auto AddRegDefToWorkList = [&](Register SrcReg) { 346 if (!SrcReg.isVirtual()) 347 return false; 348 MachineInstr *SrcMI = MRI.getVRegDef(SrcReg); 349 if (!SrcMI) 350 return false; 351 // Add SrcMI to the worklist. 352 Worklist.push_back(SrcMI); 353 return true; 354 }; 355 356 if (!AddRegDefToWorkList(SrcReg)) 357 return false; 358 359 while (!Worklist.empty()) { 360 MachineInstr *MI = Worklist.pop_back_val(); 361 362 // If we already visited this instruction, we don't need to check it again. 363 if (!Visited.insert(MI).second) 364 continue; 365 366 // If this is a sign extending operation we don't need to look any further. 367 if (isSignExtendingOpW(*MI, MRI)) 368 continue; 369 370 // Is this an instruction that propagates sign extend? 371 switch (MI->getOpcode()) { 372 default: 373 // Unknown opcode, give up. 374 return false; 375 case RISCV::COPY: { 376 const MachineFunction *MF = MI->getMF(); 377 const RISCVMachineFunctionInfo *RVFI = 378 MF->getInfo<RISCVMachineFunctionInfo>(); 379 380 // If this is the entry block and the register is livein, see if we know 381 // it is sign extended. 382 if (MI->getParent() == &MF->front()) { 383 Register VReg = MI->getOperand(0).getReg(); 384 if (MF->getRegInfo().isLiveIn(VReg) && RVFI->isSExt32Register(VReg)) 385 continue; 386 } 387 388 Register CopySrcReg = MI->getOperand(1).getReg(); 389 if (CopySrcReg == RISCV::X10) { 390 // For a method return value, we check the ZExt/SExt flags in attribute. 391 // We assume the following code sequence for method call. 392 // PseudoCALL @bar, ... 393 // ADJCALLSTACKUP 0, 0, implicit-def dead $x2, implicit $x2 394 // %0:gpr = COPY $x10 395 // 396 // We use the PseudoCall to look up the IR function being called to find 397 // its return attributes. 398 const MachineBasicBlock *MBB = MI->getParent(); 399 auto II = MI->getIterator(); 400 if (II == MBB->instr_begin() || 401 (--II)->getOpcode() != RISCV::ADJCALLSTACKUP) 402 return false; 403 404 const MachineInstr &CallMI = *(--II); 405 if (!CallMI.isCall() || !CallMI.getOperand(0).isGlobal()) 406 return false; 407 408 auto *CalleeFn = 409 dyn_cast_if_present<Function>(CallMI.getOperand(0).getGlobal()); 410 if (!CalleeFn) 411 return false; 412 413 auto *IntTy = dyn_cast<IntegerType>(CalleeFn->getReturnType()); 414 if (!IntTy) 415 return false; 416 417 const AttributeSet &Attrs = CalleeFn->getAttributes().getRetAttrs(); 418 unsigned BitWidth = IntTy->getBitWidth(); 419 if ((BitWidth <= 32 && Attrs.hasAttribute(Attribute::SExt)) || 420 (BitWidth < 32 && Attrs.hasAttribute(Attribute::ZExt))) 421 continue; 422 } 423 424 if (!AddRegDefToWorkList(CopySrcReg)) 425 return false; 426 427 break; 428 } 429 430 // For these, we just need to check if the 1st operand is sign extended. 431 case RISCV::BCLRI: 432 case RISCV::BINVI: 433 case RISCV::BSETI: 434 if (MI->getOperand(2).getImm() >= 31) 435 return false; 436 [[fallthrough]]; 437 case RISCV::REM: 438 case RISCV::ANDI: 439 case RISCV::ORI: 440 case RISCV::XORI: 441 // |Remainder| is always <= |Dividend|. If D is 32-bit, then so is R. 442 // DIV doesn't work because of the edge case 0xf..f 8000 0000 / (long)-1 443 // Logical operations use a sign extended 12-bit immediate. 444 if (!AddRegDefToWorkList(MI->getOperand(1).getReg())) 445 return false; 446 447 break; 448 case RISCV::PseudoCCADDW: 449 case RISCV::PseudoCCSUBW: 450 // Returns operand 4 or an ADDW/SUBW of operands 5 and 6. We only need to 451 // check if operand 4 is sign extended. 452 if (!AddRegDefToWorkList(MI->getOperand(4).getReg())) 453 return false; 454 break; 455 case RISCV::REMU: 456 case RISCV::AND: 457 case RISCV::OR: 458 case RISCV::XOR: 459 case RISCV::ANDN: 460 case RISCV::ORN: 461 case RISCV::XNOR: 462 case RISCV::MAX: 463 case RISCV::MAXU: 464 case RISCV::MIN: 465 case RISCV::MINU: 466 case RISCV::PseudoCCMOVGPR: 467 case RISCV::PseudoCCAND: 468 case RISCV::PseudoCCOR: 469 case RISCV::PseudoCCXOR: 470 case RISCV::PHI: { 471 // If all incoming values are sign-extended, the output of AND, OR, XOR, 472 // MIN, MAX, or PHI is also sign-extended. 473 474 // The input registers for PHI are operand 1, 3, ... 475 // The input registers for PseudoCCMOVGPR are 4 and 5. 476 // The input registers for PseudoCCAND/OR/XOR are 4, 5, and 6. 477 // The input registers for others are operand 1 and 2. 478 unsigned B = 1, E = 3, D = 1; 479 switch (MI->getOpcode()) { 480 case RISCV::PHI: 481 E = MI->getNumOperands(); 482 D = 2; 483 break; 484 case RISCV::PseudoCCMOVGPR: 485 B = 4; 486 E = 6; 487 break; 488 case RISCV::PseudoCCAND: 489 case RISCV::PseudoCCOR: 490 case RISCV::PseudoCCXOR: 491 B = 4; 492 E = 7; 493 break; 494 } 495 496 for (unsigned I = B; I != E; I += D) { 497 if (!MI->getOperand(I).isReg()) 498 return false; 499 500 if (!AddRegDefToWorkList(MI->getOperand(I).getReg())) 501 return false; 502 } 503 504 break; 505 } 506 507 case RISCV::VT_MASKC: 508 case RISCV::VT_MASKCN: 509 // Instructions return zero or operand 1. Result is sign extended if 510 // operand 1 is sign extended. 511 if (!AddRegDefToWorkList(MI->getOperand(1).getReg())) 512 return false; 513 break; 514 515 // With these opcode, we can "fix" them with the W-version 516 // if we know all users of the result only rely on bits 31:0 517 case RISCV::SLLI: 518 // SLLIW reads the lowest 5 bits, while SLLI reads lowest 6 bits 519 if (MI->getOperand(2).getImm() >= 32) 520 return false; 521 [[fallthrough]]; 522 case RISCV::ADDI: 523 case RISCV::ADD: 524 case RISCV::LD: 525 case RISCV::LWU: 526 case RISCV::MUL: 527 case RISCV::SUB: 528 if (hasAllWUsers(*MI, ST, MRI)) { 529 FixableDef.insert(MI); 530 break; 531 } 532 return false; 533 } 534 } 535 536 // If we get here, then every node we visited produces a sign extended value 537 // or propagated sign extended values. So the result must be sign extended. 538 return true; 539 } 540 541 static unsigned getWOp(unsigned Opcode) { 542 switch (Opcode) { 543 case RISCV::ADDI: 544 return RISCV::ADDIW; 545 case RISCV::ADD: 546 return RISCV::ADDW; 547 case RISCV::LD: 548 case RISCV::LWU: 549 return RISCV::LW; 550 case RISCV::MUL: 551 return RISCV::MULW; 552 case RISCV::SLLI: 553 return RISCV::SLLIW; 554 case RISCV::SUB: 555 return RISCV::SUBW; 556 default: 557 llvm_unreachable("Unexpected opcode for replacement with W variant"); 558 } 559 } 560 561 bool RISCVOptWInstrs::removeSExtWInstrs(MachineFunction &MF, 562 const RISCVInstrInfo &TII, 563 const RISCVSubtarget &ST, 564 MachineRegisterInfo &MRI) { 565 if (DisableSExtWRemoval) 566 return false; 567 568 bool MadeChange = false; 569 for (MachineBasicBlock &MBB : MF) { 570 for (auto I = MBB.begin(), IE = MBB.end(); I != IE;) { 571 MachineInstr *MI = &*I++; 572 573 // We're looking for the sext.w pattern ADDIW rd, rs1, 0. 574 if (!RISCV::isSEXT_W(*MI)) 575 continue; 576 577 Register SrcReg = MI->getOperand(1).getReg(); 578 579 SmallPtrSet<MachineInstr *, 4> FixableDefs; 580 581 // If all users only use the lower bits, this sext.w is redundant. 582 // Or if all definitions reaching MI sign-extend their output, 583 // then sext.w is redundant. 584 if (!hasAllWUsers(*MI, ST, MRI) && 585 !isSignExtendedW(SrcReg, ST, MRI, FixableDefs)) 586 continue; 587 588 Register DstReg = MI->getOperand(0).getReg(); 589 if (!MRI.constrainRegClass(SrcReg, MRI.getRegClass(DstReg))) 590 continue; 591 592 // Convert Fixable instructions to their W versions. 593 for (MachineInstr *Fixable : FixableDefs) { 594 LLVM_DEBUG(dbgs() << "Replacing " << *Fixable); 595 Fixable->setDesc(TII.get(getWOp(Fixable->getOpcode()))); 596 Fixable->clearFlag(MachineInstr::MIFlag::NoSWrap); 597 Fixable->clearFlag(MachineInstr::MIFlag::NoUWrap); 598 Fixable->clearFlag(MachineInstr::MIFlag::IsExact); 599 LLVM_DEBUG(dbgs() << " with " << *Fixable); 600 ++NumTransformedToWInstrs; 601 } 602 603 LLVM_DEBUG(dbgs() << "Removing redundant sign-extension\n"); 604 MRI.replaceRegWith(DstReg, SrcReg); 605 MRI.clearKillFlags(SrcReg); 606 MI->eraseFromParent(); 607 ++NumRemovedSExtW; 608 MadeChange = true; 609 } 610 } 611 612 return MadeChange; 613 } 614 615 bool RISCVOptWInstrs::stripWSuffixes(MachineFunction &MF, 616 const RISCVInstrInfo &TII, 617 const RISCVSubtarget &ST, 618 MachineRegisterInfo &MRI) { 619 if (DisableStripWSuffix) 620 return false; 621 622 bool MadeChange = false; 623 for (MachineBasicBlock &MBB : MF) { 624 for (auto I = MBB.begin(), IE = MBB.end(); I != IE; ++I) { 625 MachineInstr &MI = *I; 626 627 unsigned Opc; 628 switch (MI.getOpcode()) { 629 default: 630 continue; 631 case RISCV::ADDW: Opc = RISCV::ADD; break; 632 case RISCV::MULW: Opc = RISCV::MUL; break; 633 case RISCV::SLLIW: Opc = RISCV::SLLI; break; 634 } 635 636 if (hasAllWUsers(MI, ST, MRI)) { 637 MI.setDesc(TII.get(Opc)); 638 MadeChange = true; 639 } 640 } 641 } 642 643 return MadeChange; 644 } 645 646 bool RISCVOptWInstrs::runOnMachineFunction(MachineFunction &MF) { 647 if (skipFunction(MF.getFunction())) 648 return false; 649 650 MachineRegisterInfo &MRI = MF.getRegInfo(); 651 const RISCVSubtarget &ST = MF.getSubtarget<RISCVSubtarget>(); 652 const RISCVInstrInfo &TII = *ST.getInstrInfo(); 653 654 if (!ST.is64Bit()) 655 return false; 656 657 bool MadeChange = false; 658 MadeChange |= removeSExtWInstrs(MF, TII, ST, MRI); 659 MadeChange |= stripWSuffixes(MF, TII, ST, MRI); 660 661 return MadeChange; 662 } 663