1 //===-- lib/CodeGen/GlobalISel/GICombinerHelper.cpp -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h" 9 #include "llvm/ADT/APFloat.h" 10 #include "llvm/ADT/STLExtras.h" 11 #include "llvm/ADT/SetVector.h" 12 #include "llvm/ADT/SmallBitVector.h" 13 #include "llvm/Analysis/CmpInstAnalysis.h" 14 #include "llvm/CodeGen/GlobalISel/GISelChangeObserver.h" 15 #include "llvm/CodeGen/GlobalISel/GISelValueTracking.h" 16 #include "llvm/CodeGen/GlobalISel/GenericMachineInstrs.h" 17 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h" 18 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h" 19 #include "llvm/CodeGen/GlobalISel/MIPatternMatch.h" 20 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h" 21 #include "llvm/CodeGen/GlobalISel/Utils.h" 22 #include "llvm/CodeGen/LowLevelTypeUtils.h" 23 #include "llvm/CodeGen/MachineBasicBlock.h" 24 #include "llvm/CodeGen/MachineDominators.h" 25 #include "llvm/CodeGen/MachineInstr.h" 26 #include "llvm/CodeGen/MachineMemOperand.h" 27 #include "llvm/CodeGen/MachineRegisterInfo.h" 28 #include "llvm/CodeGen/Register.h" 29 #include "llvm/CodeGen/RegisterBankInfo.h" 30 #include "llvm/CodeGen/TargetInstrInfo.h" 31 #include "llvm/CodeGen/TargetLowering.h" 32 #include "llvm/CodeGen/TargetOpcodes.h" 33 #include "llvm/IR/ConstantRange.h" 34 #include "llvm/IR/DataLayout.h" 35 #include "llvm/IR/InstrTypes.h" 36 #include "llvm/Support/Casting.h" 37 #include "llvm/Support/DivisionByConstantInfo.h" 38 #include "llvm/Support/ErrorHandling.h" 39 #include "llvm/Support/MathExtras.h" 40 #include "llvm/Target/TargetMachine.h" 41 #include <cmath> 42 #include <optional> 43 #include <tuple> 44 45 #define DEBUG_TYPE "gi-combiner" 46 47 using namespace llvm; 48 using namespace MIPatternMatch; 49 50 // Option to allow testing of the combiner while no targets know about indexed 51 // addressing. 52 static cl::opt<bool> 53 ForceLegalIndexing("force-legal-indexing", cl::Hidden, cl::init(false), 54 cl::desc("Force all indexed operations to be " 55 "legal for the GlobalISel combiner")); 56 57 CombinerHelper::CombinerHelper(GISelChangeObserver &Observer, 58 MachineIRBuilder &B, bool IsPreLegalize, 59 GISelValueTracking *VT, 60 MachineDominatorTree *MDT, 61 const LegalizerInfo *LI) 62 : Builder(B), MRI(Builder.getMF().getRegInfo()), Observer(Observer), VT(VT), 63 MDT(MDT), IsPreLegalize(IsPreLegalize), LI(LI), 64 RBI(Builder.getMF().getSubtarget().getRegBankInfo()), 65 TRI(Builder.getMF().getSubtarget().getRegisterInfo()) { 66 (void)this->VT; 67 } 68 69 const TargetLowering &CombinerHelper::getTargetLowering() const { 70 return *Builder.getMF().getSubtarget().getTargetLowering(); 71 } 72 73 const MachineFunction &CombinerHelper::getMachineFunction() const { 74 return Builder.getMF(); 75 } 76 77 const DataLayout &CombinerHelper::getDataLayout() const { 78 return getMachineFunction().getDataLayout(); 79 } 80 81 LLVMContext &CombinerHelper::getContext() const { return Builder.getContext(); } 82 83 /// \returns The little endian in-memory byte position of byte \p I in a 84 /// \p ByteWidth bytes wide type. 85 /// 86 /// E.g. Given a 4-byte type x, x[0] -> byte 0 87 static unsigned littleEndianByteAt(const unsigned ByteWidth, const unsigned I) { 88 assert(I < ByteWidth && "I must be in [0, ByteWidth)"); 89 return I; 90 } 91 92 /// Determines the LogBase2 value for a non-null input value using the 93 /// transform: LogBase2(V) = (EltBits - 1) - ctlz(V). 94 static Register buildLogBase2(Register V, MachineIRBuilder &MIB) { 95 auto &MRI = *MIB.getMRI(); 96 LLT Ty = MRI.getType(V); 97 auto Ctlz = MIB.buildCTLZ(Ty, V); 98 auto Base = MIB.buildConstant(Ty, Ty.getScalarSizeInBits() - 1); 99 return MIB.buildSub(Ty, Base, Ctlz).getReg(0); 100 } 101 102 /// \returns The big endian in-memory byte position of byte \p I in a 103 /// \p ByteWidth bytes wide type. 104 /// 105 /// E.g. Given a 4-byte type x, x[0] -> byte 3 106 static unsigned bigEndianByteAt(const unsigned ByteWidth, const unsigned I) { 107 assert(I < ByteWidth && "I must be in [0, ByteWidth)"); 108 return ByteWidth - I - 1; 109 } 110 111 /// Given a map from byte offsets in memory to indices in a load/store, 112 /// determine if that map corresponds to a little or big endian byte pattern. 113 /// 114 /// \param MemOffset2Idx maps memory offsets to address offsets. 115 /// \param LowestIdx is the lowest index in \p MemOffset2Idx. 116 /// 117 /// \returns true if the map corresponds to a big endian byte pattern, false if 118 /// it corresponds to a little endian byte pattern, and std::nullopt otherwise. 119 /// 120 /// E.g. given a 32-bit type x, and x[AddrOffset], the in-memory byte patterns 121 /// are as follows: 122 /// 123 /// AddrOffset Little endian Big endian 124 /// 0 0 3 125 /// 1 1 2 126 /// 2 2 1 127 /// 3 3 0 128 static std::optional<bool> 129 isBigEndian(const SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx, 130 int64_t LowestIdx) { 131 // Need at least two byte positions to decide on endianness. 132 unsigned Width = MemOffset2Idx.size(); 133 if (Width < 2) 134 return std::nullopt; 135 bool BigEndian = true, LittleEndian = true; 136 for (unsigned MemOffset = 0; MemOffset < Width; ++ MemOffset) { 137 auto MemOffsetAndIdx = MemOffset2Idx.find(MemOffset); 138 if (MemOffsetAndIdx == MemOffset2Idx.end()) 139 return std::nullopt; 140 const int64_t Idx = MemOffsetAndIdx->second - LowestIdx; 141 assert(Idx >= 0 && "Expected non-negative byte offset?"); 142 LittleEndian &= Idx == littleEndianByteAt(Width, MemOffset); 143 BigEndian &= Idx == bigEndianByteAt(Width, MemOffset); 144 if (!BigEndian && !LittleEndian) 145 return std::nullopt; 146 } 147 148 assert((BigEndian != LittleEndian) && 149 "Pattern cannot be both big and little endian!"); 150 return BigEndian; 151 } 152 153 bool CombinerHelper::isPreLegalize() const { return IsPreLegalize; } 154 155 bool CombinerHelper::isLegal(const LegalityQuery &Query) const { 156 assert(LI && "Must have LegalizerInfo to query isLegal!"); 157 return LI->getAction(Query).Action == LegalizeActions::Legal; 158 } 159 160 bool CombinerHelper::isLegalOrBeforeLegalizer( 161 const LegalityQuery &Query) const { 162 return isPreLegalize() || isLegal(Query); 163 } 164 165 bool CombinerHelper::isLegalOrHasWidenScalar(const LegalityQuery &Query) const { 166 return isLegal(Query) || 167 LI->getAction(Query).Action == LegalizeActions::WidenScalar; 168 } 169 170 bool CombinerHelper::isConstantLegalOrBeforeLegalizer(const LLT Ty) const { 171 if (!Ty.isVector()) 172 return isLegalOrBeforeLegalizer({TargetOpcode::G_CONSTANT, {Ty}}); 173 // Vector constants are represented as a G_BUILD_VECTOR of scalar G_CONSTANTs. 174 if (isPreLegalize()) 175 return true; 176 LLT EltTy = Ty.getElementType(); 177 return isLegal({TargetOpcode::G_BUILD_VECTOR, {Ty, EltTy}}) && 178 isLegal({TargetOpcode::G_CONSTANT, {EltTy}}); 179 } 180 181 void CombinerHelper::replaceRegWith(MachineRegisterInfo &MRI, Register FromReg, 182 Register ToReg) const { 183 Observer.changingAllUsesOfReg(MRI, FromReg); 184 185 if (MRI.constrainRegAttrs(ToReg, FromReg)) 186 MRI.replaceRegWith(FromReg, ToReg); 187 else 188 Builder.buildCopy(FromReg, ToReg); 189 190 Observer.finishedChangingAllUsesOfReg(); 191 } 192 193 void CombinerHelper::replaceRegOpWith(MachineRegisterInfo &MRI, 194 MachineOperand &FromRegOp, 195 Register ToReg) const { 196 assert(FromRegOp.getParent() && "Expected an operand in an MI"); 197 Observer.changingInstr(*FromRegOp.getParent()); 198 199 FromRegOp.setReg(ToReg); 200 201 Observer.changedInstr(*FromRegOp.getParent()); 202 } 203 204 void CombinerHelper::replaceOpcodeWith(MachineInstr &FromMI, 205 unsigned ToOpcode) const { 206 Observer.changingInstr(FromMI); 207 208 FromMI.setDesc(Builder.getTII().get(ToOpcode)); 209 210 Observer.changedInstr(FromMI); 211 } 212 213 const RegisterBank *CombinerHelper::getRegBank(Register Reg) const { 214 return RBI->getRegBank(Reg, MRI, *TRI); 215 } 216 217 void CombinerHelper::setRegBank(Register Reg, 218 const RegisterBank *RegBank) const { 219 if (RegBank) 220 MRI.setRegBank(Reg, *RegBank); 221 } 222 223 bool CombinerHelper::tryCombineCopy(MachineInstr &MI) const { 224 if (matchCombineCopy(MI)) { 225 applyCombineCopy(MI); 226 return true; 227 } 228 return false; 229 } 230 bool CombinerHelper::matchCombineCopy(MachineInstr &MI) const { 231 if (MI.getOpcode() != TargetOpcode::COPY) 232 return false; 233 Register DstReg = MI.getOperand(0).getReg(); 234 Register SrcReg = MI.getOperand(1).getReg(); 235 return canReplaceReg(DstReg, SrcReg, MRI); 236 } 237 void CombinerHelper::applyCombineCopy(MachineInstr &MI) const { 238 Register DstReg = MI.getOperand(0).getReg(); 239 Register SrcReg = MI.getOperand(1).getReg(); 240 replaceRegWith(MRI, DstReg, SrcReg); 241 MI.eraseFromParent(); 242 } 243 244 bool CombinerHelper::matchFreezeOfSingleMaybePoisonOperand( 245 MachineInstr &MI, BuildFnTy &MatchInfo) const { 246 // Ported from InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating. 247 Register DstOp = MI.getOperand(0).getReg(); 248 Register OrigOp = MI.getOperand(1).getReg(); 249 250 if (!MRI.hasOneNonDBGUse(OrigOp)) 251 return false; 252 253 MachineInstr *OrigDef = MRI.getUniqueVRegDef(OrigOp); 254 // Even if only a single operand of the PHI is not guaranteed non-poison, 255 // moving freeze() backwards across a PHI can cause optimization issues for 256 // other users of that operand. 257 // 258 // Moving freeze() from one of the output registers of a G_UNMERGE_VALUES to 259 // the source register is unprofitable because it makes the freeze() more 260 // strict than is necessary (it would affect the whole register instead of 261 // just the subreg being frozen). 262 if (OrigDef->isPHI() || isa<GUnmerge>(OrigDef)) 263 return false; 264 265 if (canCreateUndefOrPoison(OrigOp, MRI, 266 /*ConsiderFlagsAndMetadata=*/false)) 267 return false; 268 269 std::optional<MachineOperand> MaybePoisonOperand; 270 for (MachineOperand &Operand : OrigDef->uses()) { 271 if (!Operand.isReg()) 272 return false; 273 274 if (isGuaranteedNotToBeUndefOrPoison(Operand.getReg(), MRI)) 275 continue; 276 277 if (!MaybePoisonOperand) 278 MaybePoisonOperand = Operand; 279 else { 280 // We have more than one maybe-poison operand. Moving the freeze is 281 // unsafe. 282 return false; 283 } 284 } 285 286 // Eliminate freeze if all operands are guaranteed non-poison. 287 if (!MaybePoisonOperand) { 288 MatchInfo = [=](MachineIRBuilder &B) { 289 Observer.changingInstr(*OrigDef); 290 cast<GenericMachineInstr>(OrigDef)->dropPoisonGeneratingFlags(); 291 Observer.changedInstr(*OrigDef); 292 B.buildCopy(DstOp, OrigOp); 293 }; 294 return true; 295 } 296 297 Register MaybePoisonOperandReg = MaybePoisonOperand->getReg(); 298 LLT MaybePoisonOperandRegTy = MRI.getType(MaybePoisonOperandReg); 299 300 MatchInfo = [=](MachineIRBuilder &B) mutable { 301 Observer.changingInstr(*OrigDef); 302 cast<GenericMachineInstr>(OrigDef)->dropPoisonGeneratingFlags(); 303 Observer.changedInstr(*OrigDef); 304 B.setInsertPt(*OrigDef->getParent(), OrigDef->getIterator()); 305 auto Freeze = B.buildFreeze(MaybePoisonOperandRegTy, MaybePoisonOperandReg); 306 replaceRegOpWith( 307 MRI, *OrigDef->findRegisterUseOperand(MaybePoisonOperandReg, TRI), 308 Freeze.getReg(0)); 309 replaceRegWith(MRI, DstOp, OrigOp); 310 }; 311 return true; 312 } 313 314 bool CombinerHelper::matchCombineConcatVectors( 315 MachineInstr &MI, SmallVector<Register> &Ops) const { 316 assert(MI.getOpcode() == TargetOpcode::G_CONCAT_VECTORS && 317 "Invalid instruction"); 318 bool IsUndef = true; 319 MachineInstr *Undef = nullptr; 320 321 // Walk over all the operands of concat vectors and check if they are 322 // build_vector themselves or undef. 323 // Then collect their operands in Ops. 324 for (const MachineOperand &MO : MI.uses()) { 325 Register Reg = MO.getReg(); 326 MachineInstr *Def = MRI.getVRegDef(Reg); 327 assert(Def && "Operand not defined"); 328 if (!MRI.hasOneNonDBGUse(Reg)) 329 return false; 330 switch (Def->getOpcode()) { 331 case TargetOpcode::G_BUILD_VECTOR: 332 IsUndef = false; 333 // Remember the operands of the build_vector to fold 334 // them into the yet-to-build flattened concat vectors. 335 for (const MachineOperand &BuildVecMO : Def->uses()) 336 Ops.push_back(BuildVecMO.getReg()); 337 break; 338 case TargetOpcode::G_IMPLICIT_DEF: { 339 LLT OpType = MRI.getType(Reg); 340 // Keep one undef value for all the undef operands. 341 if (!Undef) { 342 Builder.setInsertPt(*MI.getParent(), MI); 343 Undef = Builder.buildUndef(OpType.getScalarType()); 344 } 345 assert(MRI.getType(Undef->getOperand(0).getReg()) == 346 OpType.getScalarType() && 347 "All undefs should have the same type"); 348 // Break the undef vector in as many scalar elements as needed 349 // for the flattening. 350 for (unsigned EltIdx = 0, EltEnd = OpType.getNumElements(); 351 EltIdx != EltEnd; ++EltIdx) 352 Ops.push_back(Undef->getOperand(0).getReg()); 353 break; 354 } 355 default: 356 return false; 357 } 358 } 359 360 // Check if the combine is illegal 361 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 362 if (!isLegalOrBeforeLegalizer( 363 {TargetOpcode::G_BUILD_VECTOR, {DstTy, MRI.getType(Ops[0])}})) { 364 return false; 365 } 366 367 if (IsUndef) 368 Ops.clear(); 369 370 return true; 371 } 372 void CombinerHelper::applyCombineConcatVectors( 373 MachineInstr &MI, SmallVector<Register> &Ops) const { 374 // We determined that the concat_vectors can be flatten. 375 // Generate the flattened build_vector. 376 Register DstReg = MI.getOperand(0).getReg(); 377 Builder.setInsertPt(*MI.getParent(), MI); 378 Register NewDstReg = MRI.cloneVirtualRegister(DstReg); 379 380 // Note: IsUndef is sort of redundant. We could have determine it by 381 // checking that at all Ops are undef. Alternatively, we could have 382 // generate a build_vector of undefs and rely on another combine to 383 // clean that up. For now, given we already gather this information 384 // in matchCombineConcatVectors, just save compile time and issue the 385 // right thing. 386 if (Ops.empty()) 387 Builder.buildUndef(NewDstReg); 388 else 389 Builder.buildBuildVector(NewDstReg, Ops); 390 replaceRegWith(MRI, DstReg, NewDstReg); 391 MI.eraseFromParent(); 392 } 393 394 bool CombinerHelper::matchCombineShuffleToBuildVector(MachineInstr &MI) const { 395 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR && 396 "Invalid instruction"); 397 auto &Shuffle = cast<GShuffleVector>(MI); 398 399 Register SrcVec1 = Shuffle.getSrc1Reg(); 400 Register SrcVec2 = Shuffle.getSrc2Reg(); 401 402 LLT SrcVec1Type = MRI.getType(SrcVec1); 403 LLT SrcVec2Type = MRI.getType(SrcVec2); 404 return SrcVec1Type.isVector() && SrcVec2Type.isVector(); 405 } 406 407 void CombinerHelper::applyCombineShuffleToBuildVector(MachineInstr &MI) const { 408 auto &Shuffle = cast<GShuffleVector>(MI); 409 410 Register SrcVec1 = Shuffle.getSrc1Reg(); 411 Register SrcVec2 = Shuffle.getSrc2Reg(); 412 LLT EltTy = MRI.getType(SrcVec1).getElementType(); 413 int Width = MRI.getType(SrcVec1).getNumElements(); 414 415 auto Unmerge1 = Builder.buildUnmerge(EltTy, SrcVec1); 416 auto Unmerge2 = Builder.buildUnmerge(EltTy, SrcVec2); 417 418 SmallVector<Register> Extracts; 419 // Select only applicable elements from unmerged values. 420 for (int Val : Shuffle.getMask()) { 421 if (Val == -1) 422 Extracts.push_back(Builder.buildUndef(EltTy).getReg(0)); 423 else if (Val < Width) 424 Extracts.push_back(Unmerge1.getReg(Val)); 425 else 426 Extracts.push_back(Unmerge2.getReg(Val - Width)); 427 } 428 assert(Extracts.size() > 0 && "Expected at least one element in the shuffle"); 429 if (Extracts.size() == 1) 430 Builder.buildCopy(MI.getOperand(0).getReg(), Extracts[0]); 431 else 432 Builder.buildBuildVector(MI.getOperand(0).getReg(), Extracts); 433 MI.eraseFromParent(); 434 } 435 436 bool CombinerHelper::matchCombineShuffleConcat( 437 MachineInstr &MI, SmallVector<Register> &Ops) const { 438 ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask(); 439 auto ConcatMI1 = 440 dyn_cast<GConcatVectors>(MRI.getVRegDef(MI.getOperand(1).getReg())); 441 auto ConcatMI2 = 442 dyn_cast<GConcatVectors>(MRI.getVRegDef(MI.getOperand(2).getReg())); 443 if (!ConcatMI1 || !ConcatMI2) 444 return false; 445 446 // Check that the sources of the Concat instructions have the same type 447 if (MRI.getType(ConcatMI1->getSourceReg(0)) != 448 MRI.getType(ConcatMI2->getSourceReg(0))) 449 return false; 450 451 LLT ConcatSrcTy = MRI.getType(ConcatMI1->getReg(1)); 452 LLT ShuffleSrcTy1 = MRI.getType(MI.getOperand(1).getReg()); 453 unsigned ConcatSrcNumElt = ConcatSrcTy.getNumElements(); 454 for (unsigned i = 0; i < Mask.size(); i += ConcatSrcNumElt) { 455 // Check if the index takes a whole source register from G_CONCAT_VECTORS 456 // Assumes that all Sources of G_CONCAT_VECTORS are the same type 457 if (Mask[i] == -1) { 458 for (unsigned j = 1; j < ConcatSrcNumElt; j++) { 459 if (i + j >= Mask.size()) 460 return false; 461 if (Mask[i + j] != -1) 462 return false; 463 } 464 if (!isLegalOrBeforeLegalizer( 465 {TargetOpcode::G_IMPLICIT_DEF, {ConcatSrcTy}})) 466 return false; 467 Ops.push_back(0); 468 } else if (Mask[i] % ConcatSrcNumElt == 0) { 469 for (unsigned j = 1; j < ConcatSrcNumElt; j++) { 470 if (i + j >= Mask.size()) 471 return false; 472 if (Mask[i + j] != Mask[i] + static_cast<int>(j)) 473 return false; 474 } 475 // Retrieve the source register from its respective G_CONCAT_VECTORS 476 // instruction 477 if (Mask[i] < ShuffleSrcTy1.getNumElements()) { 478 Ops.push_back(ConcatMI1->getSourceReg(Mask[i] / ConcatSrcNumElt)); 479 } else { 480 Ops.push_back(ConcatMI2->getSourceReg(Mask[i] / ConcatSrcNumElt - 481 ConcatMI1->getNumSources())); 482 } 483 } else { 484 return false; 485 } 486 } 487 488 if (!isLegalOrBeforeLegalizer( 489 {TargetOpcode::G_CONCAT_VECTORS, 490 {MRI.getType(MI.getOperand(0).getReg()), ConcatSrcTy}})) 491 return false; 492 493 return !Ops.empty(); 494 } 495 496 void CombinerHelper::applyCombineShuffleConcat( 497 MachineInstr &MI, SmallVector<Register> &Ops) const { 498 LLT SrcTy; 499 for (Register &Reg : Ops) { 500 if (Reg != 0) 501 SrcTy = MRI.getType(Reg); 502 } 503 assert(SrcTy.isValid() && "Unexpected full undef vector in concat combine"); 504 505 Register UndefReg = 0; 506 507 for (Register &Reg : Ops) { 508 if (Reg == 0) { 509 if (UndefReg == 0) 510 UndefReg = Builder.buildUndef(SrcTy).getReg(0); 511 Reg = UndefReg; 512 } 513 } 514 515 if (Ops.size() > 1) 516 Builder.buildConcatVectors(MI.getOperand(0).getReg(), Ops); 517 else 518 Builder.buildCopy(MI.getOperand(0).getReg(), Ops[0]); 519 MI.eraseFromParent(); 520 } 521 522 bool CombinerHelper::tryCombineShuffleVector(MachineInstr &MI) const { 523 SmallVector<Register, 4> Ops; 524 if (matchCombineShuffleVector(MI, Ops)) { 525 applyCombineShuffleVector(MI, Ops); 526 return true; 527 } 528 return false; 529 } 530 531 bool CombinerHelper::matchCombineShuffleVector( 532 MachineInstr &MI, SmallVectorImpl<Register> &Ops) const { 533 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR && 534 "Invalid instruction kind"); 535 LLT DstType = MRI.getType(MI.getOperand(0).getReg()); 536 Register Src1 = MI.getOperand(1).getReg(); 537 LLT SrcType = MRI.getType(Src1); 538 // As bizarre as it may look, shuffle vector can actually produce 539 // scalar! This is because at the IR level a <1 x ty> shuffle 540 // vector is perfectly valid. 541 unsigned DstNumElts = DstType.isVector() ? DstType.getNumElements() : 1; 542 unsigned SrcNumElts = SrcType.isVector() ? SrcType.getNumElements() : 1; 543 544 // If the resulting vector is smaller than the size of the source 545 // vectors being concatenated, we won't be able to replace the 546 // shuffle vector into a concat_vectors. 547 // 548 // Note: We may still be able to produce a concat_vectors fed by 549 // extract_vector_elt and so on. It is less clear that would 550 // be better though, so don't bother for now. 551 // 552 // If the destination is a scalar, the size of the sources doesn't 553 // matter. we will lower the shuffle to a plain copy. This will 554 // work only if the source and destination have the same size. But 555 // that's covered by the next condition. 556 // 557 // TODO: If the size between the source and destination don't match 558 // we could still emit an extract vector element in that case. 559 if (DstNumElts < 2 * SrcNumElts && DstNumElts != 1) 560 return false; 561 562 // Check that the shuffle mask can be broken evenly between the 563 // different sources. 564 if (DstNumElts % SrcNumElts != 0) 565 return false; 566 567 // Mask length is a multiple of the source vector length. 568 // Check if the shuffle is some kind of concatenation of the input 569 // vectors. 570 unsigned NumConcat = DstNumElts / SrcNumElts; 571 SmallVector<int, 8> ConcatSrcs(NumConcat, -1); 572 ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask(); 573 for (unsigned i = 0; i != DstNumElts; ++i) { 574 int Idx = Mask[i]; 575 // Undef value. 576 if (Idx < 0) 577 continue; 578 // Ensure the indices in each SrcType sized piece are sequential and that 579 // the same source is used for the whole piece. 580 if ((Idx % SrcNumElts != (i % SrcNumElts)) || 581 (ConcatSrcs[i / SrcNumElts] >= 0 && 582 ConcatSrcs[i / SrcNumElts] != (int)(Idx / SrcNumElts))) 583 return false; 584 // Remember which source this index came from. 585 ConcatSrcs[i / SrcNumElts] = Idx / SrcNumElts; 586 } 587 588 // The shuffle is concatenating multiple vectors together. 589 // Collect the different operands for that. 590 Register UndefReg; 591 Register Src2 = MI.getOperand(2).getReg(); 592 for (auto Src : ConcatSrcs) { 593 if (Src < 0) { 594 if (!UndefReg) { 595 Builder.setInsertPt(*MI.getParent(), MI); 596 UndefReg = Builder.buildUndef(SrcType).getReg(0); 597 } 598 Ops.push_back(UndefReg); 599 } else if (Src == 0) 600 Ops.push_back(Src1); 601 else 602 Ops.push_back(Src2); 603 } 604 return true; 605 } 606 607 void CombinerHelper::applyCombineShuffleVector( 608 MachineInstr &MI, const ArrayRef<Register> Ops) const { 609 Register DstReg = MI.getOperand(0).getReg(); 610 Builder.setInsertPt(*MI.getParent(), MI); 611 Register NewDstReg = MRI.cloneVirtualRegister(DstReg); 612 613 if (Ops.size() == 1) 614 Builder.buildCopy(NewDstReg, Ops[0]); 615 else 616 Builder.buildMergeLikeInstr(NewDstReg, Ops); 617 618 replaceRegWith(MRI, DstReg, NewDstReg); 619 MI.eraseFromParent(); 620 } 621 622 bool CombinerHelper::matchShuffleToExtract(MachineInstr &MI) const { 623 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR && 624 "Invalid instruction kind"); 625 626 ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask(); 627 return Mask.size() == 1; 628 } 629 630 void CombinerHelper::applyShuffleToExtract(MachineInstr &MI) const { 631 Register DstReg = MI.getOperand(0).getReg(); 632 Builder.setInsertPt(*MI.getParent(), MI); 633 634 int I = MI.getOperand(3).getShuffleMask()[0]; 635 Register Src1 = MI.getOperand(1).getReg(); 636 LLT Src1Ty = MRI.getType(Src1); 637 int Src1NumElts = Src1Ty.isVector() ? Src1Ty.getNumElements() : 1; 638 Register SrcReg; 639 if (I >= Src1NumElts) { 640 SrcReg = MI.getOperand(2).getReg(); 641 I -= Src1NumElts; 642 } else if (I >= 0) 643 SrcReg = Src1; 644 645 if (I < 0) 646 Builder.buildUndef(DstReg); 647 else if (!MRI.getType(SrcReg).isVector()) 648 Builder.buildCopy(DstReg, SrcReg); 649 else 650 Builder.buildExtractVectorElementConstant(DstReg, SrcReg, I); 651 652 MI.eraseFromParent(); 653 } 654 655 namespace { 656 657 /// Select a preference between two uses. CurrentUse is the current preference 658 /// while *ForCandidate is attributes of the candidate under consideration. 659 PreferredTuple ChoosePreferredUse(MachineInstr &LoadMI, 660 PreferredTuple &CurrentUse, 661 const LLT TyForCandidate, 662 unsigned OpcodeForCandidate, 663 MachineInstr *MIForCandidate) { 664 if (!CurrentUse.Ty.isValid()) { 665 if (CurrentUse.ExtendOpcode == OpcodeForCandidate || 666 CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT) 667 return {TyForCandidate, OpcodeForCandidate, MIForCandidate}; 668 return CurrentUse; 669 } 670 671 // We permit the extend to hoist through basic blocks but this is only 672 // sensible if the target has extending loads. If you end up lowering back 673 // into a load and extend during the legalizer then the end result is 674 // hoisting the extend up to the load. 675 676 // Prefer defined extensions to undefined extensions as these are more 677 // likely to reduce the number of instructions. 678 if (OpcodeForCandidate == TargetOpcode::G_ANYEXT && 679 CurrentUse.ExtendOpcode != TargetOpcode::G_ANYEXT) 680 return CurrentUse; 681 else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ANYEXT && 682 OpcodeForCandidate != TargetOpcode::G_ANYEXT) 683 return {TyForCandidate, OpcodeForCandidate, MIForCandidate}; 684 685 // Prefer sign extensions to zero extensions as sign-extensions tend to be 686 // more expensive. Don't do this if the load is already a zero-extend load 687 // though, otherwise we'll rewrite a zero-extend load into a sign-extend 688 // later. 689 if (!isa<GZExtLoad>(LoadMI) && CurrentUse.Ty == TyForCandidate) { 690 if (CurrentUse.ExtendOpcode == TargetOpcode::G_SEXT && 691 OpcodeForCandidate == TargetOpcode::G_ZEXT) 692 return CurrentUse; 693 else if (CurrentUse.ExtendOpcode == TargetOpcode::G_ZEXT && 694 OpcodeForCandidate == TargetOpcode::G_SEXT) 695 return {TyForCandidate, OpcodeForCandidate, MIForCandidate}; 696 } 697 698 // This is potentially target specific. We've chosen the largest type 699 // because G_TRUNC is usually free. One potential catch with this is that 700 // some targets have a reduced number of larger registers than smaller 701 // registers and this choice potentially increases the live-range for the 702 // larger value. 703 if (TyForCandidate.getSizeInBits() > CurrentUse.Ty.getSizeInBits()) { 704 return {TyForCandidate, OpcodeForCandidate, MIForCandidate}; 705 } 706 return CurrentUse; 707 } 708 709 /// Find a suitable place to insert some instructions and insert them. This 710 /// function accounts for special cases like inserting before a PHI node. 711 /// The current strategy for inserting before PHI's is to duplicate the 712 /// instructions for each predecessor. However, while that's ok for G_TRUNC 713 /// on most targets since it generally requires no code, other targets/cases may 714 /// want to try harder to find a dominating block. 715 static void InsertInsnsWithoutSideEffectsBeforeUse( 716 MachineIRBuilder &Builder, MachineInstr &DefMI, MachineOperand &UseMO, 717 std::function<void(MachineBasicBlock *, MachineBasicBlock::iterator, 718 MachineOperand &UseMO)> 719 Inserter) { 720 MachineInstr &UseMI = *UseMO.getParent(); 721 722 MachineBasicBlock *InsertBB = UseMI.getParent(); 723 724 // If the use is a PHI then we want the predecessor block instead. 725 if (UseMI.isPHI()) { 726 MachineOperand *PredBB = std::next(&UseMO); 727 InsertBB = PredBB->getMBB(); 728 } 729 730 // If the block is the same block as the def then we want to insert just after 731 // the def instead of at the start of the block. 732 if (InsertBB == DefMI.getParent()) { 733 MachineBasicBlock::iterator InsertPt = &DefMI; 734 Inserter(InsertBB, std::next(InsertPt), UseMO); 735 return; 736 } 737 738 // Otherwise we want the start of the BB 739 Inserter(InsertBB, InsertBB->getFirstNonPHI(), UseMO); 740 } 741 } // end anonymous namespace 742 743 bool CombinerHelper::tryCombineExtendingLoads(MachineInstr &MI) const { 744 PreferredTuple Preferred; 745 if (matchCombineExtendingLoads(MI, Preferred)) { 746 applyCombineExtendingLoads(MI, Preferred); 747 return true; 748 } 749 return false; 750 } 751 752 static unsigned getExtLoadOpcForExtend(unsigned ExtOpc) { 753 unsigned CandidateLoadOpc; 754 switch (ExtOpc) { 755 case TargetOpcode::G_ANYEXT: 756 CandidateLoadOpc = TargetOpcode::G_LOAD; 757 break; 758 case TargetOpcode::G_SEXT: 759 CandidateLoadOpc = TargetOpcode::G_SEXTLOAD; 760 break; 761 case TargetOpcode::G_ZEXT: 762 CandidateLoadOpc = TargetOpcode::G_ZEXTLOAD; 763 break; 764 default: 765 llvm_unreachable("Unexpected extend opc"); 766 } 767 return CandidateLoadOpc; 768 } 769 770 bool CombinerHelper::matchCombineExtendingLoads( 771 MachineInstr &MI, PreferredTuple &Preferred) const { 772 // We match the loads and follow the uses to the extend instead of matching 773 // the extends and following the def to the load. This is because the load 774 // must remain in the same position for correctness (unless we also add code 775 // to find a safe place to sink it) whereas the extend is freely movable. 776 // It also prevents us from duplicating the load for the volatile case or just 777 // for performance. 778 GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(&MI); 779 if (!LoadMI) 780 return false; 781 782 Register LoadReg = LoadMI->getDstReg(); 783 784 LLT LoadValueTy = MRI.getType(LoadReg); 785 if (!LoadValueTy.isScalar()) 786 return false; 787 788 // Most architectures are going to legalize <s8 loads into at least a 1 byte 789 // load, and the MMOs can only describe memory accesses in multiples of bytes. 790 // If we try to perform extload combining on those, we can end up with 791 // %a(s8) = extload %ptr (load 1 byte from %ptr) 792 // ... which is an illegal extload instruction. 793 if (LoadValueTy.getSizeInBits() < 8) 794 return false; 795 796 // For non power-of-2 types, they will very likely be legalized into multiple 797 // loads. Don't bother trying to match them into extending loads. 798 if (!llvm::has_single_bit<uint32_t>(LoadValueTy.getSizeInBits())) 799 return false; 800 801 // Find the preferred type aside from the any-extends (unless it's the only 802 // one) and non-extending ops. We'll emit an extending load to that type and 803 // and emit a variant of (extend (trunc X)) for the others according to the 804 // relative type sizes. At the same time, pick an extend to use based on the 805 // extend involved in the chosen type. 806 unsigned PreferredOpcode = 807 isa<GLoad>(&MI) 808 ? TargetOpcode::G_ANYEXT 809 : isa<GSExtLoad>(&MI) ? TargetOpcode::G_SEXT : TargetOpcode::G_ZEXT; 810 Preferred = {LLT(), PreferredOpcode, nullptr}; 811 for (auto &UseMI : MRI.use_nodbg_instructions(LoadReg)) { 812 if (UseMI.getOpcode() == TargetOpcode::G_SEXT || 813 UseMI.getOpcode() == TargetOpcode::G_ZEXT || 814 (UseMI.getOpcode() == TargetOpcode::G_ANYEXT)) { 815 const auto &MMO = LoadMI->getMMO(); 816 // Don't do anything for atomics. 817 if (MMO.isAtomic()) 818 continue; 819 // Check for legality. 820 if (!isPreLegalize()) { 821 LegalityQuery::MemDesc MMDesc(MMO); 822 unsigned CandidateLoadOpc = getExtLoadOpcForExtend(UseMI.getOpcode()); 823 LLT UseTy = MRI.getType(UseMI.getOperand(0).getReg()); 824 LLT SrcTy = MRI.getType(LoadMI->getPointerReg()); 825 if (LI->getAction({CandidateLoadOpc, {UseTy, SrcTy}, {MMDesc}}) 826 .Action != LegalizeActions::Legal) 827 continue; 828 } 829 Preferred = ChoosePreferredUse(MI, Preferred, 830 MRI.getType(UseMI.getOperand(0).getReg()), 831 UseMI.getOpcode(), &UseMI); 832 } 833 } 834 835 // There were no extends 836 if (!Preferred.MI) 837 return false; 838 // It should be impossible to chose an extend without selecting a different 839 // type since by definition the result of an extend is larger. 840 assert(Preferred.Ty != LoadValueTy && "Extending to same type?"); 841 842 LLVM_DEBUG(dbgs() << "Preferred use is: " << *Preferred.MI); 843 return true; 844 } 845 846 void CombinerHelper::applyCombineExtendingLoads( 847 MachineInstr &MI, PreferredTuple &Preferred) const { 848 // Rewrite the load to the chosen extending load. 849 Register ChosenDstReg = Preferred.MI->getOperand(0).getReg(); 850 851 // Inserter to insert a truncate back to the original type at a given point 852 // with some basic CSE to limit truncate duplication to one per BB. 853 DenseMap<MachineBasicBlock *, MachineInstr *> EmittedInsns; 854 auto InsertTruncAt = [&](MachineBasicBlock *InsertIntoBB, 855 MachineBasicBlock::iterator InsertBefore, 856 MachineOperand &UseMO) { 857 MachineInstr *PreviouslyEmitted = EmittedInsns.lookup(InsertIntoBB); 858 if (PreviouslyEmitted) { 859 Observer.changingInstr(*UseMO.getParent()); 860 UseMO.setReg(PreviouslyEmitted->getOperand(0).getReg()); 861 Observer.changedInstr(*UseMO.getParent()); 862 return; 863 } 864 865 Builder.setInsertPt(*InsertIntoBB, InsertBefore); 866 Register NewDstReg = MRI.cloneVirtualRegister(MI.getOperand(0).getReg()); 867 MachineInstr *NewMI = Builder.buildTrunc(NewDstReg, ChosenDstReg); 868 EmittedInsns[InsertIntoBB] = NewMI; 869 replaceRegOpWith(MRI, UseMO, NewDstReg); 870 }; 871 872 Observer.changingInstr(MI); 873 unsigned LoadOpc = getExtLoadOpcForExtend(Preferred.ExtendOpcode); 874 MI.setDesc(Builder.getTII().get(LoadOpc)); 875 876 // Rewrite all the uses to fix up the types. 877 auto &LoadValue = MI.getOperand(0); 878 SmallVector<MachineOperand *, 4> Uses( 879 llvm::make_pointer_range(MRI.use_operands(LoadValue.getReg()))); 880 881 for (auto *UseMO : Uses) { 882 MachineInstr *UseMI = UseMO->getParent(); 883 884 // If the extend is compatible with the preferred extend then we should fix 885 // up the type and extend so that it uses the preferred use. 886 if (UseMI->getOpcode() == Preferred.ExtendOpcode || 887 UseMI->getOpcode() == TargetOpcode::G_ANYEXT) { 888 Register UseDstReg = UseMI->getOperand(0).getReg(); 889 MachineOperand &UseSrcMO = UseMI->getOperand(1); 890 const LLT UseDstTy = MRI.getType(UseDstReg); 891 if (UseDstReg != ChosenDstReg) { 892 if (Preferred.Ty == UseDstTy) { 893 // If the use has the same type as the preferred use, then merge 894 // the vregs and erase the extend. For example: 895 // %1:_(s8) = G_LOAD ... 896 // %2:_(s32) = G_SEXT %1(s8) 897 // %3:_(s32) = G_ANYEXT %1(s8) 898 // ... = ... %3(s32) 899 // rewrites to: 900 // %2:_(s32) = G_SEXTLOAD ... 901 // ... = ... %2(s32) 902 replaceRegWith(MRI, UseDstReg, ChosenDstReg); 903 Observer.erasingInstr(*UseMO->getParent()); 904 UseMO->getParent()->eraseFromParent(); 905 } else if (Preferred.Ty.getSizeInBits() < UseDstTy.getSizeInBits()) { 906 // If the preferred size is smaller, then keep the extend but extend 907 // from the result of the extending load. For example: 908 // %1:_(s8) = G_LOAD ... 909 // %2:_(s32) = G_SEXT %1(s8) 910 // %3:_(s64) = G_ANYEXT %1(s8) 911 // ... = ... %3(s64) 912 /// rewrites to: 913 // %2:_(s32) = G_SEXTLOAD ... 914 // %3:_(s64) = G_ANYEXT %2:_(s32) 915 // ... = ... %3(s64) 916 replaceRegOpWith(MRI, UseSrcMO, ChosenDstReg); 917 } else { 918 // If the preferred size is large, then insert a truncate. For 919 // example: 920 // %1:_(s8) = G_LOAD ... 921 // %2:_(s64) = G_SEXT %1(s8) 922 // %3:_(s32) = G_ZEXT %1(s8) 923 // ... = ... %3(s32) 924 /// rewrites to: 925 // %2:_(s64) = G_SEXTLOAD ... 926 // %4:_(s8) = G_TRUNC %2:_(s32) 927 // %3:_(s64) = G_ZEXT %2:_(s8) 928 // ... = ... %3(s64) 929 InsertInsnsWithoutSideEffectsBeforeUse(Builder, MI, *UseMO, 930 InsertTruncAt); 931 } 932 continue; 933 } 934 // The use is (one of) the uses of the preferred use we chose earlier. 935 // We're going to update the load to def this value later so just erase 936 // the old extend. 937 Observer.erasingInstr(*UseMO->getParent()); 938 UseMO->getParent()->eraseFromParent(); 939 continue; 940 } 941 942 // The use isn't an extend. Truncate back to the type we originally loaded. 943 // This is free on many targets. 944 InsertInsnsWithoutSideEffectsBeforeUse(Builder, MI, *UseMO, InsertTruncAt); 945 } 946 947 MI.getOperand(0).setReg(ChosenDstReg); 948 Observer.changedInstr(MI); 949 } 950 951 bool CombinerHelper::matchCombineLoadWithAndMask(MachineInstr &MI, 952 BuildFnTy &MatchInfo) const { 953 assert(MI.getOpcode() == TargetOpcode::G_AND); 954 955 // If we have the following code: 956 // %mask = G_CONSTANT 255 957 // %ld = G_LOAD %ptr, (load s16) 958 // %and = G_AND %ld, %mask 959 // 960 // Try to fold it into 961 // %ld = G_ZEXTLOAD %ptr, (load s8) 962 963 Register Dst = MI.getOperand(0).getReg(); 964 if (MRI.getType(Dst).isVector()) 965 return false; 966 967 auto MaybeMask = 968 getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI); 969 if (!MaybeMask) 970 return false; 971 972 APInt MaskVal = MaybeMask->Value; 973 974 if (!MaskVal.isMask()) 975 return false; 976 977 Register SrcReg = MI.getOperand(1).getReg(); 978 // Don't use getOpcodeDef() here since intermediate instructions may have 979 // multiple users. 980 GAnyLoad *LoadMI = dyn_cast<GAnyLoad>(MRI.getVRegDef(SrcReg)); 981 if (!LoadMI || !MRI.hasOneNonDBGUse(LoadMI->getDstReg())) 982 return false; 983 984 Register LoadReg = LoadMI->getDstReg(); 985 LLT RegTy = MRI.getType(LoadReg); 986 Register PtrReg = LoadMI->getPointerReg(); 987 unsigned RegSize = RegTy.getSizeInBits(); 988 LocationSize LoadSizeBits = LoadMI->getMemSizeInBits(); 989 unsigned MaskSizeBits = MaskVal.countr_one(); 990 991 // The mask may not be larger than the in-memory type, as it might cover sign 992 // extended bits 993 if (MaskSizeBits > LoadSizeBits.getValue()) 994 return false; 995 996 // If the mask covers the whole destination register, there's nothing to 997 // extend 998 if (MaskSizeBits >= RegSize) 999 return false; 1000 1001 // Most targets cannot deal with loads of size < 8 and need to re-legalize to 1002 // at least byte loads. Avoid creating such loads here 1003 if (MaskSizeBits < 8 || !isPowerOf2_32(MaskSizeBits)) 1004 return false; 1005 1006 const MachineMemOperand &MMO = LoadMI->getMMO(); 1007 LegalityQuery::MemDesc MemDesc(MMO); 1008 1009 // Don't modify the memory access size if this is atomic/volatile, but we can 1010 // still adjust the opcode to indicate the high bit behavior. 1011 if (LoadMI->isSimple()) 1012 MemDesc.MemoryTy = LLT::scalar(MaskSizeBits); 1013 else if (LoadSizeBits.getValue() > MaskSizeBits || 1014 LoadSizeBits.getValue() == RegSize) 1015 return false; 1016 1017 // TODO: Could check if it's legal with the reduced or original memory size. 1018 if (!isLegalOrBeforeLegalizer( 1019 {TargetOpcode::G_ZEXTLOAD, {RegTy, MRI.getType(PtrReg)}, {MemDesc}})) 1020 return false; 1021 1022 MatchInfo = [=](MachineIRBuilder &B) { 1023 B.setInstrAndDebugLoc(*LoadMI); 1024 auto &MF = B.getMF(); 1025 auto PtrInfo = MMO.getPointerInfo(); 1026 auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, MemDesc.MemoryTy); 1027 B.buildLoadInstr(TargetOpcode::G_ZEXTLOAD, Dst, PtrReg, *NewMMO); 1028 LoadMI->eraseFromParent(); 1029 }; 1030 return true; 1031 } 1032 1033 bool CombinerHelper::isPredecessor(const MachineInstr &DefMI, 1034 const MachineInstr &UseMI) const { 1035 assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() && 1036 "shouldn't consider debug uses"); 1037 assert(DefMI.getParent() == UseMI.getParent()); 1038 if (&DefMI == &UseMI) 1039 return true; 1040 const MachineBasicBlock &MBB = *DefMI.getParent(); 1041 auto DefOrUse = find_if(MBB, [&DefMI, &UseMI](const MachineInstr &MI) { 1042 return &MI == &DefMI || &MI == &UseMI; 1043 }); 1044 if (DefOrUse == MBB.end()) 1045 llvm_unreachable("Block must contain both DefMI and UseMI!"); 1046 return &*DefOrUse == &DefMI; 1047 } 1048 1049 bool CombinerHelper::dominates(const MachineInstr &DefMI, 1050 const MachineInstr &UseMI) const { 1051 assert(!DefMI.isDebugInstr() && !UseMI.isDebugInstr() && 1052 "shouldn't consider debug uses"); 1053 if (MDT) 1054 return MDT->dominates(&DefMI, &UseMI); 1055 else if (DefMI.getParent() != UseMI.getParent()) 1056 return false; 1057 1058 return isPredecessor(DefMI, UseMI); 1059 } 1060 1061 bool CombinerHelper::matchSextTruncSextLoad(MachineInstr &MI) const { 1062 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); 1063 Register SrcReg = MI.getOperand(1).getReg(); 1064 Register LoadUser = SrcReg; 1065 1066 if (MRI.getType(SrcReg).isVector()) 1067 return false; 1068 1069 Register TruncSrc; 1070 if (mi_match(SrcReg, MRI, m_GTrunc(m_Reg(TruncSrc)))) 1071 LoadUser = TruncSrc; 1072 1073 uint64_t SizeInBits = MI.getOperand(2).getImm(); 1074 // If the source is a G_SEXTLOAD from the same bit width, then we don't 1075 // need any extend at all, just a truncate. 1076 if (auto *LoadMI = getOpcodeDef<GSExtLoad>(LoadUser, MRI)) { 1077 // If truncating more than the original extended value, abort. 1078 auto LoadSizeBits = LoadMI->getMemSizeInBits(); 1079 if (TruncSrc && 1080 MRI.getType(TruncSrc).getSizeInBits() < LoadSizeBits.getValue()) 1081 return false; 1082 if (LoadSizeBits == SizeInBits) 1083 return true; 1084 } 1085 return false; 1086 } 1087 1088 void CombinerHelper::applySextTruncSextLoad(MachineInstr &MI) const { 1089 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); 1090 Builder.buildCopy(MI.getOperand(0).getReg(), MI.getOperand(1).getReg()); 1091 MI.eraseFromParent(); 1092 } 1093 1094 bool CombinerHelper::matchSextInRegOfLoad( 1095 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) const { 1096 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); 1097 1098 Register DstReg = MI.getOperand(0).getReg(); 1099 LLT RegTy = MRI.getType(DstReg); 1100 1101 // Only supports scalars for now. 1102 if (RegTy.isVector()) 1103 return false; 1104 1105 Register SrcReg = MI.getOperand(1).getReg(); 1106 auto *LoadDef = getOpcodeDef<GLoad>(SrcReg, MRI); 1107 if (!LoadDef || !MRI.hasOneNonDBGUse(SrcReg)) 1108 return false; 1109 1110 uint64_t MemBits = LoadDef->getMemSizeInBits().getValue(); 1111 1112 // If the sign extend extends from a narrower width than the load's width, 1113 // then we can narrow the load width when we combine to a G_SEXTLOAD. 1114 // Avoid widening the load at all. 1115 unsigned NewSizeBits = std::min((uint64_t)MI.getOperand(2).getImm(), MemBits); 1116 1117 // Don't generate G_SEXTLOADs with a < 1 byte width. 1118 if (NewSizeBits < 8) 1119 return false; 1120 // Don't bother creating a non-power-2 sextload, it will likely be broken up 1121 // anyway for most targets. 1122 if (!isPowerOf2_32(NewSizeBits)) 1123 return false; 1124 1125 const MachineMemOperand &MMO = LoadDef->getMMO(); 1126 LegalityQuery::MemDesc MMDesc(MMO); 1127 1128 // Don't modify the memory access size if this is atomic/volatile, but we can 1129 // still adjust the opcode to indicate the high bit behavior. 1130 if (LoadDef->isSimple()) 1131 MMDesc.MemoryTy = LLT::scalar(NewSizeBits); 1132 else if (MemBits > NewSizeBits || MemBits == RegTy.getSizeInBits()) 1133 return false; 1134 1135 // TODO: Could check if it's legal with the reduced or original memory size. 1136 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SEXTLOAD, 1137 {MRI.getType(LoadDef->getDstReg()), 1138 MRI.getType(LoadDef->getPointerReg())}, 1139 {MMDesc}})) 1140 return false; 1141 1142 MatchInfo = std::make_tuple(LoadDef->getDstReg(), NewSizeBits); 1143 return true; 1144 } 1145 1146 void CombinerHelper::applySextInRegOfLoad( 1147 MachineInstr &MI, std::tuple<Register, unsigned> &MatchInfo) const { 1148 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); 1149 Register LoadReg; 1150 unsigned ScalarSizeBits; 1151 std::tie(LoadReg, ScalarSizeBits) = MatchInfo; 1152 GLoad *LoadDef = cast<GLoad>(MRI.getVRegDef(LoadReg)); 1153 1154 // If we have the following: 1155 // %ld = G_LOAD %ptr, (load 2) 1156 // %ext = G_SEXT_INREG %ld, 8 1157 // ==> 1158 // %ld = G_SEXTLOAD %ptr (load 1) 1159 1160 auto &MMO = LoadDef->getMMO(); 1161 Builder.setInstrAndDebugLoc(*LoadDef); 1162 auto &MF = Builder.getMF(); 1163 auto PtrInfo = MMO.getPointerInfo(); 1164 auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, ScalarSizeBits / 8); 1165 Builder.buildLoadInstr(TargetOpcode::G_SEXTLOAD, MI.getOperand(0).getReg(), 1166 LoadDef->getPointerReg(), *NewMMO); 1167 MI.eraseFromParent(); 1168 1169 // Not all loads can be deleted, so make sure the old one is removed. 1170 LoadDef->eraseFromParent(); 1171 } 1172 1173 /// Return true if 'MI' is a load or a store that may be fold it's address 1174 /// operand into the load / store addressing mode. 1175 static bool canFoldInAddressingMode(GLoadStore *MI, const TargetLowering &TLI, 1176 MachineRegisterInfo &MRI) { 1177 TargetLowering::AddrMode AM; 1178 auto *MF = MI->getMF(); 1179 auto *Addr = getOpcodeDef<GPtrAdd>(MI->getPointerReg(), MRI); 1180 if (!Addr) 1181 return false; 1182 1183 AM.HasBaseReg = true; 1184 if (auto CstOff = getIConstantVRegVal(Addr->getOffsetReg(), MRI)) 1185 AM.BaseOffs = CstOff->getSExtValue(); // [reg +/- imm] 1186 else 1187 AM.Scale = 1; // [reg +/- reg] 1188 1189 return TLI.isLegalAddressingMode( 1190 MF->getDataLayout(), AM, 1191 getTypeForLLT(MI->getMMO().getMemoryType(), 1192 MF->getFunction().getContext()), 1193 MI->getMMO().getAddrSpace()); 1194 } 1195 1196 static unsigned getIndexedOpc(unsigned LdStOpc) { 1197 switch (LdStOpc) { 1198 case TargetOpcode::G_LOAD: 1199 return TargetOpcode::G_INDEXED_LOAD; 1200 case TargetOpcode::G_STORE: 1201 return TargetOpcode::G_INDEXED_STORE; 1202 case TargetOpcode::G_ZEXTLOAD: 1203 return TargetOpcode::G_INDEXED_ZEXTLOAD; 1204 case TargetOpcode::G_SEXTLOAD: 1205 return TargetOpcode::G_INDEXED_SEXTLOAD; 1206 default: 1207 llvm_unreachable("Unexpected opcode"); 1208 } 1209 } 1210 1211 bool CombinerHelper::isIndexedLoadStoreLegal(GLoadStore &LdSt) const { 1212 // Check for legality. 1213 LLT PtrTy = MRI.getType(LdSt.getPointerReg()); 1214 LLT Ty = MRI.getType(LdSt.getReg(0)); 1215 LLT MemTy = LdSt.getMMO().getMemoryType(); 1216 SmallVector<LegalityQuery::MemDesc, 2> MemDescrs( 1217 {{MemTy, MemTy.getSizeInBits().getKnownMinValue(), 1218 AtomicOrdering::NotAtomic}}); 1219 unsigned IndexedOpc = getIndexedOpc(LdSt.getOpcode()); 1220 SmallVector<LLT> OpTys; 1221 if (IndexedOpc == TargetOpcode::G_INDEXED_STORE) 1222 OpTys = {PtrTy, Ty, Ty}; 1223 else 1224 OpTys = {Ty, PtrTy}; // For G_INDEXED_LOAD, G_INDEXED_[SZ]EXTLOAD 1225 1226 LegalityQuery Q(IndexedOpc, OpTys, MemDescrs); 1227 return isLegal(Q); 1228 } 1229 1230 static cl::opt<unsigned> PostIndexUseThreshold( 1231 "post-index-use-threshold", cl::Hidden, cl::init(32), 1232 cl::desc("Number of uses of a base pointer to check before it is no longer " 1233 "considered for post-indexing.")); 1234 1235 bool CombinerHelper::findPostIndexCandidate(GLoadStore &LdSt, Register &Addr, 1236 Register &Base, Register &Offset, 1237 bool &RematOffset) const { 1238 // We're looking for the following pattern, for either load or store: 1239 // %baseptr:_(p0) = ... 1240 // G_STORE %val(s64), %baseptr(p0) 1241 // %offset:_(s64) = G_CONSTANT i64 -256 1242 // %new_addr:_(p0) = G_PTR_ADD %baseptr, %offset(s64) 1243 const auto &TLI = getTargetLowering(); 1244 1245 Register Ptr = LdSt.getPointerReg(); 1246 // If the store is the only use, don't bother. 1247 if (MRI.hasOneNonDBGUse(Ptr)) 1248 return false; 1249 1250 if (!isIndexedLoadStoreLegal(LdSt)) 1251 return false; 1252 1253 if (getOpcodeDef(TargetOpcode::G_FRAME_INDEX, Ptr, MRI)) 1254 return false; 1255 1256 MachineInstr *StoredValDef = getDefIgnoringCopies(LdSt.getReg(0), MRI); 1257 auto *PtrDef = MRI.getVRegDef(Ptr); 1258 1259 unsigned NumUsesChecked = 0; 1260 for (auto &Use : MRI.use_nodbg_instructions(Ptr)) { 1261 if (++NumUsesChecked > PostIndexUseThreshold) 1262 return false; // Try to avoid exploding compile time. 1263 1264 auto *PtrAdd = dyn_cast<GPtrAdd>(&Use); 1265 // The use itself might be dead. This can happen during combines if DCE 1266 // hasn't had a chance to run yet. Don't allow it to form an indexed op. 1267 if (!PtrAdd || MRI.use_nodbg_empty(PtrAdd->getReg(0))) 1268 continue; 1269 1270 // Check the user of this isn't the store, otherwise we'd be generate a 1271 // indexed store defining its own use. 1272 if (StoredValDef == &Use) 1273 continue; 1274 1275 Offset = PtrAdd->getOffsetReg(); 1276 if (!ForceLegalIndexing && 1277 !TLI.isIndexingLegal(LdSt, PtrAdd->getBaseReg(), Offset, 1278 /*IsPre*/ false, MRI)) 1279 continue; 1280 1281 // Make sure the offset calculation is before the potentially indexed op. 1282 MachineInstr *OffsetDef = MRI.getVRegDef(Offset); 1283 RematOffset = false; 1284 if (!dominates(*OffsetDef, LdSt)) { 1285 // If the offset however is just a G_CONSTANT, we can always just 1286 // rematerialize it where we need it. 1287 if (OffsetDef->getOpcode() != TargetOpcode::G_CONSTANT) 1288 continue; 1289 RematOffset = true; 1290 } 1291 1292 for (auto &BasePtrUse : MRI.use_nodbg_instructions(PtrAdd->getBaseReg())) { 1293 if (&BasePtrUse == PtrDef) 1294 continue; 1295 1296 // If the user is a later load/store that can be post-indexed, then don't 1297 // combine this one. 1298 auto *BasePtrLdSt = dyn_cast<GLoadStore>(&BasePtrUse); 1299 if (BasePtrLdSt && BasePtrLdSt != &LdSt && 1300 dominates(LdSt, *BasePtrLdSt) && 1301 isIndexedLoadStoreLegal(*BasePtrLdSt)) 1302 return false; 1303 1304 // Now we're looking for the key G_PTR_ADD instruction, which contains 1305 // the offset add that we want to fold. 1306 if (auto *BasePtrUseDef = dyn_cast<GPtrAdd>(&BasePtrUse)) { 1307 Register PtrAddDefReg = BasePtrUseDef->getReg(0); 1308 for (auto &BaseUseUse : MRI.use_nodbg_instructions(PtrAddDefReg)) { 1309 // If the use is in a different block, then we may produce worse code 1310 // due to the extra register pressure. 1311 if (BaseUseUse.getParent() != LdSt.getParent()) 1312 return false; 1313 1314 if (auto *UseUseLdSt = dyn_cast<GLoadStore>(&BaseUseUse)) 1315 if (canFoldInAddressingMode(UseUseLdSt, TLI, MRI)) 1316 return false; 1317 } 1318 if (!dominates(LdSt, BasePtrUse)) 1319 return false; // All use must be dominated by the load/store. 1320 } 1321 } 1322 1323 Addr = PtrAdd->getReg(0); 1324 Base = PtrAdd->getBaseReg(); 1325 return true; 1326 } 1327 1328 return false; 1329 } 1330 1331 bool CombinerHelper::findPreIndexCandidate(GLoadStore &LdSt, Register &Addr, 1332 Register &Base, 1333 Register &Offset) const { 1334 auto &MF = *LdSt.getParent()->getParent(); 1335 const auto &TLI = *MF.getSubtarget().getTargetLowering(); 1336 1337 Addr = LdSt.getPointerReg(); 1338 if (!mi_match(Addr, MRI, m_GPtrAdd(m_Reg(Base), m_Reg(Offset))) || 1339 MRI.hasOneNonDBGUse(Addr)) 1340 return false; 1341 1342 if (!ForceLegalIndexing && 1343 !TLI.isIndexingLegal(LdSt, Base, Offset, /*IsPre*/ true, MRI)) 1344 return false; 1345 1346 if (!isIndexedLoadStoreLegal(LdSt)) 1347 return false; 1348 1349 MachineInstr *BaseDef = getDefIgnoringCopies(Base, MRI); 1350 if (BaseDef->getOpcode() == TargetOpcode::G_FRAME_INDEX) 1351 return false; 1352 1353 if (auto *St = dyn_cast<GStore>(&LdSt)) { 1354 // Would require a copy. 1355 if (Base == St->getValueReg()) 1356 return false; 1357 1358 // We're expecting one use of Addr in MI, but it could also be the 1359 // value stored, which isn't actually dominated by the instruction. 1360 if (St->getValueReg() == Addr) 1361 return false; 1362 } 1363 1364 // Avoid increasing cross-block register pressure. 1365 for (auto &AddrUse : MRI.use_nodbg_instructions(Addr)) 1366 if (AddrUse.getParent() != LdSt.getParent()) 1367 return false; 1368 1369 // FIXME: check whether all uses of the base pointer are constant PtrAdds. 1370 // That might allow us to end base's liveness here by adjusting the constant. 1371 bool RealUse = false; 1372 for (auto &AddrUse : MRI.use_nodbg_instructions(Addr)) { 1373 if (!dominates(LdSt, AddrUse)) 1374 return false; // All use must be dominated by the load/store. 1375 1376 // If Ptr may be folded in addressing mode of other use, then it's 1377 // not profitable to do this transformation. 1378 if (auto *UseLdSt = dyn_cast<GLoadStore>(&AddrUse)) { 1379 if (!canFoldInAddressingMode(UseLdSt, TLI, MRI)) 1380 RealUse = true; 1381 } else { 1382 RealUse = true; 1383 } 1384 } 1385 return RealUse; 1386 } 1387 1388 bool CombinerHelper::matchCombineExtractedVectorLoad( 1389 MachineInstr &MI, BuildFnTy &MatchInfo) const { 1390 assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT); 1391 1392 // Check if there is a load that defines the vector being extracted from. 1393 auto *LoadMI = getOpcodeDef<GLoad>(MI.getOperand(1).getReg(), MRI); 1394 if (!LoadMI) 1395 return false; 1396 1397 Register Vector = MI.getOperand(1).getReg(); 1398 LLT VecEltTy = MRI.getType(Vector).getElementType(); 1399 1400 assert(MRI.getType(MI.getOperand(0).getReg()) == VecEltTy); 1401 1402 // Checking whether we should reduce the load width. 1403 if (!MRI.hasOneNonDBGUse(Vector)) 1404 return false; 1405 1406 // Check if the defining load is simple. 1407 if (!LoadMI->isSimple()) 1408 return false; 1409 1410 // If the vector element type is not a multiple of a byte then we are unable 1411 // to correctly compute an address to load only the extracted element as a 1412 // scalar. 1413 if (!VecEltTy.isByteSized()) 1414 return false; 1415 1416 // Check for load fold barriers between the extraction and the load. 1417 if (MI.getParent() != LoadMI->getParent()) 1418 return false; 1419 const unsigned MaxIter = 20; 1420 unsigned Iter = 0; 1421 for (auto II = LoadMI->getIterator(), IE = MI.getIterator(); II != IE; ++II) { 1422 if (II->isLoadFoldBarrier()) 1423 return false; 1424 if (Iter++ == MaxIter) 1425 return false; 1426 } 1427 1428 // Check if the new load that we are going to create is legal 1429 // if we are in the post-legalization phase. 1430 MachineMemOperand MMO = LoadMI->getMMO(); 1431 Align Alignment = MMO.getAlign(); 1432 MachinePointerInfo PtrInfo; 1433 uint64_t Offset; 1434 1435 // Finding the appropriate PtrInfo if offset is a known constant. 1436 // This is required to create the memory operand for the narrowed load. 1437 // This machine memory operand object helps us infer about legality 1438 // before we proceed to combine the instruction. 1439 if (auto CVal = getIConstantVRegVal(Vector, MRI)) { 1440 int Elt = CVal->getZExtValue(); 1441 // FIXME: should be (ABI size)*Elt. 1442 Offset = VecEltTy.getSizeInBits() * Elt / 8; 1443 PtrInfo = MMO.getPointerInfo().getWithOffset(Offset); 1444 } else { 1445 // Discard the pointer info except the address space because the memory 1446 // operand can't represent this new access since the offset is variable. 1447 Offset = VecEltTy.getSizeInBits() / 8; 1448 PtrInfo = MachinePointerInfo(MMO.getPointerInfo().getAddrSpace()); 1449 } 1450 1451 Alignment = commonAlignment(Alignment, Offset); 1452 1453 Register VecPtr = LoadMI->getPointerReg(); 1454 LLT PtrTy = MRI.getType(VecPtr); 1455 1456 MachineFunction &MF = *MI.getMF(); 1457 auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, VecEltTy); 1458 1459 LegalityQuery::MemDesc MMDesc(*NewMMO); 1460 1461 if (!isLegalOrBeforeLegalizer( 1462 {TargetOpcode::G_LOAD, {VecEltTy, PtrTy}, {MMDesc}})) 1463 return false; 1464 1465 // Load must be allowed and fast on the target. 1466 LLVMContext &C = MF.getFunction().getContext(); 1467 auto &DL = MF.getDataLayout(); 1468 unsigned Fast = 0; 1469 if (!getTargetLowering().allowsMemoryAccess(C, DL, VecEltTy, *NewMMO, 1470 &Fast) || 1471 !Fast) 1472 return false; 1473 1474 Register Result = MI.getOperand(0).getReg(); 1475 Register Index = MI.getOperand(2).getReg(); 1476 1477 MatchInfo = [=](MachineIRBuilder &B) { 1478 GISelObserverWrapper DummyObserver; 1479 LegalizerHelper Helper(B.getMF(), DummyObserver, B); 1480 //// Get pointer to the vector element. 1481 Register finalPtr = Helper.getVectorElementPointer( 1482 LoadMI->getPointerReg(), MRI.getType(LoadMI->getOperand(0).getReg()), 1483 Index); 1484 // New G_LOAD instruction. 1485 B.buildLoad(Result, finalPtr, PtrInfo, Alignment); 1486 // Remove original GLOAD instruction. 1487 LoadMI->eraseFromParent(); 1488 }; 1489 1490 return true; 1491 } 1492 1493 bool CombinerHelper::matchCombineIndexedLoadStore( 1494 MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) const { 1495 auto &LdSt = cast<GLoadStore>(MI); 1496 1497 if (LdSt.isAtomic()) 1498 return false; 1499 1500 MatchInfo.IsPre = findPreIndexCandidate(LdSt, MatchInfo.Addr, MatchInfo.Base, 1501 MatchInfo.Offset); 1502 if (!MatchInfo.IsPre && 1503 !findPostIndexCandidate(LdSt, MatchInfo.Addr, MatchInfo.Base, 1504 MatchInfo.Offset, MatchInfo.RematOffset)) 1505 return false; 1506 1507 return true; 1508 } 1509 1510 void CombinerHelper::applyCombineIndexedLoadStore( 1511 MachineInstr &MI, IndexedLoadStoreMatchInfo &MatchInfo) const { 1512 MachineInstr &AddrDef = *MRI.getUniqueVRegDef(MatchInfo.Addr); 1513 unsigned Opcode = MI.getOpcode(); 1514 bool IsStore = Opcode == TargetOpcode::G_STORE; 1515 unsigned NewOpcode = getIndexedOpc(Opcode); 1516 1517 // If the offset constant didn't happen to dominate the load/store, we can 1518 // just clone it as needed. 1519 if (MatchInfo.RematOffset) { 1520 auto *OldCst = MRI.getVRegDef(MatchInfo.Offset); 1521 auto NewCst = Builder.buildConstant(MRI.getType(MatchInfo.Offset), 1522 *OldCst->getOperand(1).getCImm()); 1523 MatchInfo.Offset = NewCst.getReg(0); 1524 } 1525 1526 auto MIB = Builder.buildInstr(NewOpcode); 1527 if (IsStore) { 1528 MIB.addDef(MatchInfo.Addr); 1529 MIB.addUse(MI.getOperand(0).getReg()); 1530 } else { 1531 MIB.addDef(MI.getOperand(0).getReg()); 1532 MIB.addDef(MatchInfo.Addr); 1533 } 1534 1535 MIB.addUse(MatchInfo.Base); 1536 MIB.addUse(MatchInfo.Offset); 1537 MIB.addImm(MatchInfo.IsPre); 1538 MIB->cloneMemRefs(*MI.getMF(), MI); 1539 MI.eraseFromParent(); 1540 AddrDef.eraseFromParent(); 1541 1542 LLVM_DEBUG(dbgs() << " Combinined to indexed operation"); 1543 } 1544 1545 bool CombinerHelper::matchCombineDivRem(MachineInstr &MI, 1546 MachineInstr *&OtherMI) const { 1547 unsigned Opcode = MI.getOpcode(); 1548 bool IsDiv, IsSigned; 1549 1550 switch (Opcode) { 1551 default: 1552 llvm_unreachable("Unexpected opcode!"); 1553 case TargetOpcode::G_SDIV: 1554 case TargetOpcode::G_UDIV: { 1555 IsDiv = true; 1556 IsSigned = Opcode == TargetOpcode::G_SDIV; 1557 break; 1558 } 1559 case TargetOpcode::G_SREM: 1560 case TargetOpcode::G_UREM: { 1561 IsDiv = false; 1562 IsSigned = Opcode == TargetOpcode::G_SREM; 1563 break; 1564 } 1565 } 1566 1567 Register Src1 = MI.getOperand(1).getReg(); 1568 unsigned DivOpcode, RemOpcode, DivremOpcode; 1569 if (IsSigned) { 1570 DivOpcode = TargetOpcode::G_SDIV; 1571 RemOpcode = TargetOpcode::G_SREM; 1572 DivremOpcode = TargetOpcode::G_SDIVREM; 1573 } else { 1574 DivOpcode = TargetOpcode::G_UDIV; 1575 RemOpcode = TargetOpcode::G_UREM; 1576 DivremOpcode = TargetOpcode::G_UDIVREM; 1577 } 1578 1579 if (!isLegalOrBeforeLegalizer({DivremOpcode, {MRI.getType(Src1)}})) 1580 return false; 1581 1582 // Combine: 1583 // %div:_ = G_[SU]DIV %src1:_, %src2:_ 1584 // %rem:_ = G_[SU]REM %src1:_, %src2:_ 1585 // into: 1586 // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_ 1587 1588 // Combine: 1589 // %rem:_ = G_[SU]REM %src1:_, %src2:_ 1590 // %div:_ = G_[SU]DIV %src1:_, %src2:_ 1591 // into: 1592 // %div:_, %rem:_ = G_[SU]DIVREM %src1:_, %src2:_ 1593 1594 for (auto &UseMI : MRI.use_nodbg_instructions(Src1)) { 1595 if (MI.getParent() == UseMI.getParent() && 1596 ((IsDiv && UseMI.getOpcode() == RemOpcode) || 1597 (!IsDiv && UseMI.getOpcode() == DivOpcode)) && 1598 matchEqualDefs(MI.getOperand(2), UseMI.getOperand(2)) && 1599 matchEqualDefs(MI.getOperand(1), UseMI.getOperand(1))) { 1600 OtherMI = &UseMI; 1601 return true; 1602 } 1603 } 1604 1605 return false; 1606 } 1607 1608 void CombinerHelper::applyCombineDivRem(MachineInstr &MI, 1609 MachineInstr *&OtherMI) const { 1610 unsigned Opcode = MI.getOpcode(); 1611 assert(OtherMI && "OtherMI shouldn't be empty."); 1612 1613 Register DestDivReg, DestRemReg; 1614 if (Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_UDIV) { 1615 DestDivReg = MI.getOperand(0).getReg(); 1616 DestRemReg = OtherMI->getOperand(0).getReg(); 1617 } else { 1618 DestDivReg = OtherMI->getOperand(0).getReg(); 1619 DestRemReg = MI.getOperand(0).getReg(); 1620 } 1621 1622 bool IsSigned = 1623 Opcode == TargetOpcode::G_SDIV || Opcode == TargetOpcode::G_SREM; 1624 1625 // Check which instruction is first in the block so we don't break def-use 1626 // deps by "moving" the instruction incorrectly. Also keep track of which 1627 // instruction is first so we pick it's operands, avoiding use-before-def 1628 // bugs. 1629 MachineInstr *FirstInst = dominates(MI, *OtherMI) ? &MI : OtherMI; 1630 Builder.setInstrAndDebugLoc(*FirstInst); 1631 1632 Builder.buildInstr(IsSigned ? TargetOpcode::G_SDIVREM 1633 : TargetOpcode::G_UDIVREM, 1634 {DestDivReg, DestRemReg}, 1635 { FirstInst->getOperand(1), FirstInst->getOperand(2) }); 1636 MI.eraseFromParent(); 1637 OtherMI->eraseFromParent(); 1638 } 1639 1640 bool CombinerHelper::matchOptBrCondByInvertingCond( 1641 MachineInstr &MI, MachineInstr *&BrCond) const { 1642 assert(MI.getOpcode() == TargetOpcode::G_BR); 1643 1644 // Try to match the following: 1645 // bb1: 1646 // G_BRCOND %c1, %bb2 1647 // G_BR %bb3 1648 // bb2: 1649 // ... 1650 // bb3: 1651 1652 // The above pattern does not have a fall through to the successor bb2, always 1653 // resulting in a branch no matter which path is taken. Here we try to find 1654 // and replace that pattern with conditional branch to bb3 and otherwise 1655 // fallthrough to bb2. This is generally better for branch predictors. 1656 1657 MachineBasicBlock *MBB = MI.getParent(); 1658 MachineBasicBlock::iterator BrIt(MI); 1659 if (BrIt == MBB->begin()) 1660 return false; 1661 assert(std::next(BrIt) == MBB->end() && "expected G_BR to be a terminator"); 1662 1663 BrCond = &*std::prev(BrIt); 1664 if (BrCond->getOpcode() != TargetOpcode::G_BRCOND) 1665 return false; 1666 1667 // Check that the next block is the conditional branch target. Also make sure 1668 // that it isn't the same as the G_BR's target (otherwise, this will loop.) 1669 MachineBasicBlock *BrCondTarget = BrCond->getOperand(1).getMBB(); 1670 return BrCondTarget != MI.getOperand(0).getMBB() && 1671 MBB->isLayoutSuccessor(BrCondTarget); 1672 } 1673 1674 void CombinerHelper::applyOptBrCondByInvertingCond( 1675 MachineInstr &MI, MachineInstr *&BrCond) const { 1676 MachineBasicBlock *BrTarget = MI.getOperand(0).getMBB(); 1677 Builder.setInstrAndDebugLoc(*BrCond); 1678 LLT Ty = MRI.getType(BrCond->getOperand(0).getReg()); 1679 // FIXME: Does int/fp matter for this? If so, we might need to restrict 1680 // this to i1 only since we might not know for sure what kind of 1681 // compare generated the condition value. 1682 auto True = Builder.buildConstant( 1683 Ty, getICmpTrueVal(getTargetLowering(), false, false)); 1684 auto Xor = Builder.buildXor(Ty, BrCond->getOperand(0), True); 1685 1686 auto *FallthroughBB = BrCond->getOperand(1).getMBB(); 1687 Observer.changingInstr(MI); 1688 MI.getOperand(0).setMBB(FallthroughBB); 1689 Observer.changedInstr(MI); 1690 1691 // Change the conditional branch to use the inverted condition and 1692 // new target block. 1693 Observer.changingInstr(*BrCond); 1694 BrCond->getOperand(0).setReg(Xor.getReg(0)); 1695 BrCond->getOperand(1).setMBB(BrTarget); 1696 Observer.changedInstr(*BrCond); 1697 } 1698 1699 bool CombinerHelper::tryEmitMemcpyInline(MachineInstr &MI) const { 1700 MachineIRBuilder HelperBuilder(MI); 1701 GISelObserverWrapper DummyObserver; 1702 LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder); 1703 return Helper.lowerMemcpyInline(MI) == 1704 LegalizerHelper::LegalizeResult::Legalized; 1705 } 1706 1707 bool CombinerHelper::tryCombineMemCpyFamily(MachineInstr &MI, 1708 unsigned MaxLen) const { 1709 MachineIRBuilder HelperBuilder(MI); 1710 GISelObserverWrapper DummyObserver; 1711 LegalizerHelper Helper(HelperBuilder.getMF(), DummyObserver, HelperBuilder); 1712 return Helper.lowerMemCpyFamily(MI, MaxLen) == 1713 LegalizerHelper::LegalizeResult::Legalized; 1714 } 1715 1716 static APFloat constantFoldFpUnary(const MachineInstr &MI, 1717 const MachineRegisterInfo &MRI, 1718 const APFloat &Val) { 1719 APFloat Result(Val); 1720 switch (MI.getOpcode()) { 1721 default: 1722 llvm_unreachable("Unexpected opcode!"); 1723 case TargetOpcode::G_FNEG: { 1724 Result.changeSign(); 1725 return Result; 1726 } 1727 case TargetOpcode::G_FABS: { 1728 Result.clearSign(); 1729 return Result; 1730 } 1731 case TargetOpcode::G_FPTRUNC: { 1732 bool Unused; 1733 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 1734 Result.convert(getFltSemanticForLLT(DstTy), APFloat::rmNearestTiesToEven, 1735 &Unused); 1736 return Result; 1737 } 1738 case TargetOpcode::G_FSQRT: { 1739 bool Unused; 1740 Result.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 1741 &Unused); 1742 Result = APFloat(sqrt(Result.convertToDouble())); 1743 break; 1744 } 1745 case TargetOpcode::G_FLOG2: { 1746 bool Unused; 1747 Result.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, 1748 &Unused); 1749 Result = APFloat(log2(Result.convertToDouble())); 1750 break; 1751 } 1752 } 1753 // Convert `APFloat` to appropriate IEEE type depending on `DstTy`. Otherwise, 1754 // `buildFConstant` will assert on size mismatch. Only `G_FSQRT`, and 1755 // `G_FLOG2` reach here. 1756 bool Unused; 1757 Result.convert(Val.getSemantics(), APFloat::rmNearestTiesToEven, &Unused); 1758 return Result; 1759 } 1760 1761 void CombinerHelper::applyCombineConstantFoldFpUnary( 1762 MachineInstr &MI, const ConstantFP *Cst) const { 1763 APFloat Folded = constantFoldFpUnary(MI, MRI, Cst->getValue()); 1764 const ConstantFP *NewCst = ConstantFP::get(Builder.getContext(), Folded); 1765 Builder.buildFConstant(MI.getOperand(0), *NewCst); 1766 MI.eraseFromParent(); 1767 } 1768 1769 bool CombinerHelper::matchPtrAddImmedChain(MachineInstr &MI, 1770 PtrAddChain &MatchInfo) const { 1771 // We're trying to match the following pattern: 1772 // %t1 = G_PTR_ADD %base, G_CONSTANT imm1 1773 // %root = G_PTR_ADD %t1, G_CONSTANT imm2 1774 // --> 1775 // %root = G_PTR_ADD %base, G_CONSTANT (imm1 + imm2) 1776 1777 if (MI.getOpcode() != TargetOpcode::G_PTR_ADD) 1778 return false; 1779 1780 Register Add2 = MI.getOperand(1).getReg(); 1781 Register Imm1 = MI.getOperand(2).getReg(); 1782 auto MaybeImmVal = getIConstantVRegValWithLookThrough(Imm1, MRI); 1783 if (!MaybeImmVal) 1784 return false; 1785 1786 MachineInstr *Add2Def = MRI.getVRegDef(Add2); 1787 if (!Add2Def || Add2Def->getOpcode() != TargetOpcode::G_PTR_ADD) 1788 return false; 1789 1790 Register Base = Add2Def->getOperand(1).getReg(); 1791 Register Imm2 = Add2Def->getOperand(2).getReg(); 1792 auto MaybeImm2Val = getIConstantVRegValWithLookThrough(Imm2, MRI); 1793 if (!MaybeImm2Val) 1794 return false; 1795 1796 // Check if the new combined immediate forms an illegal addressing mode. 1797 // Do not combine if it was legal before but would get illegal. 1798 // To do so, we need to find a load/store user of the pointer to get 1799 // the access type. 1800 Type *AccessTy = nullptr; 1801 auto &MF = *MI.getMF(); 1802 for (auto &UseMI : MRI.use_nodbg_instructions(MI.getOperand(0).getReg())) { 1803 if (auto *LdSt = dyn_cast<GLoadStore>(&UseMI)) { 1804 AccessTy = getTypeForLLT(MRI.getType(LdSt->getReg(0)), 1805 MF.getFunction().getContext()); 1806 break; 1807 } 1808 } 1809 TargetLoweringBase::AddrMode AMNew; 1810 APInt CombinedImm = MaybeImmVal->Value + MaybeImm2Val->Value; 1811 AMNew.BaseOffs = CombinedImm.getSExtValue(); 1812 if (AccessTy) { 1813 AMNew.HasBaseReg = true; 1814 TargetLoweringBase::AddrMode AMOld; 1815 AMOld.BaseOffs = MaybeImmVal->Value.getSExtValue(); 1816 AMOld.HasBaseReg = true; 1817 unsigned AS = MRI.getType(Add2).getAddressSpace(); 1818 const auto &TLI = *MF.getSubtarget().getTargetLowering(); 1819 if (TLI.isLegalAddressingMode(MF.getDataLayout(), AMOld, AccessTy, AS) && 1820 !TLI.isLegalAddressingMode(MF.getDataLayout(), AMNew, AccessTy, AS)) 1821 return false; 1822 } 1823 1824 // Pass the combined immediate to the apply function. 1825 MatchInfo.Imm = AMNew.BaseOffs; 1826 MatchInfo.Base = Base; 1827 MatchInfo.Bank = getRegBank(Imm2); 1828 return true; 1829 } 1830 1831 void CombinerHelper::applyPtrAddImmedChain(MachineInstr &MI, 1832 PtrAddChain &MatchInfo) const { 1833 assert(MI.getOpcode() == TargetOpcode::G_PTR_ADD && "Expected G_PTR_ADD"); 1834 MachineIRBuilder MIB(MI); 1835 LLT OffsetTy = MRI.getType(MI.getOperand(2).getReg()); 1836 auto NewOffset = MIB.buildConstant(OffsetTy, MatchInfo.Imm); 1837 setRegBank(NewOffset.getReg(0), MatchInfo.Bank); 1838 Observer.changingInstr(MI); 1839 MI.getOperand(1).setReg(MatchInfo.Base); 1840 MI.getOperand(2).setReg(NewOffset.getReg(0)); 1841 Observer.changedInstr(MI); 1842 } 1843 1844 bool CombinerHelper::matchShiftImmedChain(MachineInstr &MI, 1845 RegisterImmPair &MatchInfo) const { 1846 // We're trying to match the following pattern with any of 1847 // G_SHL/G_ASHR/G_LSHR/G_SSHLSAT/G_USHLSAT shift instructions: 1848 // %t1 = SHIFT %base, G_CONSTANT imm1 1849 // %root = SHIFT %t1, G_CONSTANT imm2 1850 // --> 1851 // %root = SHIFT %base, G_CONSTANT (imm1 + imm2) 1852 1853 unsigned Opcode = MI.getOpcode(); 1854 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR || 1855 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT || 1856 Opcode == TargetOpcode::G_USHLSAT) && 1857 "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT"); 1858 1859 Register Shl2 = MI.getOperand(1).getReg(); 1860 Register Imm1 = MI.getOperand(2).getReg(); 1861 auto MaybeImmVal = getIConstantVRegValWithLookThrough(Imm1, MRI); 1862 if (!MaybeImmVal) 1863 return false; 1864 1865 MachineInstr *Shl2Def = MRI.getUniqueVRegDef(Shl2); 1866 if (Shl2Def->getOpcode() != Opcode) 1867 return false; 1868 1869 Register Base = Shl2Def->getOperand(1).getReg(); 1870 Register Imm2 = Shl2Def->getOperand(2).getReg(); 1871 auto MaybeImm2Val = getIConstantVRegValWithLookThrough(Imm2, MRI); 1872 if (!MaybeImm2Val) 1873 return false; 1874 1875 // Pass the combined immediate to the apply function. 1876 MatchInfo.Imm = 1877 (MaybeImmVal->Value.getZExtValue() + MaybeImm2Val->Value).getZExtValue(); 1878 MatchInfo.Reg = Base; 1879 1880 // There is no simple replacement for a saturating unsigned left shift that 1881 // exceeds the scalar size. 1882 if (Opcode == TargetOpcode::G_USHLSAT && 1883 MatchInfo.Imm >= MRI.getType(Shl2).getScalarSizeInBits()) 1884 return false; 1885 1886 return true; 1887 } 1888 1889 void CombinerHelper::applyShiftImmedChain(MachineInstr &MI, 1890 RegisterImmPair &MatchInfo) const { 1891 unsigned Opcode = MI.getOpcode(); 1892 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR || 1893 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_SSHLSAT || 1894 Opcode == TargetOpcode::G_USHLSAT) && 1895 "Expected G_SHL, G_ASHR, G_LSHR, G_SSHLSAT or G_USHLSAT"); 1896 1897 LLT Ty = MRI.getType(MI.getOperand(1).getReg()); 1898 unsigned const ScalarSizeInBits = Ty.getScalarSizeInBits(); 1899 auto Imm = MatchInfo.Imm; 1900 1901 if (Imm >= ScalarSizeInBits) { 1902 // Any logical shift that exceeds scalar size will produce zero. 1903 if (Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR) { 1904 Builder.buildConstant(MI.getOperand(0), 0); 1905 MI.eraseFromParent(); 1906 return; 1907 } 1908 // Arithmetic shift and saturating signed left shift have no effect beyond 1909 // scalar size. 1910 Imm = ScalarSizeInBits - 1; 1911 } 1912 1913 LLT ImmTy = MRI.getType(MI.getOperand(2).getReg()); 1914 Register NewImm = Builder.buildConstant(ImmTy, Imm).getReg(0); 1915 Observer.changingInstr(MI); 1916 MI.getOperand(1).setReg(MatchInfo.Reg); 1917 MI.getOperand(2).setReg(NewImm); 1918 Observer.changedInstr(MI); 1919 } 1920 1921 bool CombinerHelper::matchShiftOfShiftedLogic( 1922 MachineInstr &MI, ShiftOfShiftedLogic &MatchInfo) const { 1923 // We're trying to match the following pattern with any of 1924 // G_SHL/G_ASHR/G_LSHR/G_USHLSAT/G_SSHLSAT shift instructions in combination 1925 // with any of G_AND/G_OR/G_XOR logic instructions. 1926 // %t1 = SHIFT %X, G_CONSTANT C0 1927 // %t2 = LOGIC %t1, %Y 1928 // %root = SHIFT %t2, G_CONSTANT C1 1929 // --> 1930 // %t3 = SHIFT %X, G_CONSTANT (C0+C1) 1931 // %t4 = SHIFT %Y, G_CONSTANT C1 1932 // %root = LOGIC %t3, %t4 1933 unsigned ShiftOpcode = MI.getOpcode(); 1934 assert((ShiftOpcode == TargetOpcode::G_SHL || 1935 ShiftOpcode == TargetOpcode::G_ASHR || 1936 ShiftOpcode == TargetOpcode::G_LSHR || 1937 ShiftOpcode == TargetOpcode::G_USHLSAT || 1938 ShiftOpcode == TargetOpcode::G_SSHLSAT) && 1939 "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT"); 1940 1941 // Match a one-use bitwise logic op. 1942 Register LogicDest = MI.getOperand(1).getReg(); 1943 if (!MRI.hasOneNonDBGUse(LogicDest)) 1944 return false; 1945 1946 MachineInstr *LogicMI = MRI.getUniqueVRegDef(LogicDest); 1947 unsigned LogicOpcode = LogicMI->getOpcode(); 1948 if (LogicOpcode != TargetOpcode::G_AND && LogicOpcode != TargetOpcode::G_OR && 1949 LogicOpcode != TargetOpcode::G_XOR) 1950 return false; 1951 1952 // Find a matching one-use shift by constant. 1953 const Register C1 = MI.getOperand(2).getReg(); 1954 auto MaybeImmVal = getIConstantVRegValWithLookThrough(C1, MRI); 1955 if (!MaybeImmVal || MaybeImmVal->Value == 0) 1956 return false; 1957 1958 const uint64_t C1Val = MaybeImmVal->Value.getZExtValue(); 1959 1960 auto matchFirstShift = [&](const MachineInstr *MI, uint64_t &ShiftVal) { 1961 // Shift should match previous one and should be a one-use. 1962 if (MI->getOpcode() != ShiftOpcode || 1963 !MRI.hasOneNonDBGUse(MI->getOperand(0).getReg())) 1964 return false; 1965 1966 // Must be a constant. 1967 auto MaybeImmVal = 1968 getIConstantVRegValWithLookThrough(MI->getOperand(2).getReg(), MRI); 1969 if (!MaybeImmVal) 1970 return false; 1971 1972 ShiftVal = MaybeImmVal->Value.getSExtValue(); 1973 return true; 1974 }; 1975 1976 // Logic ops are commutative, so check each operand for a match. 1977 Register LogicMIReg1 = LogicMI->getOperand(1).getReg(); 1978 MachineInstr *LogicMIOp1 = MRI.getUniqueVRegDef(LogicMIReg1); 1979 Register LogicMIReg2 = LogicMI->getOperand(2).getReg(); 1980 MachineInstr *LogicMIOp2 = MRI.getUniqueVRegDef(LogicMIReg2); 1981 uint64_t C0Val; 1982 1983 if (matchFirstShift(LogicMIOp1, C0Val)) { 1984 MatchInfo.LogicNonShiftReg = LogicMIReg2; 1985 MatchInfo.Shift2 = LogicMIOp1; 1986 } else if (matchFirstShift(LogicMIOp2, C0Val)) { 1987 MatchInfo.LogicNonShiftReg = LogicMIReg1; 1988 MatchInfo.Shift2 = LogicMIOp2; 1989 } else 1990 return false; 1991 1992 MatchInfo.ValSum = C0Val + C1Val; 1993 1994 // The fold is not valid if the sum of the shift values exceeds bitwidth. 1995 if (MatchInfo.ValSum >= MRI.getType(LogicDest).getScalarSizeInBits()) 1996 return false; 1997 1998 MatchInfo.Logic = LogicMI; 1999 return true; 2000 } 2001 2002 void CombinerHelper::applyShiftOfShiftedLogic( 2003 MachineInstr &MI, ShiftOfShiftedLogic &MatchInfo) const { 2004 unsigned Opcode = MI.getOpcode(); 2005 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_ASHR || 2006 Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_USHLSAT || 2007 Opcode == TargetOpcode::G_SSHLSAT) && 2008 "Expected G_SHL, G_ASHR, G_LSHR, G_USHLSAT and G_SSHLSAT"); 2009 2010 LLT ShlType = MRI.getType(MI.getOperand(2).getReg()); 2011 LLT DestType = MRI.getType(MI.getOperand(0).getReg()); 2012 2013 Register Const = Builder.buildConstant(ShlType, MatchInfo.ValSum).getReg(0); 2014 2015 Register Shift1Base = MatchInfo.Shift2->getOperand(1).getReg(); 2016 Register Shift1 = 2017 Builder.buildInstr(Opcode, {DestType}, {Shift1Base, Const}).getReg(0); 2018 2019 // If LogicNonShiftReg is the same to Shift1Base, and shift1 const is the same 2020 // to MatchInfo.Shift2 const, CSEMIRBuilder will reuse the old shift1 when 2021 // build shift2. So, if we erase MatchInfo.Shift2 at the end, actually we 2022 // remove old shift1. And it will cause crash later. So erase it earlier to 2023 // avoid the crash. 2024 MatchInfo.Shift2->eraseFromParent(); 2025 2026 Register Shift2Const = MI.getOperand(2).getReg(); 2027 Register Shift2 = Builder 2028 .buildInstr(Opcode, {DestType}, 2029 {MatchInfo.LogicNonShiftReg, Shift2Const}) 2030 .getReg(0); 2031 2032 Register Dest = MI.getOperand(0).getReg(); 2033 Builder.buildInstr(MatchInfo.Logic->getOpcode(), {Dest}, {Shift1, Shift2}); 2034 2035 // This was one use so it's safe to remove it. 2036 MatchInfo.Logic->eraseFromParent(); 2037 2038 MI.eraseFromParent(); 2039 } 2040 2041 bool CombinerHelper::matchCommuteShift(MachineInstr &MI, 2042 BuildFnTy &MatchInfo) const { 2043 assert(MI.getOpcode() == TargetOpcode::G_SHL && "Expected G_SHL"); 2044 // Combine (shl (add x, c1), c2) -> (add (shl x, c2), c1 << c2) 2045 // Combine (shl (or x, c1), c2) -> (or (shl x, c2), c1 << c2) 2046 auto &Shl = cast<GenericMachineInstr>(MI); 2047 Register DstReg = Shl.getReg(0); 2048 Register SrcReg = Shl.getReg(1); 2049 Register ShiftReg = Shl.getReg(2); 2050 Register X, C1; 2051 2052 if (!getTargetLowering().isDesirableToCommuteWithShift(MI, !isPreLegalize())) 2053 return false; 2054 2055 if (!mi_match(SrcReg, MRI, 2056 m_OneNonDBGUse(m_any_of(m_GAdd(m_Reg(X), m_Reg(C1)), 2057 m_GOr(m_Reg(X), m_Reg(C1)))))) 2058 return false; 2059 2060 APInt C1Val, C2Val; 2061 if (!mi_match(C1, MRI, m_ICstOrSplat(C1Val)) || 2062 !mi_match(ShiftReg, MRI, m_ICstOrSplat(C2Val))) 2063 return false; 2064 2065 auto *SrcDef = MRI.getVRegDef(SrcReg); 2066 assert((SrcDef->getOpcode() == TargetOpcode::G_ADD || 2067 SrcDef->getOpcode() == TargetOpcode::G_OR) && "Unexpected op"); 2068 LLT SrcTy = MRI.getType(SrcReg); 2069 MatchInfo = [=](MachineIRBuilder &B) { 2070 auto S1 = B.buildShl(SrcTy, X, ShiftReg); 2071 auto S2 = B.buildShl(SrcTy, C1, ShiftReg); 2072 B.buildInstr(SrcDef->getOpcode(), {DstReg}, {S1, S2}); 2073 }; 2074 return true; 2075 } 2076 2077 bool CombinerHelper::matchCombineMulToShl(MachineInstr &MI, 2078 unsigned &ShiftVal) const { 2079 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL"); 2080 auto MaybeImmVal = 2081 getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI); 2082 if (!MaybeImmVal) 2083 return false; 2084 2085 ShiftVal = MaybeImmVal->Value.exactLogBase2(); 2086 return (static_cast<int32_t>(ShiftVal) != -1); 2087 } 2088 2089 void CombinerHelper::applyCombineMulToShl(MachineInstr &MI, 2090 unsigned &ShiftVal) const { 2091 assert(MI.getOpcode() == TargetOpcode::G_MUL && "Expected a G_MUL"); 2092 MachineIRBuilder MIB(MI); 2093 LLT ShiftTy = MRI.getType(MI.getOperand(0).getReg()); 2094 auto ShiftCst = MIB.buildConstant(ShiftTy, ShiftVal); 2095 Observer.changingInstr(MI); 2096 MI.setDesc(MIB.getTII().get(TargetOpcode::G_SHL)); 2097 MI.getOperand(2).setReg(ShiftCst.getReg(0)); 2098 if (ShiftVal == ShiftTy.getScalarSizeInBits() - 1) 2099 MI.clearFlag(MachineInstr::MIFlag::NoSWrap); 2100 Observer.changedInstr(MI); 2101 } 2102 2103 bool CombinerHelper::matchCombineSubToAdd(MachineInstr &MI, 2104 BuildFnTy &MatchInfo) const { 2105 GSub &Sub = cast<GSub>(MI); 2106 2107 LLT Ty = MRI.getType(Sub.getReg(0)); 2108 2109 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, {Ty}})) 2110 return false; 2111 2112 if (!isConstantLegalOrBeforeLegalizer(Ty)) 2113 return false; 2114 2115 APInt Imm = getIConstantFromReg(Sub.getRHSReg(), MRI); 2116 2117 MatchInfo = [=, &MI](MachineIRBuilder &B) { 2118 auto NegCst = B.buildConstant(Ty, -Imm); 2119 Observer.changingInstr(MI); 2120 MI.setDesc(B.getTII().get(TargetOpcode::G_ADD)); 2121 MI.getOperand(2).setReg(NegCst.getReg(0)); 2122 MI.clearFlag(MachineInstr::MIFlag::NoUWrap); 2123 if (Imm.isMinSignedValue()) 2124 MI.clearFlags(MachineInstr::MIFlag::NoSWrap); 2125 Observer.changedInstr(MI); 2126 }; 2127 return true; 2128 } 2129 2130 // shl ([sza]ext x), y => zext (shl x, y), if shift does not overflow source 2131 bool CombinerHelper::matchCombineShlOfExtend(MachineInstr &MI, 2132 RegisterImmPair &MatchData) const { 2133 assert(MI.getOpcode() == TargetOpcode::G_SHL && VT); 2134 if (!getTargetLowering().isDesirableToPullExtFromShl(MI)) 2135 return false; 2136 2137 Register LHS = MI.getOperand(1).getReg(); 2138 2139 Register ExtSrc; 2140 if (!mi_match(LHS, MRI, m_GAnyExt(m_Reg(ExtSrc))) && 2141 !mi_match(LHS, MRI, m_GZExt(m_Reg(ExtSrc))) && 2142 !mi_match(LHS, MRI, m_GSExt(m_Reg(ExtSrc)))) 2143 return false; 2144 2145 Register RHS = MI.getOperand(2).getReg(); 2146 MachineInstr *MIShiftAmt = MRI.getVRegDef(RHS); 2147 auto MaybeShiftAmtVal = isConstantOrConstantSplatVector(*MIShiftAmt, MRI); 2148 if (!MaybeShiftAmtVal) 2149 return false; 2150 2151 if (LI) { 2152 LLT SrcTy = MRI.getType(ExtSrc); 2153 2154 // We only really care about the legality with the shifted value. We can 2155 // pick any type the constant shift amount, so ask the target what to 2156 // use. Otherwise we would have to guess and hope it is reported as legal. 2157 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(SrcTy); 2158 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SHL, {SrcTy, ShiftAmtTy}})) 2159 return false; 2160 } 2161 2162 int64_t ShiftAmt = MaybeShiftAmtVal->getSExtValue(); 2163 MatchData.Reg = ExtSrc; 2164 MatchData.Imm = ShiftAmt; 2165 2166 unsigned MinLeadingZeros = VT->getKnownZeroes(ExtSrc).countl_one(); 2167 unsigned SrcTySize = MRI.getType(ExtSrc).getScalarSizeInBits(); 2168 return MinLeadingZeros >= ShiftAmt && ShiftAmt < SrcTySize; 2169 } 2170 2171 void CombinerHelper::applyCombineShlOfExtend( 2172 MachineInstr &MI, const RegisterImmPair &MatchData) const { 2173 Register ExtSrcReg = MatchData.Reg; 2174 int64_t ShiftAmtVal = MatchData.Imm; 2175 2176 LLT ExtSrcTy = MRI.getType(ExtSrcReg); 2177 auto ShiftAmt = Builder.buildConstant(ExtSrcTy, ShiftAmtVal); 2178 auto NarrowShift = 2179 Builder.buildShl(ExtSrcTy, ExtSrcReg, ShiftAmt, MI.getFlags()); 2180 Builder.buildZExt(MI.getOperand(0), NarrowShift); 2181 MI.eraseFromParent(); 2182 } 2183 2184 bool CombinerHelper::matchCombineMergeUnmerge(MachineInstr &MI, 2185 Register &MatchInfo) const { 2186 GMerge &Merge = cast<GMerge>(MI); 2187 SmallVector<Register, 16> MergedValues; 2188 for (unsigned I = 0; I < Merge.getNumSources(); ++I) 2189 MergedValues.emplace_back(Merge.getSourceReg(I)); 2190 2191 auto *Unmerge = getOpcodeDef<GUnmerge>(MergedValues[0], MRI); 2192 if (!Unmerge || Unmerge->getNumDefs() != Merge.getNumSources()) 2193 return false; 2194 2195 for (unsigned I = 0; I < MergedValues.size(); ++I) 2196 if (MergedValues[I] != Unmerge->getReg(I)) 2197 return false; 2198 2199 MatchInfo = Unmerge->getSourceReg(); 2200 return true; 2201 } 2202 2203 static Register peekThroughBitcast(Register Reg, 2204 const MachineRegisterInfo &MRI) { 2205 while (mi_match(Reg, MRI, m_GBitcast(m_Reg(Reg)))) 2206 ; 2207 2208 return Reg; 2209 } 2210 2211 bool CombinerHelper::matchCombineUnmergeMergeToPlainValues( 2212 MachineInstr &MI, SmallVectorImpl<Register> &Operands) const { 2213 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && 2214 "Expected an unmerge"); 2215 auto &Unmerge = cast<GUnmerge>(MI); 2216 Register SrcReg = peekThroughBitcast(Unmerge.getSourceReg(), MRI); 2217 2218 auto *SrcInstr = getOpcodeDef<GMergeLikeInstr>(SrcReg, MRI); 2219 if (!SrcInstr) 2220 return false; 2221 2222 // Check the source type of the merge. 2223 LLT SrcMergeTy = MRI.getType(SrcInstr->getSourceReg(0)); 2224 LLT Dst0Ty = MRI.getType(Unmerge.getReg(0)); 2225 bool SameSize = Dst0Ty.getSizeInBits() == SrcMergeTy.getSizeInBits(); 2226 if (SrcMergeTy != Dst0Ty && !SameSize) 2227 return false; 2228 // They are the same now (modulo a bitcast). 2229 // We can collect all the src registers. 2230 for (unsigned Idx = 0; Idx < SrcInstr->getNumSources(); ++Idx) 2231 Operands.push_back(SrcInstr->getSourceReg(Idx)); 2232 return true; 2233 } 2234 2235 void CombinerHelper::applyCombineUnmergeMergeToPlainValues( 2236 MachineInstr &MI, SmallVectorImpl<Register> &Operands) const { 2237 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && 2238 "Expected an unmerge"); 2239 assert((MI.getNumOperands() - 1 == Operands.size()) && 2240 "Not enough operands to replace all defs"); 2241 unsigned NumElems = MI.getNumOperands() - 1; 2242 2243 LLT SrcTy = MRI.getType(Operands[0]); 2244 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 2245 bool CanReuseInputDirectly = DstTy == SrcTy; 2246 for (unsigned Idx = 0; Idx < NumElems; ++Idx) { 2247 Register DstReg = MI.getOperand(Idx).getReg(); 2248 Register SrcReg = Operands[Idx]; 2249 2250 // This combine may run after RegBankSelect, so we need to be aware of 2251 // register banks. 2252 const auto &DstCB = MRI.getRegClassOrRegBank(DstReg); 2253 if (!DstCB.isNull() && DstCB != MRI.getRegClassOrRegBank(SrcReg)) { 2254 SrcReg = Builder.buildCopy(MRI.getType(SrcReg), SrcReg).getReg(0); 2255 MRI.setRegClassOrRegBank(SrcReg, DstCB); 2256 } 2257 2258 if (CanReuseInputDirectly) 2259 replaceRegWith(MRI, DstReg, SrcReg); 2260 else 2261 Builder.buildCast(DstReg, SrcReg); 2262 } 2263 MI.eraseFromParent(); 2264 } 2265 2266 bool CombinerHelper::matchCombineUnmergeConstant( 2267 MachineInstr &MI, SmallVectorImpl<APInt> &Csts) const { 2268 unsigned SrcIdx = MI.getNumOperands() - 1; 2269 Register SrcReg = MI.getOperand(SrcIdx).getReg(); 2270 MachineInstr *SrcInstr = MRI.getVRegDef(SrcReg); 2271 if (SrcInstr->getOpcode() != TargetOpcode::G_CONSTANT && 2272 SrcInstr->getOpcode() != TargetOpcode::G_FCONSTANT) 2273 return false; 2274 // Break down the big constant in smaller ones. 2275 const MachineOperand &CstVal = SrcInstr->getOperand(1); 2276 APInt Val = SrcInstr->getOpcode() == TargetOpcode::G_CONSTANT 2277 ? CstVal.getCImm()->getValue() 2278 : CstVal.getFPImm()->getValueAPF().bitcastToAPInt(); 2279 2280 LLT Dst0Ty = MRI.getType(MI.getOperand(0).getReg()); 2281 unsigned ShiftAmt = Dst0Ty.getSizeInBits(); 2282 // Unmerge a constant. 2283 for (unsigned Idx = 0; Idx != SrcIdx; ++Idx) { 2284 Csts.emplace_back(Val.trunc(ShiftAmt)); 2285 Val = Val.lshr(ShiftAmt); 2286 } 2287 2288 return true; 2289 } 2290 2291 void CombinerHelper::applyCombineUnmergeConstant( 2292 MachineInstr &MI, SmallVectorImpl<APInt> &Csts) const { 2293 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && 2294 "Expected an unmerge"); 2295 assert((MI.getNumOperands() - 1 == Csts.size()) && 2296 "Not enough operands to replace all defs"); 2297 unsigned NumElems = MI.getNumOperands() - 1; 2298 for (unsigned Idx = 0; Idx < NumElems; ++Idx) { 2299 Register DstReg = MI.getOperand(Idx).getReg(); 2300 Builder.buildConstant(DstReg, Csts[Idx]); 2301 } 2302 2303 MI.eraseFromParent(); 2304 } 2305 2306 bool CombinerHelper::matchCombineUnmergeUndef( 2307 MachineInstr &MI, 2308 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 2309 unsigned SrcIdx = MI.getNumOperands() - 1; 2310 Register SrcReg = MI.getOperand(SrcIdx).getReg(); 2311 MatchInfo = [&MI](MachineIRBuilder &B) { 2312 unsigned NumElems = MI.getNumOperands() - 1; 2313 for (unsigned Idx = 0; Idx < NumElems; ++Idx) { 2314 Register DstReg = MI.getOperand(Idx).getReg(); 2315 B.buildUndef(DstReg); 2316 } 2317 }; 2318 return isa<GImplicitDef>(MRI.getVRegDef(SrcReg)); 2319 } 2320 2321 bool CombinerHelper::matchCombineUnmergeWithDeadLanesToTrunc( 2322 MachineInstr &MI) const { 2323 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && 2324 "Expected an unmerge"); 2325 if (MRI.getType(MI.getOperand(0).getReg()).isVector() || 2326 MRI.getType(MI.getOperand(MI.getNumDefs()).getReg()).isVector()) 2327 return false; 2328 // Check that all the lanes are dead except the first one. 2329 for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) { 2330 if (!MRI.use_nodbg_empty(MI.getOperand(Idx).getReg())) 2331 return false; 2332 } 2333 return true; 2334 } 2335 2336 void CombinerHelper::applyCombineUnmergeWithDeadLanesToTrunc( 2337 MachineInstr &MI) const { 2338 Register SrcReg = MI.getOperand(MI.getNumDefs()).getReg(); 2339 Register Dst0Reg = MI.getOperand(0).getReg(); 2340 Builder.buildTrunc(Dst0Reg, SrcReg); 2341 MI.eraseFromParent(); 2342 } 2343 2344 bool CombinerHelper::matchCombineUnmergeZExtToZExt(MachineInstr &MI) const { 2345 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && 2346 "Expected an unmerge"); 2347 Register Dst0Reg = MI.getOperand(0).getReg(); 2348 LLT Dst0Ty = MRI.getType(Dst0Reg); 2349 // G_ZEXT on vector applies to each lane, so it will 2350 // affect all destinations. Therefore we won't be able 2351 // to simplify the unmerge to just the first definition. 2352 if (Dst0Ty.isVector()) 2353 return false; 2354 Register SrcReg = MI.getOperand(MI.getNumDefs()).getReg(); 2355 LLT SrcTy = MRI.getType(SrcReg); 2356 if (SrcTy.isVector()) 2357 return false; 2358 2359 Register ZExtSrcReg; 2360 if (!mi_match(SrcReg, MRI, m_GZExt(m_Reg(ZExtSrcReg)))) 2361 return false; 2362 2363 // Finally we can replace the first definition with 2364 // a zext of the source if the definition is big enough to hold 2365 // all of ZExtSrc bits. 2366 LLT ZExtSrcTy = MRI.getType(ZExtSrcReg); 2367 return ZExtSrcTy.getSizeInBits() <= Dst0Ty.getSizeInBits(); 2368 } 2369 2370 void CombinerHelper::applyCombineUnmergeZExtToZExt(MachineInstr &MI) const { 2371 assert(MI.getOpcode() == TargetOpcode::G_UNMERGE_VALUES && 2372 "Expected an unmerge"); 2373 2374 Register Dst0Reg = MI.getOperand(0).getReg(); 2375 2376 MachineInstr *ZExtInstr = 2377 MRI.getVRegDef(MI.getOperand(MI.getNumDefs()).getReg()); 2378 assert(ZExtInstr && ZExtInstr->getOpcode() == TargetOpcode::G_ZEXT && 2379 "Expecting a G_ZEXT"); 2380 2381 Register ZExtSrcReg = ZExtInstr->getOperand(1).getReg(); 2382 LLT Dst0Ty = MRI.getType(Dst0Reg); 2383 LLT ZExtSrcTy = MRI.getType(ZExtSrcReg); 2384 2385 if (Dst0Ty.getSizeInBits() > ZExtSrcTy.getSizeInBits()) { 2386 Builder.buildZExt(Dst0Reg, ZExtSrcReg); 2387 } else { 2388 assert(Dst0Ty.getSizeInBits() == ZExtSrcTy.getSizeInBits() && 2389 "ZExt src doesn't fit in destination"); 2390 replaceRegWith(MRI, Dst0Reg, ZExtSrcReg); 2391 } 2392 2393 Register ZeroReg; 2394 for (unsigned Idx = 1, EndIdx = MI.getNumDefs(); Idx != EndIdx; ++Idx) { 2395 if (!ZeroReg) 2396 ZeroReg = Builder.buildConstant(Dst0Ty, 0).getReg(0); 2397 replaceRegWith(MRI, MI.getOperand(Idx).getReg(), ZeroReg); 2398 } 2399 MI.eraseFromParent(); 2400 } 2401 2402 bool CombinerHelper::matchCombineShiftToUnmerge(MachineInstr &MI, 2403 unsigned TargetShiftSize, 2404 unsigned &ShiftVal) const { 2405 assert((MI.getOpcode() == TargetOpcode::G_SHL || 2406 MI.getOpcode() == TargetOpcode::G_LSHR || 2407 MI.getOpcode() == TargetOpcode::G_ASHR) && "Expected a shift"); 2408 2409 LLT Ty = MRI.getType(MI.getOperand(0).getReg()); 2410 if (Ty.isVector()) // TODO: 2411 return false; 2412 2413 // Don't narrow further than the requested size. 2414 unsigned Size = Ty.getSizeInBits(); 2415 if (Size <= TargetShiftSize) 2416 return false; 2417 2418 auto MaybeImmVal = 2419 getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI); 2420 if (!MaybeImmVal) 2421 return false; 2422 2423 ShiftVal = MaybeImmVal->Value.getSExtValue(); 2424 return ShiftVal >= Size / 2 && ShiftVal < Size; 2425 } 2426 2427 void CombinerHelper::applyCombineShiftToUnmerge( 2428 MachineInstr &MI, const unsigned &ShiftVal) const { 2429 Register DstReg = MI.getOperand(0).getReg(); 2430 Register SrcReg = MI.getOperand(1).getReg(); 2431 LLT Ty = MRI.getType(SrcReg); 2432 unsigned Size = Ty.getSizeInBits(); 2433 unsigned HalfSize = Size / 2; 2434 assert(ShiftVal >= HalfSize); 2435 2436 LLT HalfTy = LLT::scalar(HalfSize); 2437 2438 auto Unmerge = Builder.buildUnmerge(HalfTy, SrcReg); 2439 unsigned NarrowShiftAmt = ShiftVal - HalfSize; 2440 2441 if (MI.getOpcode() == TargetOpcode::G_LSHR) { 2442 Register Narrowed = Unmerge.getReg(1); 2443 2444 // dst = G_LSHR s64:x, C for C >= 32 2445 // => 2446 // lo, hi = G_UNMERGE_VALUES x 2447 // dst = G_MERGE_VALUES (G_LSHR hi, C - 32), 0 2448 2449 if (NarrowShiftAmt != 0) { 2450 Narrowed = Builder.buildLShr(HalfTy, Narrowed, 2451 Builder.buildConstant(HalfTy, NarrowShiftAmt)).getReg(0); 2452 } 2453 2454 auto Zero = Builder.buildConstant(HalfTy, 0); 2455 Builder.buildMergeLikeInstr(DstReg, {Narrowed, Zero}); 2456 } else if (MI.getOpcode() == TargetOpcode::G_SHL) { 2457 Register Narrowed = Unmerge.getReg(0); 2458 // dst = G_SHL s64:x, C for C >= 32 2459 // => 2460 // lo, hi = G_UNMERGE_VALUES x 2461 // dst = G_MERGE_VALUES 0, (G_SHL hi, C - 32) 2462 if (NarrowShiftAmt != 0) { 2463 Narrowed = Builder.buildShl(HalfTy, Narrowed, 2464 Builder.buildConstant(HalfTy, NarrowShiftAmt)).getReg(0); 2465 } 2466 2467 auto Zero = Builder.buildConstant(HalfTy, 0); 2468 Builder.buildMergeLikeInstr(DstReg, {Zero, Narrowed}); 2469 } else { 2470 assert(MI.getOpcode() == TargetOpcode::G_ASHR); 2471 auto Hi = Builder.buildAShr( 2472 HalfTy, Unmerge.getReg(1), 2473 Builder.buildConstant(HalfTy, HalfSize - 1)); 2474 2475 if (ShiftVal == HalfSize) { 2476 // (G_ASHR i64:x, 32) -> 2477 // G_MERGE_VALUES hi_32(x), (G_ASHR hi_32(x), 31) 2478 Builder.buildMergeLikeInstr(DstReg, {Unmerge.getReg(1), Hi}); 2479 } else if (ShiftVal == Size - 1) { 2480 // Don't need a second shift. 2481 // (G_ASHR i64:x, 63) -> 2482 // %narrowed = (G_ASHR hi_32(x), 31) 2483 // G_MERGE_VALUES %narrowed, %narrowed 2484 Builder.buildMergeLikeInstr(DstReg, {Hi, Hi}); 2485 } else { 2486 auto Lo = Builder.buildAShr( 2487 HalfTy, Unmerge.getReg(1), 2488 Builder.buildConstant(HalfTy, ShiftVal - HalfSize)); 2489 2490 // (G_ASHR i64:x, C) ->, for C >= 32 2491 // G_MERGE_VALUES (G_ASHR hi_32(x), C - 32), (G_ASHR hi_32(x), 31) 2492 Builder.buildMergeLikeInstr(DstReg, {Lo, Hi}); 2493 } 2494 } 2495 2496 MI.eraseFromParent(); 2497 } 2498 2499 bool CombinerHelper::tryCombineShiftToUnmerge( 2500 MachineInstr &MI, unsigned TargetShiftAmount) const { 2501 unsigned ShiftAmt; 2502 if (matchCombineShiftToUnmerge(MI, TargetShiftAmount, ShiftAmt)) { 2503 applyCombineShiftToUnmerge(MI, ShiftAmt); 2504 return true; 2505 } 2506 2507 return false; 2508 } 2509 2510 bool CombinerHelper::matchCombineI2PToP2I(MachineInstr &MI, 2511 Register &Reg) const { 2512 assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR"); 2513 Register DstReg = MI.getOperand(0).getReg(); 2514 LLT DstTy = MRI.getType(DstReg); 2515 Register SrcReg = MI.getOperand(1).getReg(); 2516 return mi_match(SrcReg, MRI, 2517 m_GPtrToInt(m_all_of(m_SpecificType(DstTy), m_Reg(Reg)))); 2518 } 2519 2520 void CombinerHelper::applyCombineI2PToP2I(MachineInstr &MI, 2521 Register &Reg) const { 2522 assert(MI.getOpcode() == TargetOpcode::G_INTTOPTR && "Expected a G_INTTOPTR"); 2523 Register DstReg = MI.getOperand(0).getReg(); 2524 Builder.buildCopy(DstReg, Reg); 2525 MI.eraseFromParent(); 2526 } 2527 2528 void CombinerHelper::applyCombineP2IToI2P(MachineInstr &MI, 2529 Register &Reg) const { 2530 assert(MI.getOpcode() == TargetOpcode::G_PTRTOINT && "Expected a G_PTRTOINT"); 2531 Register DstReg = MI.getOperand(0).getReg(); 2532 Builder.buildZExtOrTrunc(DstReg, Reg); 2533 MI.eraseFromParent(); 2534 } 2535 2536 bool CombinerHelper::matchCombineAddP2IToPtrAdd( 2537 MachineInstr &MI, std::pair<Register, bool> &PtrReg) const { 2538 assert(MI.getOpcode() == TargetOpcode::G_ADD); 2539 Register LHS = MI.getOperand(1).getReg(); 2540 Register RHS = MI.getOperand(2).getReg(); 2541 LLT IntTy = MRI.getType(LHS); 2542 2543 // G_PTR_ADD always has the pointer in the LHS, so we may need to commute the 2544 // instruction. 2545 PtrReg.second = false; 2546 for (Register SrcReg : {LHS, RHS}) { 2547 if (mi_match(SrcReg, MRI, m_GPtrToInt(m_Reg(PtrReg.first)))) { 2548 // Don't handle cases where the integer is implicitly converted to the 2549 // pointer width. 2550 LLT PtrTy = MRI.getType(PtrReg.first); 2551 if (PtrTy.getScalarSizeInBits() == IntTy.getScalarSizeInBits()) 2552 return true; 2553 } 2554 2555 PtrReg.second = true; 2556 } 2557 2558 return false; 2559 } 2560 2561 void CombinerHelper::applyCombineAddP2IToPtrAdd( 2562 MachineInstr &MI, std::pair<Register, bool> &PtrReg) const { 2563 Register Dst = MI.getOperand(0).getReg(); 2564 Register LHS = MI.getOperand(1).getReg(); 2565 Register RHS = MI.getOperand(2).getReg(); 2566 2567 const bool DoCommute = PtrReg.second; 2568 if (DoCommute) 2569 std::swap(LHS, RHS); 2570 LHS = PtrReg.first; 2571 2572 LLT PtrTy = MRI.getType(LHS); 2573 2574 auto PtrAdd = Builder.buildPtrAdd(PtrTy, LHS, RHS); 2575 Builder.buildPtrToInt(Dst, PtrAdd); 2576 MI.eraseFromParent(); 2577 } 2578 2579 bool CombinerHelper::matchCombineConstPtrAddToI2P(MachineInstr &MI, 2580 APInt &NewCst) const { 2581 auto &PtrAdd = cast<GPtrAdd>(MI); 2582 Register LHS = PtrAdd.getBaseReg(); 2583 Register RHS = PtrAdd.getOffsetReg(); 2584 MachineRegisterInfo &MRI = Builder.getMF().getRegInfo(); 2585 2586 if (auto RHSCst = getIConstantVRegVal(RHS, MRI)) { 2587 APInt Cst; 2588 if (mi_match(LHS, MRI, m_GIntToPtr(m_ICst(Cst)))) { 2589 auto DstTy = MRI.getType(PtrAdd.getReg(0)); 2590 // G_INTTOPTR uses zero-extension 2591 NewCst = Cst.zextOrTrunc(DstTy.getSizeInBits()); 2592 NewCst += RHSCst->sextOrTrunc(DstTy.getSizeInBits()); 2593 return true; 2594 } 2595 } 2596 2597 return false; 2598 } 2599 2600 void CombinerHelper::applyCombineConstPtrAddToI2P(MachineInstr &MI, 2601 APInt &NewCst) const { 2602 auto &PtrAdd = cast<GPtrAdd>(MI); 2603 Register Dst = PtrAdd.getReg(0); 2604 2605 Builder.buildConstant(Dst, NewCst); 2606 PtrAdd.eraseFromParent(); 2607 } 2608 2609 bool CombinerHelper::matchCombineAnyExtTrunc(MachineInstr &MI, 2610 Register &Reg) const { 2611 assert(MI.getOpcode() == TargetOpcode::G_ANYEXT && "Expected a G_ANYEXT"); 2612 Register DstReg = MI.getOperand(0).getReg(); 2613 Register SrcReg = MI.getOperand(1).getReg(); 2614 Register OriginalSrcReg = getSrcRegIgnoringCopies(SrcReg, MRI); 2615 if (OriginalSrcReg.isValid()) 2616 SrcReg = OriginalSrcReg; 2617 LLT DstTy = MRI.getType(DstReg); 2618 return mi_match(SrcReg, MRI, 2619 m_GTrunc(m_all_of(m_Reg(Reg), m_SpecificType(DstTy)))) && 2620 canReplaceReg(DstReg, Reg, MRI); 2621 } 2622 2623 bool CombinerHelper::matchCombineZextTrunc(MachineInstr &MI, 2624 Register &Reg) const { 2625 assert(MI.getOpcode() == TargetOpcode::G_ZEXT && "Expected a G_ZEXT"); 2626 Register DstReg = MI.getOperand(0).getReg(); 2627 Register SrcReg = MI.getOperand(1).getReg(); 2628 LLT DstTy = MRI.getType(DstReg); 2629 if (mi_match(SrcReg, MRI, 2630 m_GTrunc(m_all_of(m_Reg(Reg), m_SpecificType(DstTy)))) && 2631 canReplaceReg(DstReg, Reg, MRI)) { 2632 unsigned DstSize = DstTy.getScalarSizeInBits(); 2633 unsigned SrcSize = MRI.getType(SrcReg).getScalarSizeInBits(); 2634 return VT->getKnownBits(Reg).countMinLeadingZeros() >= DstSize - SrcSize; 2635 } 2636 return false; 2637 } 2638 2639 static LLT getMidVTForTruncRightShiftCombine(LLT ShiftTy, LLT TruncTy) { 2640 const unsigned ShiftSize = ShiftTy.getScalarSizeInBits(); 2641 const unsigned TruncSize = TruncTy.getScalarSizeInBits(); 2642 2643 // ShiftTy > 32 > TruncTy -> 32 2644 if (ShiftSize > 32 && TruncSize < 32) 2645 return ShiftTy.changeElementSize(32); 2646 2647 // TODO: We could also reduce to 16 bits, but that's more target-dependent. 2648 // Some targets like it, some don't, some only like it under certain 2649 // conditions/processor versions, etc. 2650 // A TL hook might be needed for this. 2651 2652 // Don't combine 2653 return ShiftTy; 2654 } 2655 2656 bool CombinerHelper::matchCombineTruncOfShift( 2657 MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) const { 2658 assert(MI.getOpcode() == TargetOpcode::G_TRUNC && "Expected a G_TRUNC"); 2659 Register DstReg = MI.getOperand(0).getReg(); 2660 Register SrcReg = MI.getOperand(1).getReg(); 2661 2662 if (!MRI.hasOneNonDBGUse(SrcReg)) 2663 return false; 2664 2665 LLT SrcTy = MRI.getType(SrcReg); 2666 LLT DstTy = MRI.getType(DstReg); 2667 2668 MachineInstr *SrcMI = getDefIgnoringCopies(SrcReg, MRI); 2669 const auto &TL = getTargetLowering(); 2670 2671 LLT NewShiftTy; 2672 switch (SrcMI->getOpcode()) { 2673 default: 2674 return false; 2675 case TargetOpcode::G_SHL: { 2676 NewShiftTy = DstTy; 2677 2678 // Make sure new shift amount is legal. 2679 KnownBits Known = VT->getKnownBits(SrcMI->getOperand(2).getReg()); 2680 if (Known.getMaxValue().uge(NewShiftTy.getScalarSizeInBits())) 2681 return false; 2682 break; 2683 } 2684 case TargetOpcode::G_LSHR: 2685 case TargetOpcode::G_ASHR: { 2686 // For right shifts, we conservatively do not do the transform if the TRUNC 2687 // has any STORE users. The reason is that if we change the type of the 2688 // shift, we may break the truncstore combine. 2689 // 2690 // TODO: Fix truncstore combine to handle (trunc(lshr (trunc x), k)). 2691 for (auto &User : MRI.use_instructions(DstReg)) 2692 if (User.getOpcode() == TargetOpcode::G_STORE) 2693 return false; 2694 2695 NewShiftTy = getMidVTForTruncRightShiftCombine(SrcTy, DstTy); 2696 if (NewShiftTy == SrcTy) 2697 return false; 2698 2699 // Make sure we won't lose information by truncating the high bits. 2700 KnownBits Known = VT->getKnownBits(SrcMI->getOperand(2).getReg()); 2701 if (Known.getMaxValue().ugt(NewShiftTy.getScalarSizeInBits() - 2702 DstTy.getScalarSizeInBits())) 2703 return false; 2704 break; 2705 } 2706 } 2707 2708 if (!isLegalOrBeforeLegalizer( 2709 {SrcMI->getOpcode(), 2710 {NewShiftTy, TL.getPreferredShiftAmountTy(NewShiftTy)}})) 2711 return false; 2712 2713 MatchInfo = std::make_pair(SrcMI, NewShiftTy); 2714 return true; 2715 } 2716 2717 void CombinerHelper::applyCombineTruncOfShift( 2718 MachineInstr &MI, std::pair<MachineInstr *, LLT> &MatchInfo) const { 2719 MachineInstr *ShiftMI = MatchInfo.first; 2720 LLT NewShiftTy = MatchInfo.second; 2721 2722 Register Dst = MI.getOperand(0).getReg(); 2723 LLT DstTy = MRI.getType(Dst); 2724 2725 Register ShiftAmt = ShiftMI->getOperand(2).getReg(); 2726 Register ShiftSrc = ShiftMI->getOperand(1).getReg(); 2727 ShiftSrc = Builder.buildTrunc(NewShiftTy, ShiftSrc).getReg(0); 2728 2729 Register NewShift = 2730 Builder 2731 .buildInstr(ShiftMI->getOpcode(), {NewShiftTy}, {ShiftSrc, ShiftAmt}) 2732 .getReg(0); 2733 2734 if (NewShiftTy == DstTy) 2735 replaceRegWith(MRI, Dst, NewShift); 2736 else 2737 Builder.buildTrunc(Dst, NewShift); 2738 2739 eraseInst(MI); 2740 } 2741 2742 bool CombinerHelper::matchAnyExplicitUseIsUndef(MachineInstr &MI) const { 2743 return any_of(MI.explicit_uses(), [this](const MachineOperand &MO) { 2744 return MO.isReg() && 2745 getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MO.getReg(), MRI); 2746 }); 2747 } 2748 2749 bool CombinerHelper::matchAllExplicitUsesAreUndef(MachineInstr &MI) const { 2750 return all_of(MI.explicit_uses(), [this](const MachineOperand &MO) { 2751 return !MO.isReg() || 2752 getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MO.getReg(), MRI); 2753 }); 2754 } 2755 2756 bool CombinerHelper::matchUndefShuffleVectorMask(MachineInstr &MI) const { 2757 assert(MI.getOpcode() == TargetOpcode::G_SHUFFLE_VECTOR); 2758 ArrayRef<int> Mask = MI.getOperand(3).getShuffleMask(); 2759 return all_of(Mask, [](int Elt) { return Elt < 0; }); 2760 } 2761 2762 bool CombinerHelper::matchUndefStore(MachineInstr &MI) const { 2763 assert(MI.getOpcode() == TargetOpcode::G_STORE); 2764 return getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MI.getOperand(0).getReg(), 2765 MRI); 2766 } 2767 2768 bool CombinerHelper::matchUndefSelectCmp(MachineInstr &MI) const { 2769 assert(MI.getOpcode() == TargetOpcode::G_SELECT); 2770 return getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MI.getOperand(1).getReg(), 2771 MRI); 2772 } 2773 2774 bool CombinerHelper::matchInsertExtractVecEltOutOfBounds( 2775 MachineInstr &MI) const { 2776 assert((MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT || 2777 MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT) && 2778 "Expected an insert/extract element op"); 2779 LLT VecTy = MRI.getType(MI.getOperand(1).getReg()); 2780 if (VecTy.isScalableVector()) 2781 return false; 2782 2783 unsigned IdxIdx = 2784 MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT ? 2 : 3; 2785 auto Idx = getIConstantVRegVal(MI.getOperand(IdxIdx).getReg(), MRI); 2786 if (!Idx) 2787 return false; 2788 return Idx->getZExtValue() >= VecTy.getNumElements(); 2789 } 2790 2791 bool CombinerHelper::matchConstantSelectCmp(MachineInstr &MI, 2792 unsigned &OpIdx) const { 2793 GSelect &SelMI = cast<GSelect>(MI); 2794 auto Cst = 2795 isConstantOrConstantSplatVector(*MRI.getVRegDef(SelMI.getCondReg()), MRI); 2796 if (!Cst) 2797 return false; 2798 OpIdx = Cst->isZero() ? 3 : 2; 2799 return true; 2800 } 2801 2802 void CombinerHelper::eraseInst(MachineInstr &MI) const { MI.eraseFromParent(); } 2803 2804 bool CombinerHelper::matchEqualDefs(const MachineOperand &MOP1, 2805 const MachineOperand &MOP2) const { 2806 if (!MOP1.isReg() || !MOP2.isReg()) 2807 return false; 2808 auto InstAndDef1 = getDefSrcRegIgnoringCopies(MOP1.getReg(), MRI); 2809 if (!InstAndDef1) 2810 return false; 2811 auto InstAndDef2 = getDefSrcRegIgnoringCopies(MOP2.getReg(), MRI); 2812 if (!InstAndDef2) 2813 return false; 2814 MachineInstr *I1 = InstAndDef1->MI; 2815 MachineInstr *I2 = InstAndDef2->MI; 2816 2817 // Handle a case like this: 2818 // 2819 // %0:_(s64), %1:_(s64) = G_UNMERGE_VALUES %2:_(<2 x s64>) 2820 // 2821 // Even though %0 and %1 are produced by the same instruction they are not 2822 // the same values. 2823 if (I1 == I2) 2824 return MOP1.getReg() == MOP2.getReg(); 2825 2826 // If we have an instruction which loads or stores, we can't guarantee that 2827 // it is identical. 2828 // 2829 // For example, we may have 2830 // 2831 // %x1 = G_LOAD %addr (load N from @somewhere) 2832 // ... 2833 // call @foo 2834 // ... 2835 // %x2 = G_LOAD %addr (load N from @somewhere) 2836 // ... 2837 // %or = G_OR %x1, %x2 2838 // 2839 // It's possible that @foo will modify whatever lives at the address we're 2840 // loading from. To be safe, let's just assume that all loads and stores 2841 // are different (unless we have something which is guaranteed to not 2842 // change.) 2843 if (I1->mayLoadOrStore() && !I1->isDereferenceableInvariantLoad()) 2844 return false; 2845 2846 // If both instructions are loads or stores, they are equal only if both 2847 // are dereferenceable invariant loads with the same number of bits. 2848 if (I1->mayLoadOrStore() && I2->mayLoadOrStore()) { 2849 GLoadStore *LS1 = dyn_cast<GLoadStore>(I1); 2850 GLoadStore *LS2 = dyn_cast<GLoadStore>(I2); 2851 if (!LS1 || !LS2) 2852 return false; 2853 2854 if (!I2->isDereferenceableInvariantLoad() || 2855 (LS1->getMemSizeInBits() != LS2->getMemSizeInBits())) 2856 return false; 2857 } 2858 2859 // Check for physical registers on the instructions first to avoid cases 2860 // like this: 2861 // 2862 // %a = COPY $physreg 2863 // ... 2864 // SOMETHING implicit-def $physreg 2865 // ... 2866 // %b = COPY $physreg 2867 // 2868 // These copies are not equivalent. 2869 if (any_of(I1->uses(), [](const MachineOperand &MO) { 2870 return MO.isReg() && MO.getReg().isPhysical(); 2871 })) { 2872 // Check if we have a case like this: 2873 // 2874 // %a = COPY $physreg 2875 // %b = COPY %a 2876 // 2877 // In this case, I1 and I2 will both be equal to %a = COPY $physreg. 2878 // From that, we know that they must have the same value, since they must 2879 // have come from the same COPY. 2880 return I1->isIdenticalTo(*I2); 2881 } 2882 2883 // We don't have any physical registers, so we don't necessarily need the 2884 // same vreg defs. 2885 // 2886 // On the off-chance that there's some target instruction feeding into the 2887 // instruction, let's use produceSameValue instead of isIdenticalTo. 2888 if (Builder.getTII().produceSameValue(*I1, *I2, &MRI)) { 2889 // Handle instructions with multiple defs that produce same values. Values 2890 // are same for operands with same index. 2891 // %0:_(s8), %1:_(s8), %2:_(s8), %3:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>) 2892 // %5:_(s8), %6:_(s8), %7:_(s8), %8:_(s8) = G_UNMERGE_VALUES %4:_(<4 x s8>) 2893 // I1 and I2 are different instructions but produce same values, 2894 // %1 and %6 are same, %1 and %7 are not the same value. 2895 return I1->findRegisterDefOperandIdx(InstAndDef1->Reg, /*TRI=*/nullptr) == 2896 I2->findRegisterDefOperandIdx(InstAndDef2->Reg, /*TRI=*/nullptr); 2897 } 2898 return false; 2899 } 2900 2901 bool CombinerHelper::matchConstantOp(const MachineOperand &MOP, 2902 int64_t C) const { 2903 if (!MOP.isReg()) 2904 return false; 2905 auto *MI = MRI.getVRegDef(MOP.getReg()); 2906 auto MaybeCst = isConstantOrConstantSplatVector(*MI, MRI); 2907 return MaybeCst && MaybeCst->getBitWidth() <= 64 && 2908 MaybeCst->getSExtValue() == C; 2909 } 2910 2911 bool CombinerHelper::matchConstantFPOp(const MachineOperand &MOP, 2912 double C) const { 2913 if (!MOP.isReg()) 2914 return false; 2915 std::optional<FPValueAndVReg> MaybeCst; 2916 if (!mi_match(MOP.getReg(), MRI, m_GFCstOrSplat(MaybeCst))) 2917 return false; 2918 2919 return MaybeCst->Value.isExactlyValue(C); 2920 } 2921 2922 void CombinerHelper::replaceSingleDefInstWithOperand(MachineInstr &MI, 2923 unsigned OpIdx) const { 2924 assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?"); 2925 Register OldReg = MI.getOperand(0).getReg(); 2926 Register Replacement = MI.getOperand(OpIdx).getReg(); 2927 assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?"); 2928 replaceRegWith(MRI, OldReg, Replacement); 2929 MI.eraseFromParent(); 2930 } 2931 2932 void CombinerHelper::replaceSingleDefInstWithReg(MachineInstr &MI, 2933 Register Replacement) const { 2934 assert(MI.getNumExplicitDefs() == 1 && "Expected one explicit def?"); 2935 Register OldReg = MI.getOperand(0).getReg(); 2936 assert(canReplaceReg(OldReg, Replacement, MRI) && "Cannot replace register?"); 2937 replaceRegWith(MRI, OldReg, Replacement); 2938 MI.eraseFromParent(); 2939 } 2940 2941 bool CombinerHelper::matchConstantLargerBitWidth(MachineInstr &MI, 2942 unsigned ConstIdx) const { 2943 Register ConstReg = MI.getOperand(ConstIdx).getReg(); 2944 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 2945 2946 // Get the shift amount 2947 auto VRegAndVal = getIConstantVRegValWithLookThrough(ConstReg, MRI); 2948 if (!VRegAndVal) 2949 return false; 2950 2951 // Return true of shift amount >= Bitwidth 2952 return (VRegAndVal->Value.uge(DstTy.getSizeInBits())); 2953 } 2954 2955 void CombinerHelper::applyFunnelShiftConstantModulo(MachineInstr &MI) const { 2956 assert((MI.getOpcode() == TargetOpcode::G_FSHL || 2957 MI.getOpcode() == TargetOpcode::G_FSHR) && 2958 "This is not a funnel shift operation"); 2959 2960 Register ConstReg = MI.getOperand(3).getReg(); 2961 LLT ConstTy = MRI.getType(ConstReg); 2962 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 2963 2964 auto VRegAndVal = getIConstantVRegValWithLookThrough(ConstReg, MRI); 2965 assert((VRegAndVal) && "Value is not a constant"); 2966 2967 // Calculate the new Shift Amount = Old Shift Amount % BitWidth 2968 APInt NewConst = VRegAndVal->Value.urem( 2969 APInt(ConstTy.getSizeInBits(), DstTy.getScalarSizeInBits())); 2970 2971 auto NewConstInstr = Builder.buildConstant(ConstTy, NewConst.getZExtValue()); 2972 Builder.buildInstr( 2973 MI.getOpcode(), {MI.getOperand(0)}, 2974 {MI.getOperand(1), MI.getOperand(2), NewConstInstr.getReg(0)}); 2975 2976 MI.eraseFromParent(); 2977 } 2978 2979 bool CombinerHelper::matchSelectSameVal(MachineInstr &MI) const { 2980 assert(MI.getOpcode() == TargetOpcode::G_SELECT); 2981 // Match (cond ? x : x) 2982 return matchEqualDefs(MI.getOperand(2), MI.getOperand(3)) && 2983 canReplaceReg(MI.getOperand(0).getReg(), MI.getOperand(2).getReg(), 2984 MRI); 2985 } 2986 2987 bool CombinerHelper::matchBinOpSameVal(MachineInstr &MI) const { 2988 return matchEqualDefs(MI.getOperand(1), MI.getOperand(2)) && 2989 canReplaceReg(MI.getOperand(0).getReg(), MI.getOperand(1).getReg(), 2990 MRI); 2991 } 2992 2993 bool CombinerHelper::matchOperandIsZero(MachineInstr &MI, 2994 unsigned OpIdx) const { 2995 return matchConstantOp(MI.getOperand(OpIdx), 0) && 2996 canReplaceReg(MI.getOperand(0).getReg(), MI.getOperand(OpIdx).getReg(), 2997 MRI); 2998 } 2999 3000 bool CombinerHelper::matchOperandIsUndef(MachineInstr &MI, 3001 unsigned OpIdx) const { 3002 MachineOperand &MO = MI.getOperand(OpIdx); 3003 return MO.isReg() && 3004 getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, MO.getReg(), MRI); 3005 } 3006 3007 bool CombinerHelper::matchOperandIsKnownToBeAPowerOfTwo(MachineInstr &MI, 3008 unsigned OpIdx) const { 3009 MachineOperand &MO = MI.getOperand(OpIdx); 3010 return isKnownToBeAPowerOfTwo(MO.getReg(), MRI, VT); 3011 } 3012 3013 void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI, 3014 double C) const { 3015 assert(MI.getNumDefs() == 1 && "Expected only one def?"); 3016 Builder.buildFConstant(MI.getOperand(0), C); 3017 MI.eraseFromParent(); 3018 } 3019 3020 void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, 3021 int64_t C) const { 3022 assert(MI.getNumDefs() == 1 && "Expected only one def?"); 3023 Builder.buildConstant(MI.getOperand(0), C); 3024 MI.eraseFromParent(); 3025 } 3026 3027 void CombinerHelper::replaceInstWithConstant(MachineInstr &MI, APInt C) const { 3028 assert(MI.getNumDefs() == 1 && "Expected only one def?"); 3029 Builder.buildConstant(MI.getOperand(0), C); 3030 MI.eraseFromParent(); 3031 } 3032 3033 void CombinerHelper::replaceInstWithFConstant(MachineInstr &MI, 3034 ConstantFP *CFP) const { 3035 assert(MI.getNumDefs() == 1 && "Expected only one def?"); 3036 Builder.buildFConstant(MI.getOperand(0), CFP->getValueAPF()); 3037 MI.eraseFromParent(); 3038 } 3039 3040 void CombinerHelper::replaceInstWithUndef(MachineInstr &MI) const { 3041 assert(MI.getNumDefs() == 1 && "Expected only one def?"); 3042 Builder.buildUndef(MI.getOperand(0)); 3043 MI.eraseFromParent(); 3044 } 3045 3046 bool CombinerHelper::matchSimplifyAddToSub( 3047 MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) const { 3048 Register LHS = MI.getOperand(1).getReg(); 3049 Register RHS = MI.getOperand(2).getReg(); 3050 Register &NewLHS = std::get<0>(MatchInfo); 3051 Register &NewRHS = std::get<1>(MatchInfo); 3052 3053 // Helper lambda to check for opportunities for 3054 // ((0-A) + B) -> B - A 3055 // (A + (0-B)) -> A - B 3056 auto CheckFold = [&](Register &MaybeSub, Register &MaybeNewLHS) { 3057 if (!mi_match(MaybeSub, MRI, m_Neg(m_Reg(NewRHS)))) 3058 return false; 3059 NewLHS = MaybeNewLHS; 3060 return true; 3061 }; 3062 3063 return CheckFold(LHS, RHS) || CheckFold(RHS, LHS); 3064 } 3065 3066 bool CombinerHelper::matchCombineInsertVecElts( 3067 MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) const { 3068 assert(MI.getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT && 3069 "Invalid opcode"); 3070 Register DstReg = MI.getOperand(0).getReg(); 3071 LLT DstTy = MRI.getType(DstReg); 3072 assert(DstTy.isVector() && "Invalid G_INSERT_VECTOR_ELT?"); 3073 3074 if (DstTy.isScalableVector()) 3075 return false; 3076 3077 unsigned NumElts = DstTy.getNumElements(); 3078 // If this MI is part of a sequence of insert_vec_elts, then 3079 // don't do the combine in the middle of the sequence. 3080 if (MRI.hasOneUse(DstReg) && MRI.use_instr_begin(DstReg)->getOpcode() == 3081 TargetOpcode::G_INSERT_VECTOR_ELT) 3082 return false; 3083 MachineInstr *CurrInst = &MI; 3084 MachineInstr *TmpInst; 3085 int64_t IntImm; 3086 Register TmpReg; 3087 MatchInfo.resize(NumElts); 3088 while (mi_match( 3089 CurrInst->getOperand(0).getReg(), MRI, 3090 m_GInsertVecElt(m_MInstr(TmpInst), m_Reg(TmpReg), m_ICst(IntImm)))) { 3091 if (IntImm >= NumElts || IntImm < 0) 3092 return false; 3093 if (!MatchInfo[IntImm]) 3094 MatchInfo[IntImm] = TmpReg; 3095 CurrInst = TmpInst; 3096 } 3097 // Variable index. 3098 if (CurrInst->getOpcode() == TargetOpcode::G_INSERT_VECTOR_ELT) 3099 return false; 3100 if (TmpInst->getOpcode() == TargetOpcode::G_BUILD_VECTOR) { 3101 for (unsigned I = 1; I < TmpInst->getNumOperands(); ++I) { 3102 if (!MatchInfo[I - 1].isValid()) 3103 MatchInfo[I - 1] = TmpInst->getOperand(I).getReg(); 3104 } 3105 return true; 3106 } 3107 // If we didn't end in a G_IMPLICIT_DEF and the source is not fully 3108 // overwritten, bail out. 3109 return TmpInst->getOpcode() == TargetOpcode::G_IMPLICIT_DEF || 3110 all_of(MatchInfo, [](Register Reg) { return !!Reg; }); 3111 } 3112 3113 void CombinerHelper::applyCombineInsertVecElts( 3114 MachineInstr &MI, SmallVectorImpl<Register> &MatchInfo) const { 3115 Register UndefReg; 3116 auto GetUndef = [&]() { 3117 if (UndefReg) 3118 return UndefReg; 3119 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 3120 UndefReg = Builder.buildUndef(DstTy.getScalarType()).getReg(0); 3121 return UndefReg; 3122 }; 3123 for (Register &Reg : MatchInfo) { 3124 if (!Reg) 3125 Reg = GetUndef(); 3126 } 3127 Builder.buildBuildVector(MI.getOperand(0).getReg(), MatchInfo); 3128 MI.eraseFromParent(); 3129 } 3130 3131 void CombinerHelper::applySimplifyAddToSub( 3132 MachineInstr &MI, std::tuple<Register, Register> &MatchInfo) const { 3133 Register SubLHS, SubRHS; 3134 std::tie(SubLHS, SubRHS) = MatchInfo; 3135 Builder.buildSub(MI.getOperand(0).getReg(), SubLHS, SubRHS); 3136 MI.eraseFromParent(); 3137 } 3138 3139 bool CombinerHelper::matchHoistLogicOpWithSameOpcodeHands( 3140 MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) const { 3141 // Matches: logic (hand x, ...), (hand y, ...) -> hand (logic x, y), ... 3142 // 3143 // Creates the new hand + logic instruction (but does not insert them.) 3144 // 3145 // On success, MatchInfo is populated with the new instructions. These are 3146 // inserted in applyHoistLogicOpWithSameOpcodeHands. 3147 unsigned LogicOpcode = MI.getOpcode(); 3148 assert(LogicOpcode == TargetOpcode::G_AND || 3149 LogicOpcode == TargetOpcode::G_OR || 3150 LogicOpcode == TargetOpcode::G_XOR); 3151 MachineIRBuilder MIB(MI); 3152 Register Dst = MI.getOperand(0).getReg(); 3153 Register LHSReg = MI.getOperand(1).getReg(); 3154 Register RHSReg = MI.getOperand(2).getReg(); 3155 3156 // Don't recompute anything. 3157 if (!MRI.hasOneNonDBGUse(LHSReg) || !MRI.hasOneNonDBGUse(RHSReg)) 3158 return false; 3159 3160 // Make sure we have (hand x, ...), (hand y, ...) 3161 MachineInstr *LeftHandInst = getDefIgnoringCopies(LHSReg, MRI); 3162 MachineInstr *RightHandInst = getDefIgnoringCopies(RHSReg, MRI); 3163 if (!LeftHandInst || !RightHandInst) 3164 return false; 3165 unsigned HandOpcode = LeftHandInst->getOpcode(); 3166 if (HandOpcode != RightHandInst->getOpcode()) 3167 return false; 3168 if (LeftHandInst->getNumOperands() < 2 || 3169 !LeftHandInst->getOperand(1).isReg() || 3170 RightHandInst->getNumOperands() < 2 || 3171 !RightHandInst->getOperand(1).isReg()) 3172 return false; 3173 3174 // Make sure the types match up, and if we're doing this post-legalization, 3175 // we end up with legal types. 3176 Register X = LeftHandInst->getOperand(1).getReg(); 3177 Register Y = RightHandInst->getOperand(1).getReg(); 3178 LLT XTy = MRI.getType(X); 3179 LLT YTy = MRI.getType(Y); 3180 if (!XTy.isValid() || XTy != YTy) 3181 return false; 3182 3183 // Optional extra source register. 3184 Register ExtraHandOpSrcReg; 3185 switch (HandOpcode) { 3186 default: 3187 return false; 3188 case TargetOpcode::G_ANYEXT: 3189 case TargetOpcode::G_SEXT: 3190 case TargetOpcode::G_ZEXT: { 3191 // Match: logic (ext X), (ext Y) --> ext (logic X, Y) 3192 break; 3193 } 3194 case TargetOpcode::G_TRUNC: { 3195 // Match: logic (trunc X), (trunc Y) -> trunc (logic X, Y) 3196 const MachineFunction *MF = MI.getMF(); 3197 LLVMContext &Ctx = MF->getFunction().getContext(); 3198 3199 LLT DstTy = MRI.getType(Dst); 3200 const TargetLowering &TLI = getTargetLowering(); 3201 3202 // Be extra careful sinking truncate. If it's free, there's no benefit in 3203 // widening a binop. 3204 if (TLI.isZExtFree(DstTy, XTy, Ctx) && TLI.isTruncateFree(XTy, DstTy, Ctx)) 3205 return false; 3206 break; 3207 } 3208 case TargetOpcode::G_AND: 3209 case TargetOpcode::G_ASHR: 3210 case TargetOpcode::G_LSHR: 3211 case TargetOpcode::G_SHL: { 3212 // Match: logic (binop x, z), (binop y, z) -> binop (logic x, y), z 3213 MachineOperand &ZOp = LeftHandInst->getOperand(2); 3214 if (!matchEqualDefs(ZOp, RightHandInst->getOperand(2))) 3215 return false; 3216 ExtraHandOpSrcReg = ZOp.getReg(); 3217 break; 3218 } 3219 } 3220 3221 if (!isLegalOrBeforeLegalizer({LogicOpcode, {XTy, YTy}})) 3222 return false; 3223 3224 // Record the steps to build the new instructions. 3225 // 3226 // Steps to build (logic x, y) 3227 auto NewLogicDst = MRI.createGenericVirtualRegister(XTy); 3228 OperandBuildSteps LogicBuildSteps = { 3229 [=](MachineInstrBuilder &MIB) { MIB.addDef(NewLogicDst); }, 3230 [=](MachineInstrBuilder &MIB) { MIB.addReg(X); }, 3231 [=](MachineInstrBuilder &MIB) { MIB.addReg(Y); }}; 3232 InstructionBuildSteps LogicSteps(LogicOpcode, LogicBuildSteps); 3233 3234 // Steps to build hand (logic x, y), ...z 3235 OperandBuildSteps HandBuildSteps = { 3236 [=](MachineInstrBuilder &MIB) { MIB.addDef(Dst); }, 3237 [=](MachineInstrBuilder &MIB) { MIB.addReg(NewLogicDst); }}; 3238 if (ExtraHandOpSrcReg.isValid()) 3239 HandBuildSteps.push_back( 3240 [=](MachineInstrBuilder &MIB) { MIB.addReg(ExtraHandOpSrcReg); }); 3241 InstructionBuildSteps HandSteps(HandOpcode, HandBuildSteps); 3242 3243 MatchInfo = InstructionStepsMatchInfo({LogicSteps, HandSteps}); 3244 return true; 3245 } 3246 3247 void CombinerHelper::applyBuildInstructionSteps( 3248 MachineInstr &MI, InstructionStepsMatchInfo &MatchInfo) const { 3249 assert(MatchInfo.InstrsToBuild.size() && 3250 "Expected at least one instr to build?"); 3251 for (auto &InstrToBuild : MatchInfo.InstrsToBuild) { 3252 assert(InstrToBuild.Opcode && "Expected a valid opcode?"); 3253 assert(InstrToBuild.OperandFns.size() && "Expected at least one operand?"); 3254 MachineInstrBuilder Instr = Builder.buildInstr(InstrToBuild.Opcode); 3255 for (auto &OperandFn : InstrToBuild.OperandFns) 3256 OperandFn(Instr); 3257 } 3258 MI.eraseFromParent(); 3259 } 3260 3261 bool CombinerHelper::matchAshrShlToSextInreg( 3262 MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) const { 3263 assert(MI.getOpcode() == TargetOpcode::G_ASHR); 3264 int64_t ShlCst, AshrCst; 3265 Register Src; 3266 if (!mi_match(MI.getOperand(0).getReg(), MRI, 3267 m_GAShr(m_GShl(m_Reg(Src), m_ICstOrSplat(ShlCst)), 3268 m_ICstOrSplat(AshrCst)))) 3269 return false; 3270 if (ShlCst != AshrCst) 3271 return false; 3272 if (!isLegalOrBeforeLegalizer( 3273 {TargetOpcode::G_SEXT_INREG, {MRI.getType(Src)}})) 3274 return false; 3275 MatchInfo = std::make_tuple(Src, ShlCst); 3276 return true; 3277 } 3278 3279 void CombinerHelper::applyAshShlToSextInreg( 3280 MachineInstr &MI, std::tuple<Register, int64_t> &MatchInfo) const { 3281 assert(MI.getOpcode() == TargetOpcode::G_ASHR); 3282 Register Src; 3283 int64_t ShiftAmt; 3284 std::tie(Src, ShiftAmt) = MatchInfo; 3285 unsigned Size = MRI.getType(Src).getScalarSizeInBits(); 3286 Builder.buildSExtInReg(MI.getOperand(0).getReg(), Src, Size - ShiftAmt); 3287 MI.eraseFromParent(); 3288 } 3289 3290 /// and(and(x, C1), C2) -> C1&C2 ? and(x, C1&C2) : 0 3291 bool CombinerHelper::matchOverlappingAnd( 3292 MachineInstr &MI, 3293 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 3294 assert(MI.getOpcode() == TargetOpcode::G_AND); 3295 3296 Register Dst = MI.getOperand(0).getReg(); 3297 LLT Ty = MRI.getType(Dst); 3298 3299 Register R; 3300 int64_t C1; 3301 int64_t C2; 3302 if (!mi_match( 3303 Dst, MRI, 3304 m_GAnd(m_GAnd(m_Reg(R), m_ICst(C1)), m_ICst(C2)))) 3305 return false; 3306 3307 MatchInfo = [=](MachineIRBuilder &B) { 3308 if (C1 & C2) { 3309 B.buildAnd(Dst, R, B.buildConstant(Ty, C1 & C2)); 3310 return; 3311 } 3312 auto Zero = B.buildConstant(Ty, 0); 3313 replaceRegWith(MRI, Dst, Zero->getOperand(0).getReg()); 3314 }; 3315 return true; 3316 } 3317 3318 bool CombinerHelper::matchRedundantAnd(MachineInstr &MI, 3319 Register &Replacement) const { 3320 // Given 3321 // 3322 // %y:_(sN) = G_SOMETHING 3323 // %x:_(sN) = G_SOMETHING 3324 // %res:_(sN) = G_AND %x, %y 3325 // 3326 // Eliminate the G_AND when it is known that x & y == x or x & y == y. 3327 // 3328 // Patterns like this can appear as a result of legalization. E.g. 3329 // 3330 // %cmp:_(s32) = G_ICMP intpred(pred), %x(s32), %y 3331 // %one:_(s32) = G_CONSTANT i32 1 3332 // %and:_(s32) = G_AND %cmp, %one 3333 // 3334 // In this case, G_ICMP only produces a single bit, so x & 1 == x. 3335 assert(MI.getOpcode() == TargetOpcode::G_AND); 3336 if (!VT) 3337 return false; 3338 3339 Register AndDst = MI.getOperand(0).getReg(); 3340 Register LHS = MI.getOperand(1).getReg(); 3341 Register RHS = MI.getOperand(2).getReg(); 3342 3343 // Check the RHS (maybe a constant) first, and if we have no KnownBits there, 3344 // we can't do anything. If we do, then it depends on whether we have 3345 // KnownBits on the LHS. 3346 KnownBits RHSBits = VT->getKnownBits(RHS); 3347 if (RHSBits.isUnknown()) 3348 return false; 3349 3350 KnownBits LHSBits = VT->getKnownBits(LHS); 3351 3352 // Check that x & Mask == x. 3353 // x & 1 == x, always 3354 // x & 0 == x, only if x is also 0 3355 // Meaning Mask has no effect if every bit is either one in Mask or zero in x. 3356 // 3357 // Check if we can replace AndDst with the LHS of the G_AND 3358 if (canReplaceReg(AndDst, LHS, MRI) && 3359 (LHSBits.Zero | RHSBits.One).isAllOnes()) { 3360 Replacement = LHS; 3361 return true; 3362 } 3363 3364 // Check if we can replace AndDst with the RHS of the G_AND 3365 if (canReplaceReg(AndDst, RHS, MRI) && 3366 (LHSBits.One | RHSBits.Zero).isAllOnes()) { 3367 Replacement = RHS; 3368 return true; 3369 } 3370 3371 return false; 3372 } 3373 3374 bool CombinerHelper::matchRedundantOr(MachineInstr &MI, 3375 Register &Replacement) const { 3376 // Given 3377 // 3378 // %y:_(sN) = G_SOMETHING 3379 // %x:_(sN) = G_SOMETHING 3380 // %res:_(sN) = G_OR %x, %y 3381 // 3382 // Eliminate the G_OR when it is known that x | y == x or x | y == y. 3383 assert(MI.getOpcode() == TargetOpcode::G_OR); 3384 if (!VT) 3385 return false; 3386 3387 Register OrDst = MI.getOperand(0).getReg(); 3388 Register LHS = MI.getOperand(1).getReg(); 3389 Register RHS = MI.getOperand(2).getReg(); 3390 3391 KnownBits LHSBits = VT->getKnownBits(LHS); 3392 KnownBits RHSBits = VT->getKnownBits(RHS); 3393 3394 // Check that x | Mask == x. 3395 // x | 0 == x, always 3396 // x | 1 == x, only if x is also 1 3397 // Meaning Mask has no effect if every bit is either zero in Mask or one in x. 3398 // 3399 // Check if we can replace OrDst with the LHS of the G_OR 3400 if (canReplaceReg(OrDst, LHS, MRI) && 3401 (LHSBits.One | RHSBits.Zero).isAllOnes()) { 3402 Replacement = LHS; 3403 return true; 3404 } 3405 3406 // Check if we can replace OrDst with the RHS of the G_OR 3407 if (canReplaceReg(OrDst, RHS, MRI) && 3408 (LHSBits.Zero | RHSBits.One).isAllOnes()) { 3409 Replacement = RHS; 3410 return true; 3411 } 3412 3413 return false; 3414 } 3415 3416 bool CombinerHelper::matchRedundantSExtInReg(MachineInstr &MI) const { 3417 // If the input is already sign extended, just drop the extension. 3418 Register Src = MI.getOperand(1).getReg(); 3419 unsigned ExtBits = MI.getOperand(2).getImm(); 3420 unsigned TypeSize = MRI.getType(Src).getScalarSizeInBits(); 3421 return VT->computeNumSignBits(Src) >= (TypeSize - ExtBits + 1); 3422 } 3423 3424 static bool isConstValidTrue(const TargetLowering &TLI, unsigned ScalarSizeBits, 3425 int64_t Cst, bool IsVector, bool IsFP) { 3426 // For i1, Cst will always be -1 regardless of boolean contents. 3427 return (ScalarSizeBits == 1 && Cst == -1) || 3428 isConstTrueVal(TLI, Cst, IsVector, IsFP); 3429 } 3430 3431 // This combine tries to reduce the number of scalarised G_TRUNC instructions by 3432 // using vector truncates instead 3433 // 3434 // EXAMPLE: 3435 // %a(i32), %b(i32) = G_UNMERGE_VALUES %src(<2 x i32>) 3436 // %T_a(i16) = G_TRUNC %a(i32) 3437 // %T_b(i16) = G_TRUNC %b(i32) 3438 // %Undef(i16) = G_IMPLICIT_DEF(i16) 3439 // %dst(v4i16) = G_BUILD_VECTORS %T_a(i16), %T_b(i16), %Undef(i16), %Undef(i16) 3440 // 3441 // ===> 3442 // %Undef(<2 x i32>) = G_IMPLICIT_DEF(<2 x i32>) 3443 // %Mid(<4 x s32>) = G_CONCAT_VECTORS %src(<2 x i32>), %Undef(<2 x i32>) 3444 // %dst(<4 x s16>) = G_TRUNC %Mid(<4 x s32>) 3445 // 3446 // Only matches sources made up of G_TRUNCs followed by G_IMPLICIT_DEFs 3447 bool CombinerHelper::matchUseVectorTruncate(MachineInstr &MI, 3448 Register &MatchInfo) const { 3449 auto BuildMI = cast<GBuildVector>(&MI); 3450 unsigned NumOperands = BuildMI->getNumSources(); 3451 LLT DstTy = MRI.getType(BuildMI->getReg(0)); 3452 3453 // Check the G_BUILD_VECTOR sources 3454 unsigned I; 3455 MachineInstr *UnmergeMI = nullptr; 3456 3457 // Check all source TRUNCs come from the same UNMERGE instruction 3458 for (I = 0; I < NumOperands; ++I) { 3459 auto SrcMI = MRI.getVRegDef(BuildMI->getSourceReg(I)); 3460 auto SrcMIOpc = SrcMI->getOpcode(); 3461 3462 // Check if the G_TRUNC instructions all come from the same MI 3463 if (SrcMIOpc == TargetOpcode::G_TRUNC) { 3464 if (!UnmergeMI) { 3465 UnmergeMI = MRI.getVRegDef(SrcMI->getOperand(1).getReg()); 3466 if (UnmergeMI->getOpcode() != TargetOpcode::G_UNMERGE_VALUES) 3467 return false; 3468 } else { 3469 auto UnmergeSrcMI = MRI.getVRegDef(SrcMI->getOperand(1).getReg()); 3470 if (UnmergeMI != UnmergeSrcMI) 3471 return false; 3472 } 3473 } else { 3474 break; 3475 } 3476 } 3477 if (I < 2) 3478 return false; 3479 3480 // Check the remaining source elements are only G_IMPLICIT_DEF 3481 for (; I < NumOperands; ++I) { 3482 auto SrcMI = MRI.getVRegDef(BuildMI->getSourceReg(I)); 3483 auto SrcMIOpc = SrcMI->getOpcode(); 3484 3485 if (SrcMIOpc != TargetOpcode::G_IMPLICIT_DEF) 3486 return false; 3487 } 3488 3489 // Check the size of unmerge source 3490 MatchInfo = cast<GUnmerge>(UnmergeMI)->getSourceReg(); 3491 LLT UnmergeSrcTy = MRI.getType(MatchInfo); 3492 if (!DstTy.getElementCount().isKnownMultipleOf(UnmergeSrcTy.getNumElements())) 3493 return false; 3494 3495 // Check the unmerge source and destination element types match 3496 LLT UnmergeSrcEltTy = UnmergeSrcTy.getElementType(); 3497 Register UnmergeDstReg = UnmergeMI->getOperand(0).getReg(); 3498 LLT UnmergeDstEltTy = MRI.getType(UnmergeDstReg); 3499 if (UnmergeSrcEltTy != UnmergeDstEltTy) 3500 return false; 3501 3502 // Only generate legal instructions post-legalizer 3503 if (!IsPreLegalize) { 3504 LLT MidTy = DstTy.changeElementType(UnmergeSrcTy.getScalarType()); 3505 3506 if (DstTy.getElementCount() != UnmergeSrcTy.getElementCount() && 3507 !isLegal({TargetOpcode::G_CONCAT_VECTORS, {MidTy, UnmergeSrcTy}})) 3508 return false; 3509 3510 if (!isLegal({TargetOpcode::G_TRUNC, {DstTy, MidTy}})) 3511 return false; 3512 } 3513 3514 return true; 3515 } 3516 3517 void CombinerHelper::applyUseVectorTruncate(MachineInstr &MI, 3518 Register &MatchInfo) const { 3519 Register MidReg; 3520 auto BuildMI = cast<GBuildVector>(&MI); 3521 Register DstReg = BuildMI->getReg(0); 3522 LLT DstTy = MRI.getType(DstReg); 3523 LLT UnmergeSrcTy = MRI.getType(MatchInfo); 3524 unsigned DstTyNumElt = DstTy.getNumElements(); 3525 unsigned UnmergeSrcTyNumElt = UnmergeSrcTy.getNumElements(); 3526 3527 // No need to pad vector if only G_TRUNC is needed 3528 if (DstTyNumElt / UnmergeSrcTyNumElt == 1) { 3529 MidReg = MatchInfo; 3530 } else { 3531 Register UndefReg = Builder.buildUndef(UnmergeSrcTy).getReg(0); 3532 SmallVector<Register> ConcatRegs = {MatchInfo}; 3533 for (unsigned I = 1; I < DstTyNumElt / UnmergeSrcTyNumElt; ++I) 3534 ConcatRegs.push_back(UndefReg); 3535 3536 auto MidTy = DstTy.changeElementType(UnmergeSrcTy.getScalarType()); 3537 MidReg = Builder.buildConcatVectors(MidTy, ConcatRegs).getReg(0); 3538 } 3539 3540 Builder.buildTrunc(DstReg, MidReg); 3541 MI.eraseFromParent(); 3542 } 3543 3544 bool CombinerHelper::matchNotCmp( 3545 MachineInstr &MI, SmallVectorImpl<Register> &RegsToNegate) const { 3546 assert(MI.getOpcode() == TargetOpcode::G_XOR); 3547 LLT Ty = MRI.getType(MI.getOperand(0).getReg()); 3548 const auto &TLI = *Builder.getMF().getSubtarget().getTargetLowering(); 3549 Register XorSrc; 3550 Register CstReg; 3551 // We match xor(src, true) here. 3552 if (!mi_match(MI.getOperand(0).getReg(), MRI, 3553 m_GXor(m_Reg(XorSrc), m_Reg(CstReg)))) 3554 return false; 3555 3556 if (!MRI.hasOneNonDBGUse(XorSrc)) 3557 return false; 3558 3559 // Check that XorSrc is the root of a tree of comparisons combined with ANDs 3560 // and ORs. The suffix of RegsToNegate starting from index I is used a work 3561 // list of tree nodes to visit. 3562 RegsToNegate.push_back(XorSrc); 3563 // Remember whether the comparisons are all integer or all floating point. 3564 bool IsInt = false; 3565 bool IsFP = false; 3566 for (unsigned I = 0; I < RegsToNegate.size(); ++I) { 3567 Register Reg = RegsToNegate[I]; 3568 if (!MRI.hasOneNonDBGUse(Reg)) 3569 return false; 3570 MachineInstr *Def = MRI.getVRegDef(Reg); 3571 switch (Def->getOpcode()) { 3572 default: 3573 // Don't match if the tree contains anything other than ANDs, ORs and 3574 // comparisons. 3575 return false; 3576 case TargetOpcode::G_ICMP: 3577 if (IsFP) 3578 return false; 3579 IsInt = true; 3580 // When we apply the combine we will invert the predicate. 3581 break; 3582 case TargetOpcode::G_FCMP: 3583 if (IsInt) 3584 return false; 3585 IsFP = true; 3586 // When we apply the combine we will invert the predicate. 3587 break; 3588 case TargetOpcode::G_AND: 3589 case TargetOpcode::G_OR: 3590 // Implement De Morgan's laws: 3591 // ~(x & y) -> ~x | ~y 3592 // ~(x | y) -> ~x & ~y 3593 // When we apply the combine we will change the opcode and recursively 3594 // negate the operands. 3595 RegsToNegate.push_back(Def->getOperand(1).getReg()); 3596 RegsToNegate.push_back(Def->getOperand(2).getReg()); 3597 break; 3598 } 3599 } 3600 3601 // Now we know whether the comparisons are integer or floating point, check 3602 // the constant in the xor. 3603 int64_t Cst; 3604 if (Ty.isVector()) { 3605 MachineInstr *CstDef = MRI.getVRegDef(CstReg); 3606 auto MaybeCst = getIConstantSplatSExtVal(*CstDef, MRI); 3607 if (!MaybeCst) 3608 return false; 3609 if (!isConstValidTrue(TLI, Ty.getScalarSizeInBits(), *MaybeCst, true, IsFP)) 3610 return false; 3611 } else { 3612 if (!mi_match(CstReg, MRI, m_ICst(Cst))) 3613 return false; 3614 if (!isConstValidTrue(TLI, Ty.getSizeInBits(), Cst, false, IsFP)) 3615 return false; 3616 } 3617 3618 return true; 3619 } 3620 3621 void CombinerHelper::applyNotCmp( 3622 MachineInstr &MI, SmallVectorImpl<Register> &RegsToNegate) const { 3623 for (Register Reg : RegsToNegate) { 3624 MachineInstr *Def = MRI.getVRegDef(Reg); 3625 Observer.changingInstr(*Def); 3626 // For each comparison, invert the opcode. For each AND and OR, change the 3627 // opcode. 3628 switch (Def->getOpcode()) { 3629 default: 3630 llvm_unreachable("Unexpected opcode"); 3631 case TargetOpcode::G_ICMP: 3632 case TargetOpcode::G_FCMP: { 3633 MachineOperand &PredOp = Def->getOperand(1); 3634 CmpInst::Predicate NewP = CmpInst::getInversePredicate( 3635 (CmpInst::Predicate)PredOp.getPredicate()); 3636 PredOp.setPredicate(NewP); 3637 break; 3638 } 3639 case TargetOpcode::G_AND: 3640 Def->setDesc(Builder.getTII().get(TargetOpcode::G_OR)); 3641 break; 3642 case TargetOpcode::G_OR: 3643 Def->setDesc(Builder.getTII().get(TargetOpcode::G_AND)); 3644 break; 3645 } 3646 Observer.changedInstr(*Def); 3647 } 3648 3649 replaceRegWith(MRI, MI.getOperand(0).getReg(), MI.getOperand(1).getReg()); 3650 MI.eraseFromParent(); 3651 } 3652 3653 bool CombinerHelper::matchXorOfAndWithSameReg( 3654 MachineInstr &MI, std::pair<Register, Register> &MatchInfo) const { 3655 // Match (xor (and x, y), y) (or any of its commuted cases) 3656 assert(MI.getOpcode() == TargetOpcode::G_XOR); 3657 Register &X = MatchInfo.first; 3658 Register &Y = MatchInfo.second; 3659 Register AndReg = MI.getOperand(1).getReg(); 3660 Register SharedReg = MI.getOperand(2).getReg(); 3661 3662 // Find a G_AND on either side of the G_XOR. 3663 // Look for one of 3664 // 3665 // (xor (and x, y), SharedReg) 3666 // (xor SharedReg, (and x, y)) 3667 if (!mi_match(AndReg, MRI, m_GAnd(m_Reg(X), m_Reg(Y)))) { 3668 std::swap(AndReg, SharedReg); 3669 if (!mi_match(AndReg, MRI, m_GAnd(m_Reg(X), m_Reg(Y)))) 3670 return false; 3671 } 3672 3673 // Only do this if we'll eliminate the G_AND. 3674 if (!MRI.hasOneNonDBGUse(AndReg)) 3675 return false; 3676 3677 // We can combine if SharedReg is the same as either the LHS or RHS of the 3678 // G_AND. 3679 if (Y != SharedReg) 3680 std::swap(X, Y); 3681 return Y == SharedReg; 3682 } 3683 3684 void CombinerHelper::applyXorOfAndWithSameReg( 3685 MachineInstr &MI, std::pair<Register, Register> &MatchInfo) const { 3686 // Fold (xor (and x, y), y) -> (and (not x), y) 3687 Register X, Y; 3688 std::tie(X, Y) = MatchInfo; 3689 auto Not = Builder.buildNot(MRI.getType(X), X); 3690 Observer.changingInstr(MI); 3691 MI.setDesc(Builder.getTII().get(TargetOpcode::G_AND)); 3692 MI.getOperand(1).setReg(Not->getOperand(0).getReg()); 3693 MI.getOperand(2).setReg(Y); 3694 Observer.changedInstr(MI); 3695 } 3696 3697 bool CombinerHelper::matchPtrAddZero(MachineInstr &MI) const { 3698 auto &PtrAdd = cast<GPtrAdd>(MI); 3699 Register DstReg = PtrAdd.getReg(0); 3700 LLT Ty = MRI.getType(DstReg); 3701 const DataLayout &DL = Builder.getMF().getDataLayout(); 3702 3703 if (DL.isNonIntegralAddressSpace(Ty.getScalarType().getAddressSpace())) 3704 return false; 3705 3706 if (Ty.isPointer()) { 3707 auto ConstVal = getIConstantVRegVal(PtrAdd.getBaseReg(), MRI); 3708 return ConstVal && *ConstVal == 0; 3709 } 3710 3711 assert(Ty.isVector() && "Expecting a vector type"); 3712 const MachineInstr *VecMI = MRI.getVRegDef(PtrAdd.getBaseReg()); 3713 return isBuildVectorAllZeros(*VecMI, MRI); 3714 } 3715 3716 void CombinerHelper::applyPtrAddZero(MachineInstr &MI) const { 3717 auto &PtrAdd = cast<GPtrAdd>(MI); 3718 Builder.buildIntToPtr(PtrAdd.getReg(0), PtrAdd.getOffsetReg()); 3719 PtrAdd.eraseFromParent(); 3720 } 3721 3722 /// The second source operand is known to be a power of 2. 3723 void CombinerHelper::applySimplifyURemByPow2(MachineInstr &MI) const { 3724 Register DstReg = MI.getOperand(0).getReg(); 3725 Register Src0 = MI.getOperand(1).getReg(); 3726 Register Pow2Src1 = MI.getOperand(2).getReg(); 3727 LLT Ty = MRI.getType(DstReg); 3728 3729 // Fold (urem x, pow2) -> (and x, pow2-1) 3730 auto NegOne = Builder.buildConstant(Ty, -1); 3731 auto Add = Builder.buildAdd(Ty, Pow2Src1, NegOne); 3732 Builder.buildAnd(DstReg, Src0, Add); 3733 MI.eraseFromParent(); 3734 } 3735 3736 bool CombinerHelper::matchFoldBinOpIntoSelect(MachineInstr &MI, 3737 unsigned &SelectOpNo) const { 3738 Register LHS = MI.getOperand(1).getReg(); 3739 Register RHS = MI.getOperand(2).getReg(); 3740 3741 Register OtherOperandReg = RHS; 3742 SelectOpNo = 1; 3743 MachineInstr *Select = MRI.getVRegDef(LHS); 3744 3745 // Don't do this unless the old select is going away. We want to eliminate the 3746 // binary operator, not replace a binop with a select. 3747 if (Select->getOpcode() != TargetOpcode::G_SELECT || 3748 !MRI.hasOneNonDBGUse(LHS)) { 3749 OtherOperandReg = LHS; 3750 SelectOpNo = 2; 3751 Select = MRI.getVRegDef(RHS); 3752 if (Select->getOpcode() != TargetOpcode::G_SELECT || 3753 !MRI.hasOneNonDBGUse(RHS)) 3754 return false; 3755 } 3756 3757 MachineInstr *SelectLHS = MRI.getVRegDef(Select->getOperand(2).getReg()); 3758 MachineInstr *SelectRHS = MRI.getVRegDef(Select->getOperand(3).getReg()); 3759 3760 if (!isConstantOrConstantVector(*SelectLHS, MRI, 3761 /*AllowFP*/ true, 3762 /*AllowOpaqueConstants*/ false)) 3763 return false; 3764 if (!isConstantOrConstantVector(*SelectRHS, MRI, 3765 /*AllowFP*/ true, 3766 /*AllowOpaqueConstants*/ false)) 3767 return false; 3768 3769 unsigned BinOpcode = MI.getOpcode(); 3770 3771 // We know that one of the operands is a select of constants. Now verify that 3772 // the other binary operator operand is either a constant, or we can handle a 3773 // variable. 3774 bool CanFoldNonConst = 3775 (BinOpcode == TargetOpcode::G_AND || BinOpcode == TargetOpcode::G_OR) && 3776 (isNullOrNullSplat(*SelectLHS, MRI) || 3777 isAllOnesOrAllOnesSplat(*SelectLHS, MRI)) && 3778 (isNullOrNullSplat(*SelectRHS, MRI) || 3779 isAllOnesOrAllOnesSplat(*SelectRHS, MRI)); 3780 if (CanFoldNonConst) 3781 return true; 3782 3783 return isConstantOrConstantVector(*MRI.getVRegDef(OtherOperandReg), MRI, 3784 /*AllowFP*/ true, 3785 /*AllowOpaqueConstants*/ false); 3786 } 3787 3788 /// \p SelectOperand is the operand in binary operator \p MI that is the select 3789 /// to fold. 3790 void CombinerHelper::applyFoldBinOpIntoSelect( 3791 MachineInstr &MI, const unsigned &SelectOperand) const { 3792 Register Dst = MI.getOperand(0).getReg(); 3793 Register LHS = MI.getOperand(1).getReg(); 3794 Register RHS = MI.getOperand(2).getReg(); 3795 MachineInstr *Select = MRI.getVRegDef(MI.getOperand(SelectOperand).getReg()); 3796 3797 Register SelectCond = Select->getOperand(1).getReg(); 3798 Register SelectTrue = Select->getOperand(2).getReg(); 3799 Register SelectFalse = Select->getOperand(3).getReg(); 3800 3801 LLT Ty = MRI.getType(Dst); 3802 unsigned BinOpcode = MI.getOpcode(); 3803 3804 Register FoldTrue, FoldFalse; 3805 3806 // We have a select-of-constants followed by a binary operator with a 3807 // constant. Eliminate the binop by pulling the constant math into the select. 3808 // Example: add (select Cond, CT, CF), CBO --> select Cond, CT + CBO, CF + CBO 3809 if (SelectOperand == 1) { 3810 // TODO: SelectionDAG verifies this actually constant folds before 3811 // committing to the combine. 3812 3813 FoldTrue = Builder.buildInstr(BinOpcode, {Ty}, {SelectTrue, RHS}).getReg(0); 3814 FoldFalse = 3815 Builder.buildInstr(BinOpcode, {Ty}, {SelectFalse, RHS}).getReg(0); 3816 } else { 3817 FoldTrue = Builder.buildInstr(BinOpcode, {Ty}, {LHS, SelectTrue}).getReg(0); 3818 FoldFalse = 3819 Builder.buildInstr(BinOpcode, {Ty}, {LHS, SelectFalse}).getReg(0); 3820 } 3821 3822 Builder.buildSelect(Dst, SelectCond, FoldTrue, FoldFalse, MI.getFlags()); 3823 MI.eraseFromParent(); 3824 } 3825 3826 std::optional<SmallVector<Register, 8>> 3827 CombinerHelper::findCandidatesForLoadOrCombine(const MachineInstr *Root) const { 3828 assert(Root->getOpcode() == TargetOpcode::G_OR && "Expected G_OR only!"); 3829 // We want to detect if Root is part of a tree which represents a bunch 3830 // of loads being merged into a larger load. We'll try to recognize patterns 3831 // like, for example: 3832 // 3833 // Reg Reg 3834 // \ / 3835 // OR_1 Reg 3836 // \ / 3837 // OR_2 3838 // \ Reg 3839 // .. / 3840 // Root 3841 // 3842 // Reg Reg Reg Reg 3843 // \ / \ / 3844 // OR_1 OR_2 3845 // \ / 3846 // \ / 3847 // ... 3848 // Root 3849 // 3850 // Each "Reg" may have been produced by a load + some arithmetic. This 3851 // function will save each of them. 3852 SmallVector<Register, 8> RegsToVisit; 3853 SmallVector<const MachineInstr *, 7> Ors = {Root}; 3854 3855 // In the "worst" case, we're dealing with a load for each byte. So, there 3856 // are at most #bytes - 1 ORs. 3857 const unsigned MaxIter = 3858 MRI.getType(Root->getOperand(0).getReg()).getSizeInBytes() - 1; 3859 for (unsigned Iter = 0; Iter < MaxIter; ++Iter) { 3860 if (Ors.empty()) 3861 break; 3862 const MachineInstr *Curr = Ors.pop_back_val(); 3863 Register OrLHS = Curr->getOperand(1).getReg(); 3864 Register OrRHS = Curr->getOperand(2).getReg(); 3865 3866 // In the combine, we want to elimate the entire tree. 3867 if (!MRI.hasOneNonDBGUse(OrLHS) || !MRI.hasOneNonDBGUse(OrRHS)) 3868 return std::nullopt; 3869 3870 // If it's a G_OR, save it and continue to walk. If it's not, then it's 3871 // something that may be a load + arithmetic. 3872 if (const MachineInstr *Or = getOpcodeDef(TargetOpcode::G_OR, OrLHS, MRI)) 3873 Ors.push_back(Or); 3874 else 3875 RegsToVisit.push_back(OrLHS); 3876 if (const MachineInstr *Or = getOpcodeDef(TargetOpcode::G_OR, OrRHS, MRI)) 3877 Ors.push_back(Or); 3878 else 3879 RegsToVisit.push_back(OrRHS); 3880 } 3881 3882 // We're going to try and merge each register into a wider power-of-2 type, 3883 // so we ought to have an even number of registers. 3884 if (RegsToVisit.empty() || RegsToVisit.size() % 2 != 0) 3885 return std::nullopt; 3886 return RegsToVisit; 3887 } 3888 3889 /// Helper function for findLoadOffsetsForLoadOrCombine. 3890 /// 3891 /// Check if \p Reg is the result of loading a \p MemSizeInBits wide value, 3892 /// and then moving that value into a specific byte offset. 3893 /// 3894 /// e.g. x[i] << 24 3895 /// 3896 /// \returns The load instruction and the byte offset it is moved into. 3897 static std::optional<std::pair<GZExtLoad *, int64_t>> 3898 matchLoadAndBytePosition(Register Reg, unsigned MemSizeInBits, 3899 const MachineRegisterInfo &MRI) { 3900 assert(MRI.hasOneNonDBGUse(Reg) && 3901 "Expected Reg to only have one non-debug use?"); 3902 Register MaybeLoad; 3903 int64_t Shift; 3904 if (!mi_match(Reg, MRI, 3905 m_OneNonDBGUse(m_GShl(m_Reg(MaybeLoad), m_ICst(Shift))))) { 3906 Shift = 0; 3907 MaybeLoad = Reg; 3908 } 3909 3910 if (Shift % MemSizeInBits != 0) 3911 return std::nullopt; 3912 3913 // TODO: Handle other types of loads. 3914 auto *Load = getOpcodeDef<GZExtLoad>(MaybeLoad, MRI); 3915 if (!Load) 3916 return std::nullopt; 3917 3918 if (!Load->isUnordered() || Load->getMemSizeInBits() != MemSizeInBits) 3919 return std::nullopt; 3920 3921 return std::make_pair(Load, Shift / MemSizeInBits); 3922 } 3923 3924 std::optional<std::tuple<GZExtLoad *, int64_t, GZExtLoad *>> 3925 CombinerHelper::findLoadOffsetsForLoadOrCombine( 3926 SmallDenseMap<int64_t, int64_t, 8> &MemOffset2Idx, 3927 const SmallVector<Register, 8> &RegsToVisit, 3928 const unsigned MemSizeInBits) const { 3929 3930 // Each load found for the pattern. There should be one for each RegsToVisit. 3931 SmallSetVector<const MachineInstr *, 8> Loads; 3932 3933 // The lowest index used in any load. (The lowest "i" for each x[i].) 3934 int64_t LowestIdx = INT64_MAX; 3935 3936 // The load which uses the lowest index. 3937 GZExtLoad *LowestIdxLoad = nullptr; 3938 3939 // Keeps track of the load indices we see. We shouldn't see any indices twice. 3940 SmallSet<int64_t, 8> SeenIdx; 3941 3942 // Ensure each load is in the same MBB. 3943 // TODO: Support multiple MachineBasicBlocks. 3944 MachineBasicBlock *MBB = nullptr; 3945 const MachineMemOperand *MMO = nullptr; 3946 3947 // Earliest instruction-order load in the pattern. 3948 GZExtLoad *EarliestLoad = nullptr; 3949 3950 // Latest instruction-order load in the pattern. 3951 GZExtLoad *LatestLoad = nullptr; 3952 3953 // Base pointer which every load should share. 3954 Register BasePtr; 3955 3956 // We want to find a load for each register. Each load should have some 3957 // appropriate bit twiddling arithmetic. During this loop, we will also keep 3958 // track of the load which uses the lowest index. Later, we will check if we 3959 // can use its pointer in the final, combined load. 3960 for (auto Reg : RegsToVisit) { 3961 // Find the load, and find the position that it will end up in (e.g. a 3962 // shifted) value. 3963 auto LoadAndPos = matchLoadAndBytePosition(Reg, MemSizeInBits, MRI); 3964 if (!LoadAndPos) 3965 return std::nullopt; 3966 GZExtLoad *Load; 3967 int64_t DstPos; 3968 std::tie(Load, DstPos) = *LoadAndPos; 3969 3970 // TODO: Handle multiple MachineBasicBlocks. Currently not handled because 3971 // it is difficult to check for stores/calls/etc between loads. 3972 MachineBasicBlock *LoadMBB = Load->getParent(); 3973 if (!MBB) 3974 MBB = LoadMBB; 3975 if (LoadMBB != MBB) 3976 return std::nullopt; 3977 3978 // Make sure that the MachineMemOperands of every seen load are compatible. 3979 auto &LoadMMO = Load->getMMO(); 3980 if (!MMO) 3981 MMO = &LoadMMO; 3982 if (MMO->getAddrSpace() != LoadMMO.getAddrSpace()) 3983 return std::nullopt; 3984 3985 // Find out what the base pointer and index for the load is. 3986 Register LoadPtr; 3987 int64_t Idx; 3988 if (!mi_match(Load->getOperand(1).getReg(), MRI, 3989 m_GPtrAdd(m_Reg(LoadPtr), m_ICst(Idx)))) { 3990 LoadPtr = Load->getOperand(1).getReg(); 3991 Idx = 0; 3992 } 3993 3994 // Don't combine things like a[i], a[i] -> a bigger load. 3995 if (!SeenIdx.insert(Idx).second) 3996 return std::nullopt; 3997 3998 // Every load must share the same base pointer; don't combine things like: 3999 // 4000 // a[i], b[i + 1] -> a bigger load. 4001 if (!BasePtr.isValid()) 4002 BasePtr = LoadPtr; 4003 if (BasePtr != LoadPtr) 4004 return std::nullopt; 4005 4006 if (Idx < LowestIdx) { 4007 LowestIdx = Idx; 4008 LowestIdxLoad = Load; 4009 } 4010 4011 // Keep track of the byte offset that this load ends up at. If we have seen 4012 // the byte offset, then stop here. We do not want to combine: 4013 // 4014 // a[i] << 16, a[i + k] << 16 -> a bigger load. 4015 if (!MemOffset2Idx.try_emplace(DstPos, Idx).second) 4016 return std::nullopt; 4017 Loads.insert(Load); 4018 4019 // Keep track of the position of the earliest/latest loads in the pattern. 4020 // We will check that there are no load fold barriers between them later 4021 // on. 4022 // 4023 // FIXME: Is there a better way to check for load fold barriers? 4024 if (!EarliestLoad || dominates(*Load, *EarliestLoad)) 4025 EarliestLoad = Load; 4026 if (!LatestLoad || dominates(*LatestLoad, *Load)) 4027 LatestLoad = Load; 4028 } 4029 4030 // We found a load for each register. Let's check if each load satisfies the 4031 // pattern. 4032 assert(Loads.size() == RegsToVisit.size() && 4033 "Expected to find a load for each register?"); 4034 assert(EarliestLoad != LatestLoad && EarliestLoad && 4035 LatestLoad && "Expected at least two loads?"); 4036 4037 // Check if there are any stores, calls, etc. between any of the loads. If 4038 // there are, then we can't safely perform the combine. 4039 // 4040 // MaxIter is chosen based off the (worst case) number of iterations it 4041 // typically takes to succeed in the LLVM test suite plus some padding. 4042 // 4043 // FIXME: Is there a better way to check for load fold barriers? 4044 const unsigned MaxIter = 20; 4045 unsigned Iter = 0; 4046 for (const auto &MI : instructionsWithoutDebug(EarliestLoad->getIterator(), 4047 LatestLoad->getIterator())) { 4048 if (Loads.count(&MI)) 4049 continue; 4050 if (MI.isLoadFoldBarrier()) 4051 return std::nullopt; 4052 if (Iter++ == MaxIter) 4053 return std::nullopt; 4054 } 4055 4056 return std::make_tuple(LowestIdxLoad, LowestIdx, LatestLoad); 4057 } 4058 4059 bool CombinerHelper::matchLoadOrCombine( 4060 MachineInstr &MI, 4061 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 4062 assert(MI.getOpcode() == TargetOpcode::G_OR); 4063 MachineFunction &MF = *MI.getMF(); 4064 // Assuming a little-endian target, transform: 4065 // s8 *a = ... 4066 // s32 val = a[0] | (a[1] << 8) | (a[2] << 16) | (a[3] << 24) 4067 // => 4068 // s32 val = *((i32)a) 4069 // 4070 // s8 *a = ... 4071 // s32 val = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3] 4072 // => 4073 // s32 val = BSWAP(*((s32)a)) 4074 Register Dst = MI.getOperand(0).getReg(); 4075 LLT Ty = MRI.getType(Dst); 4076 if (Ty.isVector()) 4077 return false; 4078 4079 // We need to combine at least two loads into this type. Since the smallest 4080 // possible load is into a byte, we need at least a 16-bit wide type. 4081 const unsigned WideMemSizeInBits = Ty.getSizeInBits(); 4082 if (WideMemSizeInBits < 16 || WideMemSizeInBits % 8 != 0) 4083 return false; 4084 4085 // Match a collection of non-OR instructions in the pattern. 4086 auto RegsToVisit = findCandidatesForLoadOrCombine(&MI); 4087 if (!RegsToVisit) 4088 return false; 4089 4090 // We have a collection of non-OR instructions. Figure out how wide each of 4091 // the small loads should be based off of the number of potential loads we 4092 // found. 4093 const unsigned NarrowMemSizeInBits = WideMemSizeInBits / RegsToVisit->size(); 4094 if (NarrowMemSizeInBits % 8 != 0) 4095 return false; 4096 4097 // Check if each register feeding into each OR is a load from the same 4098 // base pointer + some arithmetic. 4099 // 4100 // e.g. a[0], a[1] << 8, a[2] << 16, etc. 4101 // 4102 // Also verify that each of these ends up putting a[i] into the same memory 4103 // offset as a load into a wide type would. 4104 SmallDenseMap<int64_t, int64_t, 8> MemOffset2Idx; 4105 GZExtLoad *LowestIdxLoad, *LatestLoad; 4106 int64_t LowestIdx; 4107 auto MaybeLoadInfo = findLoadOffsetsForLoadOrCombine( 4108 MemOffset2Idx, *RegsToVisit, NarrowMemSizeInBits); 4109 if (!MaybeLoadInfo) 4110 return false; 4111 std::tie(LowestIdxLoad, LowestIdx, LatestLoad) = *MaybeLoadInfo; 4112 4113 // We have a bunch of loads being OR'd together. Using the addresses + offsets 4114 // we found before, check if this corresponds to a big or little endian byte 4115 // pattern. If it does, then we can represent it using a load + possibly a 4116 // BSWAP. 4117 bool IsBigEndianTarget = MF.getDataLayout().isBigEndian(); 4118 std::optional<bool> IsBigEndian = isBigEndian(MemOffset2Idx, LowestIdx); 4119 if (!IsBigEndian) 4120 return false; 4121 bool NeedsBSwap = IsBigEndianTarget != *IsBigEndian; 4122 if (NeedsBSwap && !isLegalOrBeforeLegalizer({TargetOpcode::G_BSWAP, {Ty}})) 4123 return false; 4124 4125 // Make sure that the load from the lowest index produces offset 0 in the 4126 // final value. 4127 // 4128 // This ensures that we won't combine something like this: 4129 // 4130 // load x[i] -> byte 2 4131 // load x[i+1] -> byte 0 ---> wide_load x[i] 4132 // load x[i+2] -> byte 1 4133 const unsigned NumLoadsInTy = WideMemSizeInBits / NarrowMemSizeInBits; 4134 const unsigned ZeroByteOffset = 4135 *IsBigEndian 4136 ? bigEndianByteAt(NumLoadsInTy, 0) 4137 : littleEndianByteAt(NumLoadsInTy, 0); 4138 auto ZeroOffsetIdx = MemOffset2Idx.find(ZeroByteOffset); 4139 if (ZeroOffsetIdx == MemOffset2Idx.end() || 4140 ZeroOffsetIdx->second != LowestIdx) 4141 return false; 4142 4143 // We wil reuse the pointer from the load which ends up at byte offset 0. It 4144 // may not use index 0. 4145 Register Ptr = LowestIdxLoad->getPointerReg(); 4146 const MachineMemOperand &MMO = LowestIdxLoad->getMMO(); 4147 LegalityQuery::MemDesc MMDesc(MMO); 4148 MMDesc.MemoryTy = Ty; 4149 if (!isLegalOrBeforeLegalizer( 4150 {TargetOpcode::G_LOAD, {Ty, MRI.getType(Ptr)}, {MMDesc}})) 4151 return false; 4152 auto PtrInfo = MMO.getPointerInfo(); 4153 auto *NewMMO = MF.getMachineMemOperand(&MMO, PtrInfo, WideMemSizeInBits / 8); 4154 4155 // Load must be allowed and fast on the target. 4156 LLVMContext &C = MF.getFunction().getContext(); 4157 auto &DL = MF.getDataLayout(); 4158 unsigned Fast = 0; 4159 if (!getTargetLowering().allowsMemoryAccess(C, DL, Ty, *NewMMO, &Fast) || 4160 !Fast) 4161 return false; 4162 4163 MatchInfo = [=](MachineIRBuilder &MIB) { 4164 MIB.setInstrAndDebugLoc(*LatestLoad); 4165 Register LoadDst = NeedsBSwap ? MRI.cloneVirtualRegister(Dst) : Dst; 4166 MIB.buildLoad(LoadDst, Ptr, *NewMMO); 4167 if (NeedsBSwap) 4168 MIB.buildBSwap(Dst, LoadDst); 4169 }; 4170 return true; 4171 } 4172 4173 bool CombinerHelper::matchExtendThroughPhis(MachineInstr &MI, 4174 MachineInstr *&ExtMI) const { 4175 auto &PHI = cast<GPhi>(MI); 4176 Register DstReg = PHI.getReg(0); 4177 4178 // TODO: Extending a vector may be expensive, don't do this until heuristics 4179 // are better. 4180 if (MRI.getType(DstReg).isVector()) 4181 return false; 4182 4183 // Try to match a phi, whose only use is an extend. 4184 if (!MRI.hasOneNonDBGUse(DstReg)) 4185 return false; 4186 ExtMI = &*MRI.use_instr_nodbg_begin(DstReg); 4187 switch (ExtMI->getOpcode()) { 4188 case TargetOpcode::G_ANYEXT: 4189 return true; // G_ANYEXT is usually free. 4190 case TargetOpcode::G_ZEXT: 4191 case TargetOpcode::G_SEXT: 4192 break; 4193 default: 4194 return false; 4195 } 4196 4197 // If the target is likely to fold this extend away, don't propagate. 4198 if (Builder.getTII().isExtendLikelyToBeFolded(*ExtMI, MRI)) 4199 return false; 4200 4201 // We don't want to propagate the extends unless there's a good chance that 4202 // they'll be optimized in some way. 4203 // Collect the unique incoming values. 4204 SmallPtrSet<MachineInstr *, 4> InSrcs; 4205 for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) { 4206 auto *DefMI = getDefIgnoringCopies(PHI.getIncomingValue(I), MRI); 4207 switch (DefMI->getOpcode()) { 4208 case TargetOpcode::G_LOAD: 4209 case TargetOpcode::G_TRUNC: 4210 case TargetOpcode::G_SEXT: 4211 case TargetOpcode::G_ZEXT: 4212 case TargetOpcode::G_ANYEXT: 4213 case TargetOpcode::G_CONSTANT: 4214 InSrcs.insert(DefMI); 4215 // Don't try to propagate if there are too many places to create new 4216 // extends, chances are it'll increase code size. 4217 if (InSrcs.size() > 2) 4218 return false; 4219 break; 4220 default: 4221 return false; 4222 } 4223 } 4224 return true; 4225 } 4226 4227 void CombinerHelper::applyExtendThroughPhis(MachineInstr &MI, 4228 MachineInstr *&ExtMI) const { 4229 auto &PHI = cast<GPhi>(MI); 4230 Register DstReg = ExtMI->getOperand(0).getReg(); 4231 LLT ExtTy = MRI.getType(DstReg); 4232 4233 // Propagate the extension into the block of each incoming reg's block. 4234 // Use a SetVector here because PHIs can have duplicate edges, and we want 4235 // deterministic iteration order. 4236 SmallSetVector<MachineInstr *, 8> SrcMIs; 4237 SmallDenseMap<MachineInstr *, MachineInstr *, 8> OldToNewSrcMap; 4238 for (unsigned I = 0; I < PHI.getNumIncomingValues(); ++I) { 4239 auto SrcReg = PHI.getIncomingValue(I); 4240 auto *SrcMI = MRI.getVRegDef(SrcReg); 4241 if (!SrcMIs.insert(SrcMI)) 4242 continue; 4243 4244 // Build an extend after each src inst. 4245 auto *MBB = SrcMI->getParent(); 4246 MachineBasicBlock::iterator InsertPt = ++SrcMI->getIterator(); 4247 if (InsertPt != MBB->end() && InsertPt->isPHI()) 4248 InsertPt = MBB->getFirstNonPHI(); 4249 4250 Builder.setInsertPt(*SrcMI->getParent(), InsertPt); 4251 Builder.setDebugLoc(MI.getDebugLoc()); 4252 auto NewExt = Builder.buildExtOrTrunc(ExtMI->getOpcode(), ExtTy, SrcReg); 4253 OldToNewSrcMap[SrcMI] = NewExt; 4254 } 4255 4256 // Create a new phi with the extended inputs. 4257 Builder.setInstrAndDebugLoc(MI); 4258 auto NewPhi = Builder.buildInstrNoInsert(TargetOpcode::G_PHI); 4259 NewPhi.addDef(DstReg); 4260 for (const MachineOperand &MO : llvm::drop_begin(MI.operands())) { 4261 if (!MO.isReg()) { 4262 NewPhi.addMBB(MO.getMBB()); 4263 continue; 4264 } 4265 auto *NewSrc = OldToNewSrcMap[MRI.getVRegDef(MO.getReg())]; 4266 NewPhi.addUse(NewSrc->getOperand(0).getReg()); 4267 } 4268 Builder.insertInstr(NewPhi); 4269 ExtMI->eraseFromParent(); 4270 } 4271 4272 bool CombinerHelper::matchExtractVecEltBuildVec(MachineInstr &MI, 4273 Register &Reg) const { 4274 assert(MI.getOpcode() == TargetOpcode::G_EXTRACT_VECTOR_ELT); 4275 // If we have a constant index, look for a G_BUILD_VECTOR source 4276 // and find the source register that the index maps to. 4277 Register SrcVec = MI.getOperand(1).getReg(); 4278 LLT SrcTy = MRI.getType(SrcVec); 4279 if (SrcTy.isScalableVector()) 4280 return false; 4281 4282 auto Cst = getIConstantVRegValWithLookThrough(MI.getOperand(2).getReg(), MRI); 4283 if (!Cst || Cst->Value.getZExtValue() >= SrcTy.getNumElements()) 4284 return false; 4285 4286 unsigned VecIdx = Cst->Value.getZExtValue(); 4287 4288 // Check if we have a build_vector or build_vector_trunc with an optional 4289 // trunc in front. 4290 MachineInstr *SrcVecMI = MRI.getVRegDef(SrcVec); 4291 if (SrcVecMI->getOpcode() == TargetOpcode::G_TRUNC) { 4292 SrcVecMI = MRI.getVRegDef(SrcVecMI->getOperand(1).getReg()); 4293 } 4294 4295 if (SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR && 4296 SrcVecMI->getOpcode() != TargetOpcode::G_BUILD_VECTOR_TRUNC) 4297 return false; 4298 4299 EVT Ty(getMVTForLLT(SrcTy)); 4300 if (!MRI.hasOneNonDBGUse(SrcVec) && 4301 !getTargetLowering().aggressivelyPreferBuildVectorSources(Ty)) 4302 return false; 4303 4304 Reg = SrcVecMI->getOperand(VecIdx + 1).getReg(); 4305 return true; 4306 } 4307 4308 void CombinerHelper::applyExtractVecEltBuildVec(MachineInstr &MI, 4309 Register &Reg) const { 4310 // Check the type of the register, since it may have come from a 4311 // G_BUILD_VECTOR_TRUNC. 4312 LLT ScalarTy = MRI.getType(Reg); 4313 Register DstReg = MI.getOperand(0).getReg(); 4314 LLT DstTy = MRI.getType(DstReg); 4315 4316 if (ScalarTy != DstTy) { 4317 assert(ScalarTy.getSizeInBits() > DstTy.getSizeInBits()); 4318 Builder.buildTrunc(DstReg, Reg); 4319 MI.eraseFromParent(); 4320 return; 4321 } 4322 replaceSingleDefInstWithReg(MI, Reg); 4323 } 4324 4325 bool CombinerHelper::matchExtractAllEltsFromBuildVector( 4326 MachineInstr &MI, 4327 SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) const { 4328 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR); 4329 // This combine tries to find build_vector's which have every source element 4330 // extracted using G_EXTRACT_VECTOR_ELT. This can happen when transforms like 4331 // the masked load scalarization is run late in the pipeline. There's already 4332 // a combine for a similar pattern starting from the extract, but that 4333 // doesn't attempt to do it if there are multiple uses of the build_vector, 4334 // which in this case is true. Starting the combine from the build_vector 4335 // feels more natural than trying to find sibling nodes of extracts. 4336 // E.g. 4337 // %vec(<4 x s32>) = G_BUILD_VECTOR %s1(s32), %s2, %s3, %s4 4338 // %ext1 = G_EXTRACT_VECTOR_ELT %vec, 0 4339 // %ext2 = G_EXTRACT_VECTOR_ELT %vec, 1 4340 // %ext3 = G_EXTRACT_VECTOR_ELT %vec, 2 4341 // %ext4 = G_EXTRACT_VECTOR_ELT %vec, 3 4342 // ==> 4343 // replace ext{1,2,3,4} with %s{1,2,3,4} 4344 4345 Register DstReg = MI.getOperand(0).getReg(); 4346 LLT DstTy = MRI.getType(DstReg); 4347 unsigned NumElts = DstTy.getNumElements(); 4348 4349 SmallBitVector ExtractedElts(NumElts); 4350 for (MachineInstr &II : MRI.use_nodbg_instructions(DstReg)) { 4351 if (II.getOpcode() != TargetOpcode::G_EXTRACT_VECTOR_ELT) 4352 return false; 4353 auto Cst = getIConstantVRegVal(II.getOperand(2).getReg(), MRI); 4354 if (!Cst) 4355 return false; 4356 unsigned Idx = Cst->getZExtValue(); 4357 if (Idx >= NumElts) 4358 return false; // Out of range. 4359 ExtractedElts.set(Idx); 4360 SrcDstPairs.emplace_back( 4361 std::make_pair(MI.getOperand(Idx + 1).getReg(), &II)); 4362 } 4363 // Match if every element was extracted. 4364 return ExtractedElts.all(); 4365 } 4366 4367 void CombinerHelper::applyExtractAllEltsFromBuildVector( 4368 MachineInstr &MI, 4369 SmallVectorImpl<std::pair<Register, MachineInstr *>> &SrcDstPairs) const { 4370 assert(MI.getOpcode() == TargetOpcode::G_BUILD_VECTOR); 4371 for (auto &Pair : SrcDstPairs) { 4372 auto *ExtMI = Pair.second; 4373 replaceRegWith(MRI, ExtMI->getOperand(0).getReg(), Pair.first); 4374 ExtMI->eraseFromParent(); 4375 } 4376 MI.eraseFromParent(); 4377 } 4378 4379 void CombinerHelper::applyBuildFn( 4380 MachineInstr &MI, 4381 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 4382 applyBuildFnNoErase(MI, MatchInfo); 4383 MI.eraseFromParent(); 4384 } 4385 4386 void CombinerHelper::applyBuildFnNoErase( 4387 MachineInstr &MI, 4388 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 4389 MatchInfo(Builder); 4390 } 4391 4392 bool CombinerHelper::matchOrShiftToFunnelShift(MachineInstr &MI, 4393 BuildFnTy &MatchInfo) const { 4394 assert(MI.getOpcode() == TargetOpcode::G_OR); 4395 4396 Register Dst = MI.getOperand(0).getReg(); 4397 LLT Ty = MRI.getType(Dst); 4398 unsigned BitWidth = Ty.getScalarSizeInBits(); 4399 4400 Register ShlSrc, ShlAmt, LShrSrc, LShrAmt, Amt; 4401 unsigned FshOpc = 0; 4402 4403 // Match (or (shl ...), (lshr ...)). 4404 if (!mi_match(Dst, MRI, 4405 // m_GOr() handles the commuted version as well. 4406 m_GOr(m_GShl(m_Reg(ShlSrc), m_Reg(ShlAmt)), 4407 m_GLShr(m_Reg(LShrSrc), m_Reg(LShrAmt))))) 4408 return false; 4409 4410 // Given constants C0 and C1 such that C0 + C1 is bit-width: 4411 // (or (shl x, C0), (lshr y, C1)) -> (fshl x, y, C0) or (fshr x, y, C1) 4412 int64_t CstShlAmt, CstLShrAmt; 4413 if (mi_match(ShlAmt, MRI, m_ICstOrSplat(CstShlAmt)) && 4414 mi_match(LShrAmt, MRI, m_ICstOrSplat(CstLShrAmt)) && 4415 CstShlAmt + CstLShrAmt == BitWidth) { 4416 FshOpc = TargetOpcode::G_FSHR; 4417 Amt = LShrAmt; 4418 4419 } else if (mi_match(LShrAmt, MRI, 4420 m_GSub(m_SpecificICstOrSplat(BitWidth), m_Reg(Amt))) && 4421 ShlAmt == Amt) { 4422 // (or (shl x, amt), (lshr y, (sub bw, amt))) -> (fshl x, y, amt) 4423 FshOpc = TargetOpcode::G_FSHL; 4424 4425 } else if (mi_match(ShlAmt, MRI, 4426 m_GSub(m_SpecificICstOrSplat(BitWidth), m_Reg(Amt))) && 4427 LShrAmt == Amt) { 4428 // (or (shl x, (sub bw, amt)), (lshr y, amt)) -> (fshr x, y, amt) 4429 FshOpc = TargetOpcode::G_FSHR; 4430 4431 } else { 4432 return false; 4433 } 4434 4435 LLT AmtTy = MRI.getType(Amt); 4436 if (!isLegalOrBeforeLegalizer({FshOpc, {Ty, AmtTy}})) 4437 return false; 4438 4439 MatchInfo = [=](MachineIRBuilder &B) { 4440 B.buildInstr(FshOpc, {Dst}, {ShlSrc, LShrSrc, Amt}); 4441 }; 4442 return true; 4443 } 4444 4445 /// Match an FSHL or FSHR that can be combined to a ROTR or ROTL rotate. 4446 bool CombinerHelper::matchFunnelShiftToRotate(MachineInstr &MI) const { 4447 unsigned Opc = MI.getOpcode(); 4448 assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR); 4449 Register X = MI.getOperand(1).getReg(); 4450 Register Y = MI.getOperand(2).getReg(); 4451 if (X != Y) 4452 return false; 4453 unsigned RotateOpc = 4454 Opc == TargetOpcode::G_FSHL ? TargetOpcode::G_ROTL : TargetOpcode::G_ROTR; 4455 return isLegalOrBeforeLegalizer({RotateOpc, {MRI.getType(X), MRI.getType(Y)}}); 4456 } 4457 4458 void CombinerHelper::applyFunnelShiftToRotate(MachineInstr &MI) const { 4459 unsigned Opc = MI.getOpcode(); 4460 assert(Opc == TargetOpcode::G_FSHL || Opc == TargetOpcode::G_FSHR); 4461 bool IsFSHL = Opc == TargetOpcode::G_FSHL; 4462 Observer.changingInstr(MI); 4463 MI.setDesc(Builder.getTII().get(IsFSHL ? TargetOpcode::G_ROTL 4464 : TargetOpcode::G_ROTR)); 4465 MI.removeOperand(2); 4466 Observer.changedInstr(MI); 4467 } 4468 4469 // Fold (rot x, c) -> (rot x, c % BitSize) 4470 bool CombinerHelper::matchRotateOutOfRange(MachineInstr &MI) const { 4471 assert(MI.getOpcode() == TargetOpcode::G_ROTL || 4472 MI.getOpcode() == TargetOpcode::G_ROTR); 4473 unsigned Bitsize = 4474 MRI.getType(MI.getOperand(0).getReg()).getScalarSizeInBits(); 4475 Register AmtReg = MI.getOperand(2).getReg(); 4476 bool OutOfRange = false; 4477 auto MatchOutOfRange = [Bitsize, &OutOfRange](const Constant *C) { 4478 if (auto *CI = dyn_cast<ConstantInt>(C)) 4479 OutOfRange |= CI->getValue().uge(Bitsize); 4480 return true; 4481 }; 4482 return matchUnaryPredicate(MRI, AmtReg, MatchOutOfRange) && OutOfRange; 4483 } 4484 4485 void CombinerHelper::applyRotateOutOfRange(MachineInstr &MI) const { 4486 assert(MI.getOpcode() == TargetOpcode::G_ROTL || 4487 MI.getOpcode() == TargetOpcode::G_ROTR); 4488 unsigned Bitsize = 4489 MRI.getType(MI.getOperand(0).getReg()).getScalarSizeInBits(); 4490 Register Amt = MI.getOperand(2).getReg(); 4491 LLT AmtTy = MRI.getType(Amt); 4492 auto Bits = Builder.buildConstant(AmtTy, Bitsize); 4493 Amt = Builder.buildURem(AmtTy, MI.getOperand(2).getReg(), Bits).getReg(0); 4494 Observer.changingInstr(MI); 4495 MI.getOperand(2).setReg(Amt); 4496 Observer.changedInstr(MI); 4497 } 4498 4499 bool CombinerHelper::matchICmpToTrueFalseKnownBits(MachineInstr &MI, 4500 int64_t &MatchInfo) const { 4501 assert(MI.getOpcode() == TargetOpcode::G_ICMP); 4502 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate()); 4503 4504 // We want to avoid calling KnownBits on the LHS if possible, as this combine 4505 // has no filter and runs on every G_ICMP instruction. We can avoid calling 4506 // KnownBits on the LHS in two cases: 4507 // 4508 // - The RHS is unknown: Constants are always on RHS. If the RHS is unknown 4509 // we cannot do any transforms so we can safely bail out early. 4510 // - The RHS is zero: we don't need to know the LHS to do unsigned <0 and 4511 // >=0. 4512 auto KnownRHS = VT->getKnownBits(MI.getOperand(3).getReg()); 4513 if (KnownRHS.isUnknown()) 4514 return false; 4515 4516 std::optional<bool> KnownVal; 4517 if (KnownRHS.isZero()) { 4518 // ? uge 0 -> always true 4519 // ? ult 0 -> always false 4520 if (Pred == CmpInst::ICMP_UGE) 4521 KnownVal = true; 4522 else if (Pred == CmpInst::ICMP_ULT) 4523 KnownVal = false; 4524 } 4525 4526 if (!KnownVal) { 4527 auto KnownLHS = VT->getKnownBits(MI.getOperand(2).getReg()); 4528 KnownVal = ICmpInst::compare(KnownLHS, KnownRHS, Pred); 4529 } 4530 4531 if (!KnownVal) 4532 return false; 4533 MatchInfo = 4534 *KnownVal 4535 ? getICmpTrueVal(getTargetLowering(), 4536 /*IsVector = */ 4537 MRI.getType(MI.getOperand(0).getReg()).isVector(), 4538 /* IsFP = */ false) 4539 : 0; 4540 return true; 4541 } 4542 4543 bool CombinerHelper::matchICmpToLHSKnownBits( 4544 MachineInstr &MI, 4545 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 4546 assert(MI.getOpcode() == TargetOpcode::G_ICMP); 4547 // Given: 4548 // 4549 // %x = G_WHATEVER (... x is known to be 0 or 1 ...) 4550 // %cmp = G_ICMP ne %x, 0 4551 // 4552 // Or: 4553 // 4554 // %x = G_WHATEVER (... x is known to be 0 or 1 ...) 4555 // %cmp = G_ICMP eq %x, 1 4556 // 4557 // We can replace %cmp with %x assuming true is 1 on the target. 4558 auto Pred = static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate()); 4559 if (!CmpInst::isEquality(Pred)) 4560 return false; 4561 Register Dst = MI.getOperand(0).getReg(); 4562 LLT DstTy = MRI.getType(Dst); 4563 if (getICmpTrueVal(getTargetLowering(), DstTy.isVector(), 4564 /* IsFP = */ false) != 1) 4565 return false; 4566 int64_t OneOrZero = Pred == CmpInst::ICMP_EQ; 4567 if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICst(OneOrZero))) 4568 return false; 4569 Register LHS = MI.getOperand(2).getReg(); 4570 auto KnownLHS = VT->getKnownBits(LHS); 4571 if (KnownLHS.getMinValue() != 0 || KnownLHS.getMaxValue() != 1) 4572 return false; 4573 // Make sure replacing Dst with the LHS is a legal operation. 4574 LLT LHSTy = MRI.getType(LHS); 4575 unsigned LHSSize = LHSTy.getSizeInBits(); 4576 unsigned DstSize = DstTy.getSizeInBits(); 4577 unsigned Op = TargetOpcode::COPY; 4578 if (DstSize != LHSSize) 4579 Op = DstSize < LHSSize ? TargetOpcode::G_TRUNC : TargetOpcode::G_ZEXT; 4580 if (!isLegalOrBeforeLegalizer({Op, {DstTy, LHSTy}})) 4581 return false; 4582 MatchInfo = [=](MachineIRBuilder &B) { B.buildInstr(Op, {Dst}, {LHS}); }; 4583 return true; 4584 } 4585 4586 // Replace (and (or x, c1), c2) with (and x, c2) iff c1 & c2 == 0 4587 bool CombinerHelper::matchAndOrDisjointMask( 4588 MachineInstr &MI, 4589 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 4590 assert(MI.getOpcode() == TargetOpcode::G_AND); 4591 4592 // Ignore vector types to simplify matching the two constants. 4593 // TODO: do this for vectors and scalars via a demanded bits analysis. 4594 LLT Ty = MRI.getType(MI.getOperand(0).getReg()); 4595 if (Ty.isVector()) 4596 return false; 4597 4598 Register Src; 4599 Register AndMaskReg; 4600 int64_t AndMaskBits; 4601 int64_t OrMaskBits; 4602 if (!mi_match(MI, MRI, 4603 m_GAnd(m_GOr(m_Reg(Src), m_ICst(OrMaskBits)), 4604 m_all_of(m_ICst(AndMaskBits), m_Reg(AndMaskReg))))) 4605 return false; 4606 4607 // Check if OrMask could turn on any bits in Src. 4608 if (AndMaskBits & OrMaskBits) 4609 return false; 4610 4611 MatchInfo = [=, &MI](MachineIRBuilder &B) { 4612 Observer.changingInstr(MI); 4613 // Canonicalize the result to have the constant on the RHS. 4614 if (MI.getOperand(1).getReg() == AndMaskReg) 4615 MI.getOperand(2).setReg(AndMaskReg); 4616 MI.getOperand(1).setReg(Src); 4617 Observer.changedInstr(MI); 4618 }; 4619 return true; 4620 } 4621 4622 /// Form a G_SBFX from a G_SEXT_INREG fed by a right shift. 4623 bool CombinerHelper::matchBitfieldExtractFromSExtInReg( 4624 MachineInstr &MI, 4625 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 4626 assert(MI.getOpcode() == TargetOpcode::G_SEXT_INREG); 4627 Register Dst = MI.getOperand(0).getReg(); 4628 Register Src = MI.getOperand(1).getReg(); 4629 LLT Ty = MRI.getType(Src); 4630 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 4631 if (!LI || !LI->isLegalOrCustom({TargetOpcode::G_SBFX, {Ty, ExtractTy}})) 4632 return false; 4633 int64_t Width = MI.getOperand(2).getImm(); 4634 Register ShiftSrc; 4635 int64_t ShiftImm; 4636 if (!mi_match( 4637 Src, MRI, 4638 m_OneNonDBGUse(m_any_of(m_GAShr(m_Reg(ShiftSrc), m_ICst(ShiftImm)), 4639 m_GLShr(m_Reg(ShiftSrc), m_ICst(ShiftImm)))))) 4640 return false; 4641 if (ShiftImm < 0 || ShiftImm + Width > Ty.getScalarSizeInBits()) 4642 return false; 4643 4644 MatchInfo = [=](MachineIRBuilder &B) { 4645 auto Cst1 = B.buildConstant(ExtractTy, ShiftImm); 4646 auto Cst2 = B.buildConstant(ExtractTy, Width); 4647 B.buildSbfx(Dst, ShiftSrc, Cst1, Cst2); 4648 }; 4649 return true; 4650 } 4651 4652 /// Form a G_UBFX from "(a srl b) & mask", where b and mask are constants. 4653 bool CombinerHelper::matchBitfieldExtractFromAnd(MachineInstr &MI, 4654 BuildFnTy &MatchInfo) const { 4655 GAnd *And = cast<GAnd>(&MI); 4656 Register Dst = And->getReg(0); 4657 LLT Ty = MRI.getType(Dst); 4658 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 4659 // Note that isLegalOrBeforeLegalizer is stricter and does not take custom 4660 // into account. 4661 if (LI && !LI->isLegalOrCustom({TargetOpcode::G_UBFX, {Ty, ExtractTy}})) 4662 return false; 4663 4664 int64_t AndImm, LSBImm; 4665 Register ShiftSrc; 4666 const unsigned Size = Ty.getScalarSizeInBits(); 4667 if (!mi_match(And->getReg(0), MRI, 4668 m_GAnd(m_OneNonDBGUse(m_GLShr(m_Reg(ShiftSrc), m_ICst(LSBImm))), 4669 m_ICst(AndImm)))) 4670 return false; 4671 4672 // The mask is a mask of the low bits iff imm & (imm+1) == 0. 4673 auto MaybeMask = static_cast<uint64_t>(AndImm); 4674 if (MaybeMask & (MaybeMask + 1)) 4675 return false; 4676 4677 // LSB must fit within the register. 4678 if (static_cast<uint64_t>(LSBImm) >= Size) 4679 return false; 4680 4681 uint64_t Width = APInt(Size, AndImm).countr_one(); 4682 MatchInfo = [=](MachineIRBuilder &B) { 4683 auto WidthCst = B.buildConstant(ExtractTy, Width); 4684 auto LSBCst = B.buildConstant(ExtractTy, LSBImm); 4685 B.buildInstr(TargetOpcode::G_UBFX, {Dst}, {ShiftSrc, LSBCst, WidthCst}); 4686 }; 4687 return true; 4688 } 4689 4690 bool CombinerHelper::matchBitfieldExtractFromShr( 4691 MachineInstr &MI, 4692 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 4693 const unsigned Opcode = MI.getOpcode(); 4694 assert(Opcode == TargetOpcode::G_ASHR || Opcode == TargetOpcode::G_LSHR); 4695 4696 const Register Dst = MI.getOperand(0).getReg(); 4697 4698 const unsigned ExtrOpcode = Opcode == TargetOpcode::G_ASHR 4699 ? TargetOpcode::G_SBFX 4700 : TargetOpcode::G_UBFX; 4701 4702 // Check if the type we would use for the extract is legal 4703 LLT Ty = MRI.getType(Dst); 4704 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 4705 if (!LI || !LI->isLegalOrCustom({ExtrOpcode, {Ty, ExtractTy}})) 4706 return false; 4707 4708 Register ShlSrc; 4709 int64_t ShrAmt; 4710 int64_t ShlAmt; 4711 const unsigned Size = Ty.getScalarSizeInBits(); 4712 4713 // Try to match shr (shl x, c1), c2 4714 if (!mi_match(Dst, MRI, 4715 m_BinOp(Opcode, 4716 m_OneNonDBGUse(m_GShl(m_Reg(ShlSrc), m_ICst(ShlAmt))), 4717 m_ICst(ShrAmt)))) 4718 return false; 4719 4720 // Make sure that the shift sizes can fit a bitfield extract 4721 if (ShlAmt < 0 || ShlAmt > ShrAmt || ShrAmt >= Size) 4722 return false; 4723 4724 // Skip this combine if the G_SEXT_INREG combine could handle it 4725 if (Opcode == TargetOpcode::G_ASHR && ShlAmt == ShrAmt) 4726 return false; 4727 4728 // Calculate start position and width of the extract 4729 const int64_t Pos = ShrAmt - ShlAmt; 4730 const int64_t Width = Size - ShrAmt; 4731 4732 MatchInfo = [=](MachineIRBuilder &B) { 4733 auto WidthCst = B.buildConstant(ExtractTy, Width); 4734 auto PosCst = B.buildConstant(ExtractTy, Pos); 4735 B.buildInstr(ExtrOpcode, {Dst}, {ShlSrc, PosCst, WidthCst}); 4736 }; 4737 return true; 4738 } 4739 4740 bool CombinerHelper::matchBitfieldExtractFromShrAnd( 4741 MachineInstr &MI, 4742 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 4743 const unsigned Opcode = MI.getOpcode(); 4744 assert(Opcode == TargetOpcode::G_LSHR || Opcode == TargetOpcode::G_ASHR); 4745 4746 const Register Dst = MI.getOperand(0).getReg(); 4747 LLT Ty = MRI.getType(Dst); 4748 LLT ExtractTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 4749 if (LI && !LI->isLegalOrCustom({TargetOpcode::G_UBFX, {Ty, ExtractTy}})) 4750 return false; 4751 4752 // Try to match shr (and x, c1), c2 4753 Register AndSrc; 4754 int64_t ShrAmt; 4755 int64_t SMask; 4756 if (!mi_match(Dst, MRI, 4757 m_BinOp(Opcode, 4758 m_OneNonDBGUse(m_GAnd(m_Reg(AndSrc), m_ICst(SMask))), 4759 m_ICst(ShrAmt)))) 4760 return false; 4761 4762 const unsigned Size = Ty.getScalarSizeInBits(); 4763 if (ShrAmt < 0 || ShrAmt >= Size) 4764 return false; 4765 4766 // If the shift subsumes the mask, emit the 0 directly. 4767 if (0 == (SMask >> ShrAmt)) { 4768 MatchInfo = [=](MachineIRBuilder &B) { 4769 B.buildConstant(Dst, 0); 4770 }; 4771 return true; 4772 } 4773 4774 // Check that ubfx can do the extraction, with no holes in the mask. 4775 uint64_t UMask = SMask; 4776 UMask |= maskTrailingOnes<uint64_t>(ShrAmt); 4777 UMask &= maskTrailingOnes<uint64_t>(Size); 4778 if (!isMask_64(UMask)) 4779 return false; 4780 4781 // Calculate start position and width of the extract. 4782 const int64_t Pos = ShrAmt; 4783 const int64_t Width = llvm::countr_one(UMask) - ShrAmt; 4784 4785 // It's preferable to keep the shift, rather than form G_SBFX. 4786 // TODO: remove the G_AND via demanded bits analysis. 4787 if (Opcode == TargetOpcode::G_ASHR && Width + ShrAmt == Size) 4788 return false; 4789 4790 MatchInfo = [=](MachineIRBuilder &B) { 4791 auto WidthCst = B.buildConstant(ExtractTy, Width); 4792 auto PosCst = B.buildConstant(ExtractTy, Pos); 4793 B.buildInstr(TargetOpcode::G_UBFX, {Dst}, {AndSrc, PosCst, WidthCst}); 4794 }; 4795 return true; 4796 } 4797 4798 bool CombinerHelper::reassociationCanBreakAddressingModePattern( 4799 MachineInstr &MI) const { 4800 auto &PtrAdd = cast<GPtrAdd>(MI); 4801 4802 Register Src1Reg = PtrAdd.getBaseReg(); 4803 auto *Src1Def = getOpcodeDef<GPtrAdd>(Src1Reg, MRI); 4804 if (!Src1Def) 4805 return false; 4806 4807 Register Src2Reg = PtrAdd.getOffsetReg(); 4808 4809 if (MRI.hasOneNonDBGUse(Src1Reg)) 4810 return false; 4811 4812 auto C1 = getIConstantVRegVal(Src1Def->getOffsetReg(), MRI); 4813 if (!C1) 4814 return false; 4815 auto C2 = getIConstantVRegVal(Src2Reg, MRI); 4816 if (!C2) 4817 return false; 4818 4819 const APInt &C1APIntVal = *C1; 4820 const APInt &C2APIntVal = *C2; 4821 const int64_t CombinedValue = (C1APIntVal + C2APIntVal).getSExtValue(); 4822 4823 for (auto &UseMI : MRI.use_nodbg_instructions(PtrAdd.getReg(0))) { 4824 // This combine may end up running before ptrtoint/inttoptr combines 4825 // manage to eliminate redundant conversions, so try to look through them. 4826 MachineInstr *ConvUseMI = &UseMI; 4827 unsigned ConvUseOpc = ConvUseMI->getOpcode(); 4828 while (ConvUseOpc == TargetOpcode::G_INTTOPTR || 4829 ConvUseOpc == TargetOpcode::G_PTRTOINT) { 4830 Register DefReg = ConvUseMI->getOperand(0).getReg(); 4831 if (!MRI.hasOneNonDBGUse(DefReg)) 4832 break; 4833 ConvUseMI = &*MRI.use_instr_nodbg_begin(DefReg); 4834 ConvUseOpc = ConvUseMI->getOpcode(); 4835 } 4836 auto *LdStMI = dyn_cast<GLoadStore>(ConvUseMI); 4837 if (!LdStMI) 4838 continue; 4839 // Is x[offset2] already not a legal addressing mode? If so then 4840 // reassociating the constants breaks nothing (we test offset2 because 4841 // that's the one we hope to fold into the load or store). 4842 TargetLoweringBase::AddrMode AM; 4843 AM.HasBaseReg = true; 4844 AM.BaseOffs = C2APIntVal.getSExtValue(); 4845 unsigned AS = MRI.getType(LdStMI->getPointerReg()).getAddressSpace(); 4846 Type *AccessTy = getTypeForLLT(LdStMI->getMMO().getMemoryType(), 4847 PtrAdd.getMF()->getFunction().getContext()); 4848 const auto &TLI = *PtrAdd.getMF()->getSubtarget().getTargetLowering(); 4849 if (!TLI.isLegalAddressingMode(PtrAdd.getMF()->getDataLayout(), AM, 4850 AccessTy, AS)) 4851 continue; 4852 4853 // Would x[offset1+offset2] still be a legal addressing mode? 4854 AM.BaseOffs = CombinedValue; 4855 if (!TLI.isLegalAddressingMode(PtrAdd.getMF()->getDataLayout(), AM, 4856 AccessTy, AS)) 4857 return true; 4858 } 4859 4860 return false; 4861 } 4862 4863 bool CombinerHelper::matchReassocConstantInnerRHS(GPtrAdd &MI, 4864 MachineInstr *RHS, 4865 BuildFnTy &MatchInfo) const { 4866 // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C) 4867 Register Src1Reg = MI.getOperand(1).getReg(); 4868 if (RHS->getOpcode() != TargetOpcode::G_ADD) 4869 return false; 4870 auto C2 = getIConstantVRegVal(RHS->getOperand(2).getReg(), MRI); 4871 if (!C2) 4872 return false; 4873 4874 MatchInfo = [=, &MI](MachineIRBuilder &B) { 4875 LLT PtrTy = MRI.getType(MI.getOperand(0).getReg()); 4876 4877 auto NewBase = 4878 Builder.buildPtrAdd(PtrTy, Src1Reg, RHS->getOperand(1).getReg()); 4879 Observer.changingInstr(MI); 4880 MI.getOperand(1).setReg(NewBase.getReg(0)); 4881 MI.getOperand(2).setReg(RHS->getOperand(2).getReg()); 4882 Observer.changedInstr(MI); 4883 }; 4884 return !reassociationCanBreakAddressingModePattern(MI); 4885 } 4886 4887 bool CombinerHelper::matchReassocConstantInnerLHS(GPtrAdd &MI, 4888 MachineInstr *LHS, 4889 MachineInstr *RHS, 4890 BuildFnTy &MatchInfo) const { 4891 // G_PTR_ADD (G_PTR_ADD X, C), Y) -> (G_PTR_ADD (G_PTR_ADD(X, Y), C) 4892 // if and only if (G_PTR_ADD X, C) has one use. 4893 Register LHSBase; 4894 std::optional<ValueAndVReg> LHSCstOff; 4895 if (!mi_match(MI.getBaseReg(), MRI, 4896 m_OneNonDBGUse(m_GPtrAdd(m_Reg(LHSBase), m_GCst(LHSCstOff))))) 4897 return false; 4898 4899 auto *LHSPtrAdd = cast<GPtrAdd>(LHS); 4900 MatchInfo = [=, &MI](MachineIRBuilder &B) { 4901 // When we change LHSPtrAdd's offset register we might cause it to use a reg 4902 // before its def. Sink the instruction so the outer PTR_ADD to ensure this 4903 // doesn't happen. 4904 LHSPtrAdd->moveBefore(&MI); 4905 Register RHSReg = MI.getOffsetReg(); 4906 // set VReg will cause type mismatch if it comes from extend/trunc 4907 auto NewCst = B.buildConstant(MRI.getType(RHSReg), LHSCstOff->Value); 4908 Observer.changingInstr(MI); 4909 MI.getOperand(2).setReg(NewCst.getReg(0)); 4910 Observer.changedInstr(MI); 4911 Observer.changingInstr(*LHSPtrAdd); 4912 LHSPtrAdd->getOperand(2).setReg(RHSReg); 4913 Observer.changedInstr(*LHSPtrAdd); 4914 }; 4915 return !reassociationCanBreakAddressingModePattern(MI); 4916 } 4917 4918 bool CombinerHelper::matchReassocFoldConstantsInSubTree( 4919 GPtrAdd &MI, MachineInstr *LHS, MachineInstr *RHS, 4920 BuildFnTy &MatchInfo) const { 4921 // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2) 4922 auto *LHSPtrAdd = dyn_cast<GPtrAdd>(LHS); 4923 if (!LHSPtrAdd) 4924 return false; 4925 4926 Register Src2Reg = MI.getOperand(2).getReg(); 4927 Register LHSSrc1 = LHSPtrAdd->getBaseReg(); 4928 Register LHSSrc2 = LHSPtrAdd->getOffsetReg(); 4929 auto C1 = getIConstantVRegVal(LHSSrc2, MRI); 4930 if (!C1) 4931 return false; 4932 auto C2 = getIConstantVRegVal(Src2Reg, MRI); 4933 if (!C2) 4934 return false; 4935 4936 MatchInfo = [=, &MI](MachineIRBuilder &B) { 4937 auto NewCst = B.buildConstant(MRI.getType(Src2Reg), *C1 + *C2); 4938 Observer.changingInstr(MI); 4939 MI.getOperand(1).setReg(LHSSrc1); 4940 MI.getOperand(2).setReg(NewCst.getReg(0)); 4941 Observer.changedInstr(MI); 4942 }; 4943 return !reassociationCanBreakAddressingModePattern(MI); 4944 } 4945 4946 bool CombinerHelper::matchReassocPtrAdd(MachineInstr &MI, 4947 BuildFnTy &MatchInfo) const { 4948 auto &PtrAdd = cast<GPtrAdd>(MI); 4949 // We're trying to match a few pointer computation patterns here for 4950 // re-association opportunities. 4951 // 1) Isolating a constant operand to be on the RHS, e.g.: 4952 // G_PTR_ADD(BASE, G_ADD(X, C)) -> G_PTR_ADD(G_PTR_ADD(BASE, X), C) 4953 // 4954 // 2) Folding two constants in each sub-tree as long as such folding 4955 // doesn't break a legal addressing mode. 4956 // G_PTR_ADD(G_PTR_ADD(BASE, C1), C2) -> G_PTR_ADD(BASE, C1+C2) 4957 // 4958 // 3) Move a constant from the LHS of an inner op to the RHS of the outer. 4959 // G_PTR_ADD (G_PTR_ADD X, C), Y) -> G_PTR_ADD (G_PTR_ADD(X, Y), C) 4960 // iif (G_PTR_ADD X, C) has one use. 4961 MachineInstr *LHS = MRI.getVRegDef(PtrAdd.getBaseReg()); 4962 MachineInstr *RHS = MRI.getVRegDef(PtrAdd.getOffsetReg()); 4963 4964 // Try to match example 2. 4965 if (matchReassocFoldConstantsInSubTree(PtrAdd, LHS, RHS, MatchInfo)) 4966 return true; 4967 4968 // Try to match example 3. 4969 if (matchReassocConstantInnerLHS(PtrAdd, LHS, RHS, MatchInfo)) 4970 return true; 4971 4972 // Try to match example 1. 4973 if (matchReassocConstantInnerRHS(PtrAdd, RHS, MatchInfo)) 4974 return true; 4975 4976 return false; 4977 } 4978 bool CombinerHelper::tryReassocBinOp(unsigned Opc, Register DstReg, 4979 Register OpLHS, Register OpRHS, 4980 BuildFnTy &MatchInfo) const { 4981 LLT OpRHSTy = MRI.getType(OpRHS); 4982 MachineInstr *OpLHSDef = MRI.getVRegDef(OpLHS); 4983 4984 if (OpLHSDef->getOpcode() != Opc) 4985 return false; 4986 4987 MachineInstr *OpRHSDef = MRI.getVRegDef(OpRHS); 4988 Register OpLHSLHS = OpLHSDef->getOperand(1).getReg(); 4989 Register OpLHSRHS = OpLHSDef->getOperand(2).getReg(); 4990 4991 // If the inner op is (X op C), pull the constant out so it can be folded with 4992 // other constants in the expression tree. Folding is not guaranteed so we 4993 // might have (C1 op C2). In that case do not pull a constant out because it 4994 // won't help and can lead to infinite loops. 4995 if (isConstantOrConstantSplatVector(*MRI.getVRegDef(OpLHSRHS), MRI) && 4996 !isConstantOrConstantSplatVector(*MRI.getVRegDef(OpLHSLHS), MRI)) { 4997 if (isConstantOrConstantSplatVector(*OpRHSDef, MRI)) { 4998 // (Opc (Opc X, C1), C2) -> (Opc X, (Opc C1, C2)) 4999 MatchInfo = [=](MachineIRBuilder &B) { 5000 auto NewCst = B.buildInstr(Opc, {OpRHSTy}, {OpLHSRHS, OpRHS}); 5001 B.buildInstr(Opc, {DstReg}, {OpLHSLHS, NewCst}); 5002 }; 5003 return true; 5004 } 5005 if (getTargetLowering().isReassocProfitable(MRI, OpLHS, OpRHS)) { 5006 // Reassociate: (op (op x, c1), y) -> (op (op x, y), c1) 5007 // iff (op x, c1) has one use 5008 MatchInfo = [=](MachineIRBuilder &B) { 5009 auto NewLHSLHS = B.buildInstr(Opc, {OpRHSTy}, {OpLHSLHS, OpRHS}); 5010 B.buildInstr(Opc, {DstReg}, {NewLHSLHS, OpLHSRHS}); 5011 }; 5012 return true; 5013 } 5014 } 5015 5016 return false; 5017 } 5018 5019 bool CombinerHelper::matchReassocCommBinOp(MachineInstr &MI, 5020 BuildFnTy &MatchInfo) const { 5021 // We don't check if the reassociation will break a legal addressing mode 5022 // here since pointer arithmetic is handled by G_PTR_ADD. 5023 unsigned Opc = MI.getOpcode(); 5024 Register DstReg = MI.getOperand(0).getReg(); 5025 Register LHSReg = MI.getOperand(1).getReg(); 5026 Register RHSReg = MI.getOperand(2).getReg(); 5027 5028 if (tryReassocBinOp(Opc, DstReg, LHSReg, RHSReg, MatchInfo)) 5029 return true; 5030 if (tryReassocBinOp(Opc, DstReg, RHSReg, LHSReg, MatchInfo)) 5031 return true; 5032 return false; 5033 } 5034 5035 bool CombinerHelper::matchConstantFoldCastOp(MachineInstr &MI, 5036 APInt &MatchInfo) const { 5037 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 5038 Register SrcOp = MI.getOperand(1).getReg(); 5039 5040 if (auto MaybeCst = ConstantFoldCastOp(MI.getOpcode(), DstTy, SrcOp, MRI)) { 5041 MatchInfo = *MaybeCst; 5042 return true; 5043 } 5044 5045 return false; 5046 } 5047 5048 bool CombinerHelper::matchConstantFoldBinOp(MachineInstr &MI, 5049 APInt &MatchInfo) const { 5050 Register Op1 = MI.getOperand(1).getReg(); 5051 Register Op2 = MI.getOperand(2).getReg(); 5052 auto MaybeCst = ConstantFoldBinOp(MI.getOpcode(), Op1, Op2, MRI); 5053 if (!MaybeCst) 5054 return false; 5055 MatchInfo = *MaybeCst; 5056 return true; 5057 } 5058 5059 bool CombinerHelper::matchConstantFoldFPBinOp(MachineInstr &MI, 5060 ConstantFP *&MatchInfo) const { 5061 Register Op1 = MI.getOperand(1).getReg(); 5062 Register Op2 = MI.getOperand(2).getReg(); 5063 auto MaybeCst = ConstantFoldFPBinOp(MI.getOpcode(), Op1, Op2, MRI); 5064 if (!MaybeCst) 5065 return false; 5066 MatchInfo = 5067 ConstantFP::get(MI.getMF()->getFunction().getContext(), *MaybeCst); 5068 return true; 5069 } 5070 5071 bool CombinerHelper::matchConstantFoldFMA(MachineInstr &MI, 5072 ConstantFP *&MatchInfo) const { 5073 assert(MI.getOpcode() == TargetOpcode::G_FMA || 5074 MI.getOpcode() == TargetOpcode::G_FMAD); 5075 auto [_, Op1, Op2, Op3] = MI.getFirst4Regs(); 5076 5077 const ConstantFP *Op3Cst = getConstantFPVRegVal(Op3, MRI); 5078 if (!Op3Cst) 5079 return false; 5080 5081 const ConstantFP *Op2Cst = getConstantFPVRegVal(Op2, MRI); 5082 if (!Op2Cst) 5083 return false; 5084 5085 const ConstantFP *Op1Cst = getConstantFPVRegVal(Op1, MRI); 5086 if (!Op1Cst) 5087 return false; 5088 5089 APFloat Op1F = Op1Cst->getValueAPF(); 5090 Op1F.fusedMultiplyAdd(Op2Cst->getValueAPF(), Op3Cst->getValueAPF(), 5091 APFloat::rmNearestTiesToEven); 5092 MatchInfo = ConstantFP::get(MI.getMF()->getFunction().getContext(), Op1F); 5093 return true; 5094 } 5095 5096 bool CombinerHelper::matchNarrowBinopFeedingAnd( 5097 MachineInstr &MI, 5098 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 5099 // Look for a binop feeding into an AND with a mask: 5100 // 5101 // %add = G_ADD %lhs, %rhs 5102 // %and = G_AND %add, 000...11111111 5103 // 5104 // Check if it's possible to perform the binop at a narrower width and zext 5105 // back to the original width like so: 5106 // 5107 // %narrow_lhs = G_TRUNC %lhs 5108 // %narrow_rhs = G_TRUNC %rhs 5109 // %narrow_add = G_ADD %narrow_lhs, %narrow_rhs 5110 // %new_add = G_ZEXT %narrow_add 5111 // %and = G_AND %new_add, 000...11111111 5112 // 5113 // This can allow later combines to eliminate the G_AND if it turns out 5114 // that the mask is irrelevant. 5115 assert(MI.getOpcode() == TargetOpcode::G_AND); 5116 Register Dst = MI.getOperand(0).getReg(); 5117 Register AndLHS = MI.getOperand(1).getReg(); 5118 Register AndRHS = MI.getOperand(2).getReg(); 5119 LLT WideTy = MRI.getType(Dst); 5120 5121 // If the potential binop has more than one use, then it's possible that one 5122 // of those uses will need its full width. 5123 if (!WideTy.isScalar() || !MRI.hasOneNonDBGUse(AndLHS)) 5124 return false; 5125 5126 // Check if the LHS feeding the AND is impacted by the high bits that we're 5127 // masking out. 5128 // 5129 // e.g. for 64-bit x, y: 5130 // 5131 // add_64(x, y) & 65535 == zext(add_16(trunc(x), trunc(y))) & 65535 5132 MachineInstr *LHSInst = getDefIgnoringCopies(AndLHS, MRI); 5133 if (!LHSInst) 5134 return false; 5135 unsigned LHSOpc = LHSInst->getOpcode(); 5136 switch (LHSOpc) { 5137 default: 5138 return false; 5139 case TargetOpcode::G_ADD: 5140 case TargetOpcode::G_SUB: 5141 case TargetOpcode::G_MUL: 5142 case TargetOpcode::G_AND: 5143 case TargetOpcode::G_OR: 5144 case TargetOpcode::G_XOR: 5145 break; 5146 } 5147 5148 // Find the mask on the RHS. 5149 auto Cst = getIConstantVRegValWithLookThrough(AndRHS, MRI); 5150 if (!Cst) 5151 return false; 5152 auto Mask = Cst->Value; 5153 if (!Mask.isMask()) 5154 return false; 5155 5156 // No point in combining if there's nothing to truncate. 5157 unsigned NarrowWidth = Mask.countr_one(); 5158 if (NarrowWidth == WideTy.getSizeInBits()) 5159 return false; 5160 LLT NarrowTy = LLT::scalar(NarrowWidth); 5161 5162 // Check if adding the zext + truncates could be harmful. 5163 auto &MF = *MI.getMF(); 5164 const auto &TLI = getTargetLowering(); 5165 LLVMContext &Ctx = MF.getFunction().getContext(); 5166 if (!TLI.isTruncateFree(WideTy, NarrowTy, Ctx) || 5167 !TLI.isZExtFree(NarrowTy, WideTy, Ctx)) 5168 return false; 5169 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {NarrowTy, WideTy}}) || 5170 !isLegalOrBeforeLegalizer({TargetOpcode::G_ZEXT, {WideTy, NarrowTy}})) 5171 return false; 5172 Register BinOpLHS = LHSInst->getOperand(1).getReg(); 5173 Register BinOpRHS = LHSInst->getOperand(2).getReg(); 5174 MatchInfo = [=, &MI](MachineIRBuilder &B) { 5175 auto NarrowLHS = Builder.buildTrunc(NarrowTy, BinOpLHS); 5176 auto NarrowRHS = Builder.buildTrunc(NarrowTy, BinOpRHS); 5177 auto NarrowBinOp = 5178 Builder.buildInstr(LHSOpc, {NarrowTy}, {NarrowLHS, NarrowRHS}); 5179 auto Ext = Builder.buildZExt(WideTy, NarrowBinOp); 5180 Observer.changingInstr(MI); 5181 MI.getOperand(1).setReg(Ext.getReg(0)); 5182 Observer.changedInstr(MI); 5183 }; 5184 return true; 5185 } 5186 5187 bool CombinerHelper::matchMulOBy2(MachineInstr &MI, 5188 BuildFnTy &MatchInfo) const { 5189 unsigned Opc = MI.getOpcode(); 5190 assert(Opc == TargetOpcode::G_UMULO || Opc == TargetOpcode::G_SMULO); 5191 5192 if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICstOrSplat(2))) 5193 return false; 5194 5195 MatchInfo = [=, &MI](MachineIRBuilder &B) { 5196 Observer.changingInstr(MI); 5197 unsigned NewOpc = Opc == TargetOpcode::G_UMULO ? TargetOpcode::G_UADDO 5198 : TargetOpcode::G_SADDO; 5199 MI.setDesc(Builder.getTII().get(NewOpc)); 5200 MI.getOperand(3).setReg(MI.getOperand(2).getReg()); 5201 Observer.changedInstr(MI); 5202 }; 5203 return true; 5204 } 5205 5206 bool CombinerHelper::matchMulOBy0(MachineInstr &MI, 5207 BuildFnTy &MatchInfo) const { 5208 // (G_*MULO x, 0) -> 0 + no carry out 5209 assert(MI.getOpcode() == TargetOpcode::G_UMULO || 5210 MI.getOpcode() == TargetOpcode::G_SMULO); 5211 if (!mi_match(MI.getOperand(3).getReg(), MRI, m_SpecificICstOrSplat(0))) 5212 return false; 5213 Register Dst = MI.getOperand(0).getReg(); 5214 Register Carry = MI.getOperand(1).getReg(); 5215 if (!isConstantLegalOrBeforeLegalizer(MRI.getType(Dst)) || 5216 !isConstantLegalOrBeforeLegalizer(MRI.getType(Carry))) 5217 return false; 5218 MatchInfo = [=](MachineIRBuilder &B) { 5219 B.buildConstant(Dst, 0); 5220 B.buildConstant(Carry, 0); 5221 }; 5222 return true; 5223 } 5224 5225 bool CombinerHelper::matchAddEToAddO(MachineInstr &MI, 5226 BuildFnTy &MatchInfo) const { 5227 // (G_*ADDE x, y, 0) -> (G_*ADDO x, y) 5228 // (G_*SUBE x, y, 0) -> (G_*SUBO x, y) 5229 assert(MI.getOpcode() == TargetOpcode::G_UADDE || 5230 MI.getOpcode() == TargetOpcode::G_SADDE || 5231 MI.getOpcode() == TargetOpcode::G_USUBE || 5232 MI.getOpcode() == TargetOpcode::G_SSUBE); 5233 if (!mi_match(MI.getOperand(4).getReg(), MRI, m_SpecificICstOrSplat(0))) 5234 return false; 5235 MatchInfo = [&](MachineIRBuilder &B) { 5236 unsigned NewOpcode; 5237 switch (MI.getOpcode()) { 5238 case TargetOpcode::G_UADDE: 5239 NewOpcode = TargetOpcode::G_UADDO; 5240 break; 5241 case TargetOpcode::G_SADDE: 5242 NewOpcode = TargetOpcode::G_SADDO; 5243 break; 5244 case TargetOpcode::G_USUBE: 5245 NewOpcode = TargetOpcode::G_USUBO; 5246 break; 5247 case TargetOpcode::G_SSUBE: 5248 NewOpcode = TargetOpcode::G_SSUBO; 5249 break; 5250 } 5251 Observer.changingInstr(MI); 5252 MI.setDesc(B.getTII().get(NewOpcode)); 5253 MI.removeOperand(4); 5254 Observer.changedInstr(MI); 5255 }; 5256 return true; 5257 } 5258 5259 bool CombinerHelper::matchSubAddSameReg(MachineInstr &MI, 5260 BuildFnTy &MatchInfo) const { 5261 assert(MI.getOpcode() == TargetOpcode::G_SUB); 5262 Register Dst = MI.getOperand(0).getReg(); 5263 // (x + y) - z -> x (if y == z) 5264 // (x + y) - z -> y (if x == z) 5265 Register X, Y, Z; 5266 if (mi_match(Dst, MRI, m_GSub(m_GAdd(m_Reg(X), m_Reg(Y)), m_Reg(Z)))) { 5267 Register ReplaceReg; 5268 int64_t CstX, CstY; 5269 if (Y == Z || (mi_match(Y, MRI, m_ICstOrSplat(CstY)) && 5270 mi_match(Z, MRI, m_SpecificICstOrSplat(CstY)))) 5271 ReplaceReg = X; 5272 else if (X == Z || (mi_match(X, MRI, m_ICstOrSplat(CstX)) && 5273 mi_match(Z, MRI, m_SpecificICstOrSplat(CstX)))) 5274 ReplaceReg = Y; 5275 if (ReplaceReg) { 5276 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, ReplaceReg); }; 5277 return true; 5278 } 5279 } 5280 5281 // x - (y + z) -> 0 - y (if x == z) 5282 // x - (y + z) -> 0 - z (if x == y) 5283 if (mi_match(Dst, MRI, m_GSub(m_Reg(X), m_GAdd(m_Reg(Y), m_Reg(Z))))) { 5284 Register ReplaceReg; 5285 int64_t CstX; 5286 if (X == Z || (mi_match(X, MRI, m_ICstOrSplat(CstX)) && 5287 mi_match(Z, MRI, m_SpecificICstOrSplat(CstX)))) 5288 ReplaceReg = Y; 5289 else if (X == Y || (mi_match(X, MRI, m_ICstOrSplat(CstX)) && 5290 mi_match(Y, MRI, m_SpecificICstOrSplat(CstX)))) 5291 ReplaceReg = Z; 5292 if (ReplaceReg) { 5293 MatchInfo = [=](MachineIRBuilder &B) { 5294 auto Zero = B.buildConstant(MRI.getType(Dst), 0); 5295 B.buildSub(Dst, Zero, ReplaceReg); 5296 }; 5297 return true; 5298 } 5299 } 5300 return false; 5301 } 5302 5303 MachineInstr *CombinerHelper::buildUDivorURemUsingMul(MachineInstr &MI) const { 5304 unsigned Opcode = MI.getOpcode(); 5305 assert(Opcode == TargetOpcode::G_UDIV || Opcode == TargetOpcode::G_UREM); 5306 auto &UDivorRem = cast<GenericMachineInstr>(MI); 5307 Register Dst = UDivorRem.getReg(0); 5308 Register LHS = UDivorRem.getReg(1); 5309 Register RHS = UDivorRem.getReg(2); 5310 LLT Ty = MRI.getType(Dst); 5311 LLT ScalarTy = Ty.getScalarType(); 5312 const unsigned EltBits = ScalarTy.getScalarSizeInBits(); 5313 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 5314 LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType(); 5315 5316 auto &MIB = Builder; 5317 5318 bool UseSRL = false; 5319 SmallVector<Register, 16> Shifts, Factors; 5320 auto *RHSDefInstr = cast<GenericMachineInstr>(getDefIgnoringCopies(RHS, MRI)); 5321 bool IsSplat = getIConstantSplatVal(*RHSDefInstr, MRI).has_value(); 5322 5323 auto BuildExactUDIVPattern = [&](const Constant *C) { 5324 // Don't recompute inverses for each splat element. 5325 if (IsSplat && !Factors.empty()) { 5326 Shifts.push_back(Shifts[0]); 5327 Factors.push_back(Factors[0]); 5328 return true; 5329 } 5330 5331 auto *CI = cast<ConstantInt>(C); 5332 APInt Divisor = CI->getValue(); 5333 unsigned Shift = Divisor.countr_zero(); 5334 if (Shift) { 5335 Divisor.lshrInPlace(Shift); 5336 UseSRL = true; 5337 } 5338 5339 // Calculate the multiplicative inverse modulo BW. 5340 APInt Factor = Divisor.multiplicativeInverse(); 5341 Shifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0)); 5342 Factors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0)); 5343 return true; 5344 }; 5345 5346 if (MI.getFlag(MachineInstr::MIFlag::IsExact)) { 5347 // Collect all magic values from the build vector. 5348 if (!matchUnaryPredicate(MRI, RHS, BuildExactUDIVPattern)) 5349 llvm_unreachable("Expected unary predicate match to succeed"); 5350 5351 Register Shift, Factor; 5352 if (Ty.isVector()) { 5353 Shift = MIB.buildBuildVector(ShiftAmtTy, Shifts).getReg(0); 5354 Factor = MIB.buildBuildVector(Ty, Factors).getReg(0); 5355 } else { 5356 Shift = Shifts[0]; 5357 Factor = Factors[0]; 5358 } 5359 5360 Register Res = LHS; 5361 5362 if (UseSRL) 5363 Res = MIB.buildLShr(Ty, Res, Shift, MachineInstr::IsExact).getReg(0); 5364 5365 return MIB.buildMul(Ty, Res, Factor); 5366 } 5367 5368 unsigned KnownLeadingZeros = 5369 VT ? VT->getKnownBits(LHS).countMinLeadingZeros() : 0; 5370 5371 bool UseNPQ = false; 5372 SmallVector<Register, 16> PreShifts, PostShifts, MagicFactors, NPQFactors; 5373 auto BuildUDIVPattern = [&](const Constant *C) { 5374 auto *CI = cast<ConstantInt>(C); 5375 const APInt &Divisor = CI->getValue(); 5376 5377 bool SelNPQ = false; 5378 APInt Magic(Divisor.getBitWidth(), 0); 5379 unsigned PreShift = 0, PostShift = 0; 5380 5381 // Magic algorithm doesn't work for division by 1. We need to emit a select 5382 // at the end. 5383 // TODO: Use undef values for divisor of 1. 5384 if (!Divisor.isOne()) { 5385 5386 // UnsignedDivisionByConstantInfo doesn't work correctly if leading zeros 5387 // in the dividend exceeds the leading zeros for the divisor. 5388 UnsignedDivisionByConstantInfo magics = 5389 UnsignedDivisionByConstantInfo::get( 5390 Divisor, std::min(KnownLeadingZeros, Divisor.countl_zero())); 5391 5392 Magic = std::move(magics.Magic); 5393 5394 assert(magics.PreShift < Divisor.getBitWidth() && 5395 "We shouldn't generate an undefined shift!"); 5396 assert(magics.PostShift < Divisor.getBitWidth() && 5397 "We shouldn't generate an undefined shift!"); 5398 assert((!magics.IsAdd || magics.PreShift == 0) && "Unexpected pre-shift"); 5399 PreShift = magics.PreShift; 5400 PostShift = magics.PostShift; 5401 SelNPQ = magics.IsAdd; 5402 } 5403 5404 PreShifts.push_back( 5405 MIB.buildConstant(ScalarShiftAmtTy, PreShift).getReg(0)); 5406 MagicFactors.push_back(MIB.buildConstant(ScalarTy, Magic).getReg(0)); 5407 NPQFactors.push_back( 5408 MIB.buildConstant(ScalarTy, 5409 SelNPQ ? APInt::getOneBitSet(EltBits, EltBits - 1) 5410 : APInt::getZero(EltBits)) 5411 .getReg(0)); 5412 PostShifts.push_back( 5413 MIB.buildConstant(ScalarShiftAmtTy, PostShift).getReg(0)); 5414 UseNPQ |= SelNPQ; 5415 return true; 5416 }; 5417 5418 // Collect the shifts/magic values from each element. 5419 bool Matched = matchUnaryPredicate(MRI, RHS, BuildUDIVPattern); 5420 (void)Matched; 5421 assert(Matched && "Expected unary predicate match to succeed"); 5422 5423 Register PreShift, PostShift, MagicFactor, NPQFactor; 5424 auto *RHSDef = getOpcodeDef<GBuildVector>(RHS, MRI); 5425 if (RHSDef) { 5426 PreShift = MIB.buildBuildVector(ShiftAmtTy, PreShifts).getReg(0); 5427 MagicFactor = MIB.buildBuildVector(Ty, MagicFactors).getReg(0); 5428 NPQFactor = MIB.buildBuildVector(Ty, NPQFactors).getReg(0); 5429 PostShift = MIB.buildBuildVector(ShiftAmtTy, PostShifts).getReg(0); 5430 } else { 5431 assert(MRI.getType(RHS).isScalar() && 5432 "Non-build_vector operation should have been a scalar"); 5433 PreShift = PreShifts[0]; 5434 MagicFactor = MagicFactors[0]; 5435 PostShift = PostShifts[0]; 5436 } 5437 5438 Register Q = LHS; 5439 Q = MIB.buildLShr(Ty, Q, PreShift).getReg(0); 5440 5441 // Multiply the numerator (operand 0) by the magic value. 5442 Q = MIB.buildUMulH(Ty, Q, MagicFactor).getReg(0); 5443 5444 if (UseNPQ) { 5445 Register NPQ = MIB.buildSub(Ty, LHS, Q).getReg(0); 5446 5447 // For vectors we might have a mix of non-NPQ/NPQ paths, so use 5448 // G_UMULH to act as a SRL-by-1 for NPQ, else multiply by zero. 5449 if (Ty.isVector()) 5450 NPQ = MIB.buildUMulH(Ty, NPQ, NPQFactor).getReg(0); 5451 else 5452 NPQ = MIB.buildLShr(Ty, NPQ, MIB.buildConstant(ShiftAmtTy, 1)).getReg(0); 5453 5454 Q = MIB.buildAdd(Ty, NPQ, Q).getReg(0); 5455 } 5456 5457 Q = MIB.buildLShr(Ty, Q, PostShift).getReg(0); 5458 auto One = MIB.buildConstant(Ty, 1); 5459 auto IsOne = MIB.buildICmp( 5460 CmpInst::Predicate::ICMP_EQ, 5461 Ty.isScalar() ? LLT::scalar(1) : Ty.changeElementSize(1), RHS, One); 5462 auto ret = MIB.buildSelect(Ty, IsOne, LHS, Q); 5463 5464 if (Opcode == TargetOpcode::G_UREM) { 5465 auto Prod = MIB.buildMul(Ty, ret, RHS); 5466 return MIB.buildSub(Ty, LHS, Prod); 5467 } 5468 return ret; 5469 } 5470 5471 bool CombinerHelper::matchUDivorURemByConst(MachineInstr &MI) const { 5472 unsigned Opcode = MI.getOpcode(); 5473 assert(Opcode == TargetOpcode::G_UDIV || Opcode == TargetOpcode::G_UREM); 5474 Register Dst = MI.getOperand(0).getReg(); 5475 Register RHS = MI.getOperand(2).getReg(); 5476 LLT DstTy = MRI.getType(Dst); 5477 5478 auto &MF = *MI.getMF(); 5479 AttributeList Attr = MF.getFunction().getAttributes(); 5480 const auto &TLI = getTargetLowering(); 5481 LLVMContext &Ctx = MF.getFunction().getContext(); 5482 if (TLI.isIntDivCheap(getApproximateEVTForLLT(DstTy, Ctx), Attr)) 5483 return false; 5484 5485 // Don't do this for minsize because the instruction sequence is usually 5486 // larger. 5487 if (MF.getFunction().hasMinSize()) 5488 return false; 5489 5490 if (Opcode == TargetOpcode::G_UDIV && 5491 MI.getFlag(MachineInstr::MIFlag::IsExact)) { 5492 return matchUnaryPredicate( 5493 MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); }); 5494 } 5495 5496 auto *RHSDef = MRI.getVRegDef(RHS); 5497 if (!isConstantOrConstantVector(*RHSDef, MRI)) 5498 return false; 5499 5500 // Don't do this if the types are not going to be legal. 5501 if (LI) { 5502 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_MUL, {DstTy, DstTy}})) 5503 return false; 5504 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_UMULH, {DstTy}})) 5505 return false; 5506 if (!isLegalOrBeforeLegalizer( 5507 {TargetOpcode::G_ICMP, 5508 {DstTy.isVector() ? DstTy.changeElementSize(1) : LLT::scalar(1), 5509 DstTy}})) 5510 return false; 5511 if (Opcode == TargetOpcode::G_UREM && 5512 !isLegalOrBeforeLegalizer({TargetOpcode::G_SUB, {DstTy, DstTy}})) 5513 return false; 5514 } 5515 5516 return matchUnaryPredicate( 5517 MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); }); 5518 } 5519 5520 void CombinerHelper::applyUDivorURemByConst(MachineInstr &MI) const { 5521 auto *NewMI = buildUDivorURemUsingMul(MI); 5522 replaceSingleDefInstWithReg(MI, NewMI->getOperand(0).getReg()); 5523 } 5524 5525 bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const { 5526 assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV"); 5527 Register Dst = MI.getOperand(0).getReg(); 5528 Register RHS = MI.getOperand(2).getReg(); 5529 LLT DstTy = MRI.getType(Dst); 5530 auto SizeInBits = DstTy.getScalarSizeInBits(); 5531 LLT WideTy = DstTy.changeElementSize(SizeInBits * 2); 5532 5533 auto &MF = *MI.getMF(); 5534 AttributeList Attr = MF.getFunction().getAttributes(); 5535 const auto &TLI = getTargetLowering(); 5536 LLVMContext &Ctx = MF.getFunction().getContext(); 5537 if (TLI.isIntDivCheap(getApproximateEVTForLLT(DstTy, Ctx), Attr)) 5538 return false; 5539 5540 // Don't do this for minsize because the instruction sequence is usually 5541 // larger. 5542 if (MF.getFunction().hasMinSize()) 5543 return false; 5544 5545 // If the sdiv has an 'exact' flag we can use a simpler lowering. 5546 if (MI.getFlag(MachineInstr::MIFlag::IsExact)) { 5547 return matchUnaryPredicate( 5548 MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); }); 5549 } 5550 5551 auto *RHSDef = MRI.getVRegDef(RHS); 5552 if (!isConstantOrConstantVector(*RHSDef, MRI)) 5553 return false; 5554 5555 // Don't do this if the types are not going to be legal. 5556 if (LI) { 5557 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_MUL, {DstTy, DstTy}})) 5558 return false; 5559 if (!isLegal({TargetOpcode::G_SMULH, {DstTy}}) && 5560 !isLegalOrHasWidenScalar({TargetOpcode::G_MUL, {WideTy, WideTy}})) 5561 return false; 5562 } 5563 5564 return matchUnaryPredicate( 5565 MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); }); 5566 } 5567 5568 void CombinerHelper::applySDivByConst(MachineInstr &MI) const { 5569 auto *NewMI = buildSDivUsingMul(MI); 5570 replaceSingleDefInstWithReg(MI, NewMI->getOperand(0).getReg()); 5571 } 5572 5573 MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) const { 5574 assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV"); 5575 auto &SDiv = cast<GenericMachineInstr>(MI); 5576 Register Dst = SDiv.getReg(0); 5577 Register LHS = SDiv.getReg(1); 5578 Register RHS = SDiv.getReg(2); 5579 LLT Ty = MRI.getType(Dst); 5580 LLT ScalarTy = Ty.getScalarType(); 5581 const unsigned EltBits = ScalarTy.getScalarSizeInBits(); 5582 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 5583 LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType(); 5584 auto &MIB = Builder; 5585 5586 bool UseSRA = false; 5587 SmallVector<Register, 16> ExactShifts, ExactFactors; 5588 5589 auto *RHSDefInstr = cast<GenericMachineInstr>(getDefIgnoringCopies(RHS, MRI)); 5590 bool IsSplat = getIConstantSplatVal(*RHSDefInstr, MRI).has_value(); 5591 5592 auto BuildExactSDIVPattern = [&](const Constant *C) { 5593 // Don't recompute inverses for each splat element. 5594 if (IsSplat && !ExactFactors.empty()) { 5595 ExactShifts.push_back(ExactShifts[0]); 5596 ExactFactors.push_back(ExactFactors[0]); 5597 return true; 5598 } 5599 5600 auto *CI = cast<ConstantInt>(C); 5601 APInt Divisor = CI->getValue(); 5602 unsigned Shift = Divisor.countr_zero(); 5603 if (Shift) { 5604 Divisor.ashrInPlace(Shift); 5605 UseSRA = true; 5606 } 5607 5608 // Calculate the multiplicative inverse modulo BW. 5609 // 2^W requires W + 1 bits, so we have to extend and then truncate. 5610 APInt Factor = Divisor.multiplicativeInverse(); 5611 ExactShifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0)); 5612 ExactFactors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0)); 5613 return true; 5614 }; 5615 5616 if (MI.getFlag(MachineInstr::MIFlag::IsExact)) { 5617 // Collect all magic values from the build vector. 5618 bool Matched = matchUnaryPredicate(MRI, RHS, BuildExactSDIVPattern); 5619 (void)Matched; 5620 assert(Matched && "Expected unary predicate match to succeed"); 5621 5622 Register Shift, Factor; 5623 if (Ty.isVector()) { 5624 Shift = MIB.buildBuildVector(ShiftAmtTy, ExactShifts).getReg(0); 5625 Factor = MIB.buildBuildVector(Ty, ExactFactors).getReg(0); 5626 } else { 5627 Shift = ExactShifts[0]; 5628 Factor = ExactFactors[0]; 5629 } 5630 5631 Register Res = LHS; 5632 5633 if (UseSRA) 5634 Res = MIB.buildAShr(Ty, Res, Shift, MachineInstr::IsExact).getReg(0); 5635 5636 return MIB.buildMul(Ty, Res, Factor); 5637 } 5638 5639 SmallVector<Register, 16> MagicFactors, Factors, Shifts, ShiftMasks; 5640 5641 auto BuildSDIVPattern = [&](const Constant *C) { 5642 auto *CI = cast<ConstantInt>(C); 5643 const APInt &Divisor = CI->getValue(); 5644 5645 SignedDivisionByConstantInfo Magics = 5646 SignedDivisionByConstantInfo::get(Divisor); 5647 int NumeratorFactor = 0; 5648 int ShiftMask = -1; 5649 5650 if (Divisor.isOne() || Divisor.isAllOnes()) { 5651 // If d is +1/-1, we just multiply the numerator by +1/-1. 5652 NumeratorFactor = Divisor.getSExtValue(); 5653 Magics.Magic = 0; 5654 Magics.ShiftAmount = 0; 5655 ShiftMask = 0; 5656 } else if (Divisor.isStrictlyPositive() && Magics.Magic.isNegative()) { 5657 // If d > 0 and m < 0, add the numerator. 5658 NumeratorFactor = 1; 5659 } else if (Divisor.isNegative() && Magics.Magic.isStrictlyPositive()) { 5660 // If d < 0 and m > 0, subtract the numerator. 5661 NumeratorFactor = -1; 5662 } 5663 5664 MagicFactors.push_back(MIB.buildConstant(ScalarTy, Magics.Magic).getReg(0)); 5665 Factors.push_back(MIB.buildConstant(ScalarTy, NumeratorFactor).getReg(0)); 5666 Shifts.push_back( 5667 MIB.buildConstant(ScalarShiftAmtTy, Magics.ShiftAmount).getReg(0)); 5668 ShiftMasks.push_back(MIB.buildConstant(ScalarTy, ShiftMask).getReg(0)); 5669 5670 return true; 5671 }; 5672 5673 // Collect the shifts/magic values from each element. 5674 bool Matched = matchUnaryPredicate(MRI, RHS, BuildSDIVPattern); 5675 (void)Matched; 5676 assert(Matched && "Expected unary predicate match to succeed"); 5677 5678 Register MagicFactor, Factor, Shift, ShiftMask; 5679 auto *RHSDef = getOpcodeDef<GBuildVector>(RHS, MRI); 5680 if (RHSDef) { 5681 MagicFactor = MIB.buildBuildVector(Ty, MagicFactors).getReg(0); 5682 Factor = MIB.buildBuildVector(Ty, Factors).getReg(0); 5683 Shift = MIB.buildBuildVector(ShiftAmtTy, Shifts).getReg(0); 5684 ShiftMask = MIB.buildBuildVector(Ty, ShiftMasks).getReg(0); 5685 } else { 5686 assert(MRI.getType(RHS).isScalar() && 5687 "Non-build_vector operation should have been a scalar"); 5688 MagicFactor = MagicFactors[0]; 5689 Factor = Factors[0]; 5690 Shift = Shifts[0]; 5691 ShiftMask = ShiftMasks[0]; 5692 } 5693 5694 Register Q = LHS; 5695 Q = MIB.buildSMulH(Ty, LHS, MagicFactor).getReg(0); 5696 5697 // (Optionally) Add/subtract the numerator using Factor. 5698 Factor = MIB.buildMul(Ty, LHS, Factor).getReg(0); 5699 Q = MIB.buildAdd(Ty, Q, Factor).getReg(0); 5700 5701 // Shift right algebraic by shift value. 5702 Q = MIB.buildAShr(Ty, Q, Shift).getReg(0); 5703 5704 // Extract the sign bit, mask it and add it to the quotient. 5705 auto SignShift = MIB.buildConstant(ShiftAmtTy, EltBits - 1); 5706 auto T = MIB.buildLShr(Ty, Q, SignShift); 5707 T = MIB.buildAnd(Ty, T, ShiftMask); 5708 return MIB.buildAdd(Ty, Q, T); 5709 } 5710 5711 bool CombinerHelper::matchDivByPow2(MachineInstr &MI, bool IsSigned) const { 5712 assert((MI.getOpcode() == TargetOpcode::G_SDIV || 5713 MI.getOpcode() == TargetOpcode::G_UDIV) && 5714 "Expected SDIV or UDIV"); 5715 auto &Div = cast<GenericMachineInstr>(MI); 5716 Register RHS = Div.getReg(2); 5717 auto MatchPow2 = [&](const Constant *C) { 5718 auto *CI = dyn_cast<ConstantInt>(C); 5719 return CI && (CI->getValue().isPowerOf2() || 5720 (IsSigned && CI->getValue().isNegatedPowerOf2())); 5721 }; 5722 return matchUnaryPredicate(MRI, RHS, MatchPow2, /*AllowUndefs=*/false); 5723 } 5724 5725 void CombinerHelper::applySDivByPow2(MachineInstr &MI) const { 5726 assert(MI.getOpcode() == TargetOpcode::G_SDIV && "Expected SDIV"); 5727 auto &SDiv = cast<GenericMachineInstr>(MI); 5728 Register Dst = SDiv.getReg(0); 5729 Register LHS = SDiv.getReg(1); 5730 Register RHS = SDiv.getReg(2); 5731 LLT Ty = MRI.getType(Dst); 5732 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 5733 LLT CCVT = 5734 Ty.isVector() ? LLT::vector(Ty.getElementCount(), 1) : LLT::scalar(1); 5735 5736 // Effectively we want to lower G_SDIV %lhs, %rhs, where %rhs is a power of 2, 5737 // to the following version: 5738 // 5739 // %c1 = G_CTTZ %rhs 5740 // %inexact = G_SUB $bitwidth, %c1 5741 // %sign = %G_ASHR %lhs, $(bitwidth - 1) 5742 // %lshr = G_LSHR %sign, %inexact 5743 // %add = G_ADD %lhs, %lshr 5744 // %ashr = G_ASHR %add, %c1 5745 // %ashr = G_SELECT, %isoneorallones, %lhs, %ashr 5746 // %zero = G_CONSTANT $0 5747 // %neg = G_NEG %ashr 5748 // %isneg = G_ICMP SLT %rhs, %zero 5749 // %res = G_SELECT %isneg, %neg, %ashr 5750 5751 unsigned BitWidth = Ty.getScalarSizeInBits(); 5752 auto Zero = Builder.buildConstant(Ty, 0); 5753 5754 auto Bits = Builder.buildConstant(ShiftAmtTy, BitWidth); 5755 auto C1 = Builder.buildCTTZ(ShiftAmtTy, RHS); 5756 auto Inexact = Builder.buildSub(ShiftAmtTy, Bits, C1); 5757 // Splat the sign bit into the register 5758 auto Sign = Builder.buildAShr( 5759 Ty, LHS, Builder.buildConstant(ShiftAmtTy, BitWidth - 1)); 5760 5761 // Add (LHS < 0) ? abs2 - 1 : 0; 5762 auto LSrl = Builder.buildLShr(Ty, Sign, Inexact); 5763 auto Add = Builder.buildAdd(Ty, LHS, LSrl); 5764 auto AShr = Builder.buildAShr(Ty, Add, C1); 5765 5766 // Special case: (sdiv X, 1) -> X 5767 // Special Case: (sdiv X, -1) -> 0-X 5768 auto One = Builder.buildConstant(Ty, 1); 5769 auto MinusOne = Builder.buildConstant(Ty, -1); 5770 auto IsOne = Builder.buildICmp(CmpInst::Predicate::ICMP_EQ, CCVT, RHS, One); 5771 auto IsMinusOne = 5772 Builder.buildICmp(CmpInst::Predicate::ICMP_EQ, CCVT, RHS, MinusOne); 5773 auto IsOneOrMinusOne = Builder.buildOr(CCVT, IsOne, IsMinusOne); 5774 AShr = Builder.buildSelect(Ty, IsOneOrMinusOne, LHS, AShr); 5775 5776 // If divided by a positive value, we're done. Otherwise, the result must be 5777 // negated. 5778 auto Neg = Builder.buildNeg(Ty, AShr); 5779 auto IsNeg = Builder.buildICmp(CmpInst::Predicate::ICMP_SLT, CCVT, RHS, Zero); 5780 Builder.buildSelect(MI.getOperand(0).getReg(), IsNeg, Neg, AShr); 5781 MI.eraseFromParent(); 5782 } 5783 5784 void CombinerHelper::applyUDivByPow2(MachineInstr &MI) const { 5785 assert(MI.getOpcode() == TargetOpcode::G_UDIV && "Expected UDIV"); 5786 auto &UDiv = cast<GenericMachineInstr>(MI); 5787 Register Dst = UDiv.getReg(0); 5788 Register LHS = UDiv.getReg(1); 5789 Register RHS = UDiv.getReg(2); 5790 LLT Ty = MRI.getType(Dst); 5791 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 5792 5793 auto C1 = Builder.buildCTTZ(ShiftAmtTy, RHS); 5794 Builder.buildLShr(MI.getOperand(0).getReg(), LHS, C1); 5795 MI.eraseFromParent(); 5796 } 5797 5798 bool CombinerHelper::matchUMulHToLShr(MachineInstr &MI) const { 5799 assert(MI.getOpcode() == TargetOpcode::G_UMULH); 5800 Register RHS = MI.getOperand(2).getReg(); 5801 Register Dst = MI.getOperand(0).getReg(); 5802 LLT Ty = MRI.getType(Dst); 5803 LLT RHSTy = MRI.getType(RHS); 5804 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 5805 auto MatchPow2ExceptOne = [&](const Constant *C) { 5806 if (auto *CI = dyn_cast<ConstantInt>(C)) 5807 return CI->getValue().isPowerOf2() && !CI->getValue().isOne(); 5808 return false; 5809 }; 5810 if (!matchUnaryPredicate(MRI, RHS, MatchPow2ExceptOne, false)) 5811 return false; 5812 // We need to check both G_LSHR and G_CTLZ because the combine uses G_CTLZ to 5813 // get log base 2, and it is not always legal for on a target. 5814 return isLegalOrBeforeLegalizer({TargetOpcode::G_LSHR, {Ty, ShiftAmtTy}}) && 5815 isLegalOrBeforeLegalizer({TargetOpcode::G_CTLZ, {RHSTy, RHSTy}}); 5816 } 5817 5818 void CombinerHelper::applyUMulHToLShr(MachineInstr &MI) const { 5819 Register LHS = MI.getOperand(1).getReg(); 5820 Register RHS = MI.getOperand(2).getReg(); 5821 Register Dst = MI.getOperand(0).getReg(); 5822 LLT Ty = MRI.getType(Dst); 5823 LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty); 5824 unsigned NumEltBits = Ty.getScalarSizeInBits(); 5825 5826 auto LogBase2 = buildLogBase2(RHS, Builder); 5827 auto ShiftAmt = 5828 Builder.buildSub(Ty, Builder.buildConstant(Ty, NumEltBits), LogBase2); 5829 auto Trunc = Builder.buildZExtOrTrunc(ShiftAmtTy, ShiftAmt); 5830 Builder.buildLShr(Dst, LHS, Trunc); 5831 MI.eraseFromParent(); 5832 } 5833 5834 bool CombinerHelper::matchRedundantNegOperands(MachineInstr &MI, 5835 BuildFnTy &MatchInfo) const { 5836 unsigned Opc = MI.getOpcode(); 5837 assert(Opc == TargetOpcode::G_FADD || Opc == TargetOpcode::G_FSUB || 5838 Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV || 5839 Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA); 5840 5841 Register Dst = MI.getOperand(0).getReg(); 5842 Register X = MI.getOperand(1).getReg(); 5843 Register Y = MI.getOperand(2).getReg(); 5844 LLT Type = MRI.getType(Dst); 5845 5846 // fold (fadd x, fneg(y)) -> (fsub x, y) 5847 // fold (fadd fneg(y), x) -> (fsub x, y) 5848 // G_ADD is commutative so both cases are checked by m_GFAdd 5849 if (mi_match(Dst, MRI, m_GFAdd(m_Reg(X), m_GFNeg(m_Reg(Y)))) && 5850 isLegalOrBeforeLegalizer({TargetOpcode::G_FSUB, {Type}})) { 5851 Opc = TargetOpcode::G_FSUB; 5852 } 5853 /// fold (fsub x, fneg(y)) -> (fadd x, y) 5854 else if (mi_match(Dst, MRI, m_GFSub(m_Reg(X), m_GFNeg(m_Reg(Y)))) && 5855 isLegalOrBeforeLegalizer({TargetOpcode::G_FADD, {Type}})) { 5856 Opc = TargetOpcode::G_FADD; 5857 } 5858 // fold (fmul fneg(x), fneg(y)) -> (fmul x, y) 5859 // fold (fdiv fneg(x), fneg(y)) -> (fdiv x, y) 5860 // fold (fmad fneg(x), fneg(y), z) -> (fmad x, y, z) 5861 // fold (fma fneg(x), fneg(y), z) -> (fma x, y, z) 5862 else if ((Opc == TargetOpcode::G_FMUL || Opc == TargetOpcode::G_FDIV || 5863 Opc == TargetOpcode::G_FMAD || Opc == TargetOpcode::G_FMA) && 5864 mi_match(X, MRI, m_GFNeg(m_Reg(X))) && 5865 mi_match(Y, MRI, m_GFNeg(m_Reg(Y)))) { 5866 // no opcode change 5867 } else 5868 return false; 5869 5870 MatchInfo = [=, &MI](MachineIRBuilder &B) { 5871 Observer.changingInstr(MI); 5872 MI.setDesc(B.getTII().get(Opc)); 5873 MI.getOperand(1).setReg(X); 5874 MI.getOperand(2).setReg(Y); 5875 Observer.changedInstr(MI); 5876 }; 5877 return true; 5878 } 5879 5880 bool CombinerHelper::matchFsubToFneg(MachineInstr &MI, 5881 Register &MatchInfo) const { 5882 assert(MI.getOpcode() == TargetOpcode::G_FSUB); 5883 5884 Register LHS = MI.getOperand(1).getReg(); 5885 MatchInfo = MI.getOperand(2).getReg(); 5886 LLT Ty = MRI.getType(MI.getOperand(0).getReg()); 5887 5888 const auto LHSCst = Ty.isVector() 5889 ? getFConstantSplat(LHS, MRI, /* allowUndef */ true) 5890 : getFConstantVRegValWithLookThrough(LHS, MRI); 5891 if (!LHSCst) 5892 return false; 5893 5894 // -0.0 is always allowed 5895 if (LHSCst->Value.isNegZero()) 5896 return true; 5897 5898 // +0.0 is only allowed if nsz is set. 5899 if (LHSCst->Value.isPosZero()) 5900 return MI.getFlag(MachineInstr::FmNsz); 5901 5902 return false; 5903 } 5904 5905 void CombinerHelper::applyFsubToFneg(MachineInstr &MI, 5906 Register &MatchInfo) const { 5907 Register Dst = MI.getOperand(0).getReg(); 5908 Builder.buildFNeg( 5909 Dst, Builder.buildFCanonicalize(MRI.getType(Dst), MatchInfo).getReg(0)); 5910 eraseInst(MI); 5911 } 5912 5913 /// Checks if \p MI is TargetOpcode::G_FMUL and contractable either 5914 /// due to global flags or MachineInstr flags. 5915 static bool isContractableFMul(MachineInstr &MI, bool AllowFusionGlobally) { 5916 if (MI.getOpcode() != TargetOpcode::G_FMUL) 5917 return false; 5918 return AllowFusionGlobally || MI.getFlag(MachineInstr::MIFlag::FmContract); 5919 } 5920 5921 static bool hasMoreUses(const MachineInstr &MI0, const MachineInstr &MI1, 5922 const MachineRegisterInfo &MRI) { 5923 return std::distance(MRI.use_instr_nodbg_begin(MI0.getOperand(0).getReg()), 5924 MRI.use_instr_nodbg_end()) > 5925 std::distance(MRI.use_instr_nodbg_begin(MI1.getOperand(0).getReg()), 5926 MRI.use_instr_nodbg_end()); 5927 } 5928 5929 bool CombinerHelper::canCombineFMadOrFMA(MachineInstr &MI, 5930 bool &AllowFusionGlobally, 5931 bool &HasFMAD, bool &Aggressive, 5932 bool CanReassociate) const { 5933 5934 auto *MF = MI.getMF(); 5935 const auto &TLI = *MF->getSubtarget().getTargetLowering(); 5936 const TargetOptions &Options = MF->getTarget().Options; 5937 LLT DstType = MRI.getType(MI.getOperand(0).getReg()); 5938 5939 if (CanReassociate && 5940 !(Options.UnsafeFPMath || MI.getFlag(MachineInstr::MIFlag::FmReassoc))) 5941 return false; 5942 5943 // Floating-point multiply-add with intermediate rounding. 5944 HasFMAD = (!isPreLegalize() && TLI.isFMADLegal(MI, DstType)); 5945 // Floating-point multiply-add without intermediate rounding. 5946 bool HasFMA = TLI.isFMAFasterThanFMulAndFAdd(*MF, DstType) && 5947 isLegalOrBeforeLegalizer({TargetOpcode::G_FMA, {DstType}}); 5948 // No valid opcode, do not combine. 5949 if (!HasFMAD && !HasFMA) 5950 return false; 5951 5952 AllowFusionGlobally = Options.AllowFPOpFusion == FPOpFusion::Fast || 5953 Options.UnsafeFPMath || HasFMAD; 5954 // If the addition is not contractable, do not combine. 5955 if (!AllowFusionGlobally && !MI.getFlag(MachineInstr::MIFlag::FmContract)) 5956 return false; 5957 5958 Aggressive = TLI.enableAggressiveFMAFusion(DstType); 5959 return true; 5960 } 5961 5962 bool CombinerHelper::matchCombineFAddFMulToFMadOrFMA( 5963 MachineInstr &MI, 5964 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 5965 assert(MI.getOpcode() == TargetOpcode::G_FADD); 5966 5967 bool AllowFusionGlobally, HasFMAD, Aggressive; 5968 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) 5969 return false; 5970 5971 Register Op1 = MI.getOperand(1).getReg(); 5972 Register Op2 = MI.getOperand(2).getReg(); 5973 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1}; 5974 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2}; 5975 unsigned PreferredFusedOpcode = 5976 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; 5977 5978 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), 5979 // prefer to fold the multiply with fewer uses. 5980 if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) && 5981 isContractableFMul(*RHS.MI, AllowFusionGlobally)) { 5982 if (hasMoreUses(*LHS.MI, *RHS.MI, MRI)) 5983 std::swap(LHS, RHS); 5984 } 5985 5986 // fold (fadd (fmul x, y), z) -> (fma x, y, z) 5987 if (isContractableFMul(*LHS.MI, AllowFusionGlobally) && 5988 (Aggressive || MRI.hasOneNonDBGUse(LHS.Reg))) { 5989 MatchInfo = [=, &MI](MachineIRBuilder &B) { 5990 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 5991 {LHS.MI->getOperand(1).getReg(), 5992 LHS.MI->getOperand(2).getReg(), RHS.Reg}); 5993 }; 5994 return true; 5995 } 5996 5997 // fold (fadd x, (fmul y, z)) -> (fma y, z, x) 5998 if (isContractableFMul(*RHS.MI, AllowFusionGlobally) && 5999 (Aggressive || MRI.hasOneNonDBGUse(RHS.Reg))) { 6000 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6001 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 6002 {RHS.MI->getOperand(1).getReg(), 6003 RHS.MI->getOperand(2).getReg(), LHS.Reg}); 6004 }; 6005 return true; 6006 } 6007 6008 return false; 6009 } 6010 6011 bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMA( 6012 MachineInstr &MI, 6013 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 6014 assert(MI.getOpcode() == TargetOpcode::G_FADD); 6015 6016 bool AllowFusionGlobally, HasFMAD, Aggressive; 6017 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) 6018 return false; 6019 6020 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); 6021 Register Op1 = MI.getOperand(1).getReg(); 6022 Register Op2 = MI.getOperand(2).getReg(); 6023 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1}; 6024 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2}; 6025 LLT DstType = MRI.getType(MI.getOperand(0).getReg()); 6026 6027 unsigned PreferredFusedOpcode = 6028 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; 6029 6030 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), 6031 // prefer to fold the multiply with fewer uses. 6032 if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) && 6033 isContractableFMul(*RHS.MI, AllowFusionGlobally)) { 6034 if (hasMoreUses(*LHS.MI, *RHS.MI, MRI)) 6035 std::swap(LHS, RHS); 6036 } 6037 6038 // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) 6039 MachineInstr *FpExtSrc; 6040 if (mi_match(LHS.Reg, MRI, m_GFPExt(m_MInstr(FpExtSrc))) && 6041 isContractableFMul(*FpExtSrc, AllowFusionGlobally) && 6042 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType, 6043 MRI.getType(FpExtSrc->getOperand(1).getReg()))) { 6044 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6045 auto FpExtX = B.buildFPExt(DstType, FpExtSrc->getOperand(1).getReg()); 6046 auto FpExtY = B.buildFPExt(DstType, FpExtSrc->getOperand(2).getReg()); 6047 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 6048 {FpExtX.getReg(0), FpExtY.getReg(0), RHS.Reg}); 6049 }; 6050 return true; 6051 } 6052 6053 // fold (fadd z, (fpext (fmul x, y))) -> (fma (fpext x), (fpext y), z) 6054 // Note: Commutes FADD operands. 6055 if (mi_match(RHS.Reg, MRI, m_GFPExt(m_MInstr(FpExtSrc))) && 6056 isContractableFMul(*FpExtSrc, AllowFusionGlobally) && 6057 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType, 6058 MRI.getType(FpExtSrc->getOperand(1).getReg()))) { 6059 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6060 auto FpExtX = B.buildFPExt(DstType, FpExtSrc->getOperand(1).getReg()); 6061 auto FpExtY = B.buildFPExt(DstType, FpExtSrc->getOperand(2).getReg()); 6062 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 6063 {FpExtX.getReg(0), FpExtY.getReg(0), LHS.Reg}); 6064 }; 6065 return true; 6066 } 6067 6068 return false; 6069 } 6070 6071 bool CombinerHelper::matchCombineFAddFMAFMulToFMadOrFMA( 6072 MachineInstr &MI, 6073 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 6074 assert(MI.getOpcode() == TargetOpcode::G_FADD); 6075 6076 bool AllowFusionGlobally, HasFMAD, Aggressive; 6077 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive, true)) 6078 return false; 6079 6080 Register Op1 = MI.getOperand(1).getReg(); 6081 Register Op2 = MI.getOperand(2).getReg(); 6082 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1}; 6083 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2}; 6084 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 6085 6086 unsigned PreferredFusedOpcode = 6087 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; 6088 6089 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), 6090 // prefer to fold the multiply with fewer uses. 6091 if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) && 6092 isContractableFMul(*RHS.MI, AllowFusionGlobally)) { 6093 if (hasMoreUses(*LHS.MI, *RHS.MI, MRI)) 6094 std::swap(LHS, RHS); 6095 } 6096 6097 MachineInstr *FMA = nullptr; 6098 Register Z; 6099 // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y, (fma u, v, z)) 6100 if (LHS.MI->getOpcode() == PreferredFusedOpcode && 6101 (MRI.getVRegDef(LHS.MI->getOperand(3).getReg())->getOpcode() == 6102 TargetOpcode::G_FMUL) && 6103 MRI.hasOneNonDBGUse(LHS.MI->getOperand(0).getReg()) && 6104 MRI.hasOneNonDBGUse(LHS.MI->getOperand(3).getReg())) { 6105 FMA = LHS.MI; 6106 Z = RHS.Reg; 6107 } 6108 // fold (fadd z, (fma x, y, (fmul u, v))) -> (fma x, y, (fma u, v, z)) 6109 else if (RHS.MI->getOpcode() == PreferredFusedOpcode && 6110 (MRI.getVRegDef(RHS.MI->getOperand(3).getReg())->getOpcode() == 6111 TargetOpcode::G_FMUL) && 6112 MRI.hasOneNonDBGUse(RHS.MI->getOperand(0).getReg()) && 6113 MRI.hasOneNonDBGUse(RHS.MI->getOperand(3).getReg())) { 6114 Z = LHS.Reg; 6115 FMA = RHS.MI; 6116 } 6117 6118 if (FMA) { 6119 MachineInstr *FMulMI = MRI.getVRegDef(FMA->getOperand(3).getReg()); 6120 Register X = FMA->getOperand(1).getReg(); 6121 Register Y = FMA->getOperand(2).getReg(); 6122 Register U = FMulMI->getOperand(1).getReg(); 6123 Register V = FMulMI->getOperand(2).getReg(); 6124 6125 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6126 Register InnerFMA = MRI.createGenericVirtualRegister(DstTy); 6127 B.buildInstr(PreferredFusedOpcode, {InnerFMA}, {U, V, Z}); 6128 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 6129 {X, Y, InnerFMA}); 6130 }; 6131 return true; 6132 } 6133 6134 return false; 6135 } 6136 6137 bool CombinerHelper::matchCombineFAddFpExtFMulToFMadOrFMAAggressive( 6138 MachineInstr &MI, 6139 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 6140 assert(MI.getOpcode() == TargetOpcode::G_FADD); 6141 6142 bool AllowFusionGlobally, HasFMAD, Aggressive; 6143 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) 6144 return false; 6145 6146 if (!Aggressive) 6147 return false; 6148 6149 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); 6150 LLT DstType = MRI.getType(MI.getOperand(0).getReg()); 6151 Register Op1 = MI.getOperand(1).getReg(); 6152 Register Op2 = MI.getOperand(2).getReg(); 6153 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1}; 6154 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2}; 6155 6156 unsigned PreferredFusedOpcode = 6157 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; 6158 6159 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), 6160 // prefer to fold the multiply with fewer uses. 6161 if (Aggressive && isContractableFMul(*LHS.MI, AllowFusionGlobally) && 6162 isContractableFMul(*RHS.MI, AllowFusionGlobally)) { 6163 if (hasMoreUses(*LHS.MI, *RHS.MI, MRI)) 6164 std::swap(LHS, RHS); 6165 } 6166 6167 // Builds: (fma x, y, (fma (fpext u), (fpext v), z)) 6168 auto buildMatchInfo = [=, &MI](Register U, Register V, Register Z, Register X, 6169 Register Y, MachineIRBuilder &B) { 6170 Register FpExtU = B.buildFPExt(DstType, U).getReg(0); 6171 Register FpExtV = B.buildFPExt(DstType, V).getReg(0); 6172 Register InnerFMA = 6173 B.buildInstr(PreferredFusedOpcode, {DstType}, {FpExtU, FpExtV, Z}) 6174 .getReg(0); 6175 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 6176 {X, Y, InnerFMA}); 6177 }; 6178 6179 MachineInstr *FMulMI, *FMAMI; 6180 // fold (fadd (fma x, y, (fpext (fmul u, v))), z) 6181 // -> (fma x, y, (fma (fpext u), (fpext v), z)) 6182 if (LHS.MI->getOpcode() == PreferredFusedOpcode && 6183 mi_match(LHS.MI->getOperand(3).getReg(), MRI, 6184 m_GFPExt(m_MInstr(FMulMI))) && 6185 isContractableFMul(*FMulMI, AllowFusionGlobally) && 6186 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType, 6187 MRI.getType(FMulMI->getOperand(0).getReg()))) { 6188 MatchInfo = [=](MachineIRBuilder &B) { 6189 buildMatchInfo(FMulMI->getOperand(1).getReg(), 6190 FMulMI->getOperand(2).getReg(), RHS.Reg, 6191 LHS.MI->getOperand(1).getReg(), 6192 LHS.MI->getOperand(2).getReg(), B); 6193 }; 6194 return true; 6195 } 6196 6197 // fold (fadd (fpext (fma x, y, (fmul u, v))), z) 6198 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z)) 6199 // FIXME: This turns two single-precision and one double-precision 6200 // operation into two double-precision operations, which might not be 6201 // interesting for all targets, especially GPUs. 6202 if (mi_match(LHS.Reg, MRI, m_GFPExt(m_MInstr(FMAMI))) && 6203 FMAMI->getOpcode() == PreferredFusedOpcode) { 6204 MachineInstr *FMulMI = MRI.getVRegDef(FMAMI->getOperand(3).getReg()); 6205 if (isContractableFMul(*FMulMI, AllowFusionGlobally) && 6206 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType, 6207 MRI.getType(FMAMI->getOperand(0).getReg()))) { 6208 MatchInfo = [=](MachineIRBuilder &B) { 6209 Register X = FMAMI->getOperand(1).getReg(); 6210 Register Y = FMAMI->getOperand(2).getReg(); 6211 X = B.buildFPExt(DstType, X).getReg(0); 6212 Y = B.buildFPExt(DstType, Y).getReg(0); 6213 buildMatchInfo(FMulMI->getOperand(1).getReg(), 6214 FMulMI->getOperand(2).getReg(), RHS.Reg, X, Y, B); 6215 }; 6216 6217 return true; 6218 } 6219 } 6220 6221 // fold (fadd z, (fma x, y, (fpext (fmul u, v))) 6222 // -> (fma x, y, (fma (fpext u), (fpext v), z)) 6223 if (RHS.MI->getOpcode() == PreferredFusedOpcode && 6224 mi_match(RHS.MI->getOperand(3).getReg(), MRI, 6225 m_GFPExt(m_MInstr(FMulMI))) && 6226 isContractableFMul(*FMulMI, AllowFusionGlobally) && 6227 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType, 6228 MRI.getType(FMulMI->getOperand(0).getReg()))) { 6229 MatchInfo = [=](MachineIRBuilder &B) { 6230 buildMatchInfo(FMulMI->getOperand(1).getReg(), 6231 FMulMI->getOperand(2).getReg(), LHS.Reg, 6232 RHS.MI->getOperand(1).getReg(), 6233 RHS.MI->getOperand(2).getReg(), B); 6234 }; 6235 return true; 6236 } 6237 6238 // fold (fadd z, (fpext (fma x, y, (fmul u, v))) 6239 // -> (fma (fpext x), (fpext y), (fma (fpext u), (fpext v), z)) 6240 // FIXME: This turns two single-precision and one double-precision 6241 // operation into two double-precision operations, which might not be 6242 // interesting for all targets, especially GPUs. 6243 if (mi_match(RHS.Reg, MRI, m_GFPExt(m_MInstr(FMAMI))) && 6244 FMAMI->getOpcode() == PreferredFusedOpcode) { 6245 MachineInstr *FMulMI = MRI.getVRegDef(FMAMI->getOperand(3).getReg()); 6246 if (isContractableFMul(*FMulMI, AllowFusionGlobally) && 6247 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstType, 6248 MRI.getType(FMAMI->getOperand(0).getReg()))) { 6249 MatchInfo = [=](MachineIRBuilder &B) { 6250 Register X = FMAMI->getOperand(1).getReg(); 6251 Register Y = FMAMI->getOperand(2).getReg(); 6252 X = B.buildFPExt(DstType, X).getReg(0); 6253 Y = B.buildFPExt(DstType, Y).getReg(0); 6254 buildMatchInfo(FMulMI->getOperand(1).getReg(), 6255 FMulMI->getOperand(2).getReg(), LHS.Reg, X, Y, B); 6256 }; 6257 return true; 6258 } 6259 } 6260 6261 return false; 6262 } 6263 6264 bool CombinerHelper::matchCombineFSubFMulToFMadOrFMA( 6265 MachineInstr &MI, 6266 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 6267 assert(MI.getOpcode() == TargetOpcode::G_FSUB); 6268 6269 bool AllowFusionGlobally, HasFMAD, Aggressive; 6270 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) 6271 return false; 6272 6273 Register Op1 = MI.getOperand(1).getReg(); 6274 Register Op2 = MI.getOperand(2).getReg(); 6275 DefinitionAndSourceRegister LHS = {MRI.getVRegDef(Op1), Op1}; 6276 DefinitionAndSourceRegister RHS = {MRI.getVRegDef(Op2), Op2}; 6277 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 6278 6279 // If we have two choices trying to fold (fadd (fmul u, v), (fmul x, y)), 6280 // prefer to fold the multiply with fewer uses. 6281 int FirstMulHasFewerUses = true; 6282 if (isContractableFMul(*LHS.MI, AllowFusionGlobally) && 6283 isContractableFMul(*RHS.MI, AllowFusionGlobally) && 6284 hasMoreUses(*LHS.MI, *RHS.MI, MRI)) 6285 FirstMulHasFewerUses = false; 6286 6287 unsigned PreferredFusedOpcode = 6288 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; 6289 6290 // fold (fsub (fmul x, y), z) -> (fma x, y, -z) 6291 if (FirstMulHasFewerUses && 6292 (isContractableFMul(*LHS.MI, AllowFusionGlobally) && 6293 (Aggressive || MRI.hasOneNonDBGUse(LHS.Reg)))) { 6294 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6295 Register NegZ = B.buildFNeg(DstTy, RHS.Reg).getReg(0); 6296 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 6297 {LHS.MI->getOperand(1).getReg(), 6298 LHS.MI->getOperand(2).getReg(), NegZ}); 6299 }; 6300 return true; 6301 } 6302 // fold (fsub x, (fmul y, z)) -> (fma -y, z, x) 6303 else if ((isContractableFMul(*RHS.MI, AllowFusionGlobally) && 6304 (Aggressive || MRI.hasOneNonDBGUse(RHS.Reg)))) { 6305 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6306 Register NegY = 6307 B.buildFNeg(DstTy, RHS.MI->getOperand(1).getReg()).getReg(0); 6308 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 6309 {NegY, RHS.MI->getOperand(2).getReg(), LHS.Reg}); 6310 }; 6311 return true; 6312 } 6313 6314 return false; 6315 } 6316 6317 bool CombinerHelper::matchCombineFSubFNegFMulToFMadOrFMA( 6318 MachineInstr &MI, 6319 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 6320 assert(MI.getOpcode() == TargetOpcode::G_FSUB); 6321 6322 bool AllowFusionGlobally, HasFMAD, Aggressive; 6323 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) 6324 return false; 6325 6326 Register LHSReg = MI.getOperand(1).getReg(); 6327 Register RHSReg = MI.getOperand(2).getReg(); 6328 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 6329 6330 unsigned PreferredFusedOpcode = 6331 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; 6332 6333 MachineInstr *FMulMI; 6334 // fold (fsub (fneg (fmul x, y)), z) -> (fma (fneg x), y, (fneg z)) 6335 if (mi_match(LHSReg, MRI, m_GFNeg(m_MInstr(FMulMI))) && 6336 (Aggressive || (MRI.hasOneNonDBGUse(LHSReg) && 6337 MRI.hasOneNonDBGUse(FMulMI->getOperand(0).getReg()))) && 6338 isContractableFMul(*FMulMI, AllowFusionGlobally)) { 6339 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6340 Register NegX = 6341 B.buildFNeg(DstTy, FMulMI->getOperand(1).getReg()).getReg(0); 6342 Register NegZ = B.buildFNeg(DstTy, RHSReg).getReg(0); 6343 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 6344 {NegX, FMulMI->getOperand(2).getReg(), NegZ}); 6345 }; 6346 return true; 6347 } 6348 6349 // fold (fsub x, (fneg (fmul, y, z))) -> (fma y, z, x) 6350 if (mi_match(RHSReg, MRI, m_GFNeg(m_MInstr(FMulMI))) && 6351 (Aggressive || (MRI.hasOneNonDBGUse(RHSReg) && 6352 MRI.hasOneNonDBGUse(FMulMI->getOperand(0).getReg()))) && 6353 isContractableFMul(*FMulMI, AllowFusionGlobally)) { 6354 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6355 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 6356 {FMulMI->getOperand(1).getReg(), 6357 FMulMI->getOperand(2).getReg(), LHSReg}); 6358 }; 6359 return true; 6360 } 6361 6362 return false; 6363 } 6364 6365 bool CombinerHelper::matchCombineFSubFpExtFMulToFMadOrFMA( 6366 MachineInstr &MI, 6367 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 6368 assert(MI.getOpcode() == TargetOpcode::G_FSUB); 6369 6370 bool AllowFusionGlobally, HasFMAD, Aggressive; 6371 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) 6372 return false; 6373 6374 Register LHSReg = MI.getOperand(1).getReg(); 6375 Register RHSReg = MI.getOperand(2).getReg(); 6376 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 6377 6378 unsigned PreferredFusedOpcode = 6379 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; 6380 6381 MachineInstr *FMulMI; 6382 // fold (fsub (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), (fneg z)) 6383 if (mi_match(LHSReg, MRI, m_GFPExt(m_MInstr(FMulMI))) && 6384 isContractableFMul(*FMulMI, AllowFusionGlobally) && 6385 (Aggressive || MRI.hasOneNonDBGUse(LHSReg))) { 6386 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6387 Register FpExtX = 6388 B.buildFPExt(DstTy, FMulMI->getOperand(1).getReg()).getReg(0); 6389 Register FpExtY = 6390 B.buildFPExt(DstTy, FMulMI->getOperand(2).getReg()).getReg(0); 6391 Register NegZ = B.buildFNeg(DstTy, RHSReg).getReg(0); 6392 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 6393 {FpExtX, FpExtY, NegZ}); 6394 }; 6395 return true; 6396 } 6397 6398 // fold (fsub x, (fpext (fmul y, z))) -> (fma (fneg (fpext y)), (fpext z), x) 6399 if (mi_match(RHSReg, MRI, m_GFPExt(m_MInstr(FMulMI))) && 6400 isContractableFMul(*FMulMI, AllowFusionGlobally) && 6401 (Aggressive || MRI.hasOneNonDBGUse(RHSReg))) { 6402 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6403 Register FpExtY = 6404 B.buildFPExt(DstTy, FMulMI->getOperand(1).getReg()).getReg(0); 6405 Register NegY = B.buildFNeg(DstTy, FpExtY).getReg(0); 6406 Register FpExtZ = 6407 B.buildFPExt(DstTy, FMulMI->getOperand(2).getReg()).getReg(0); 6408 B.buildInstr(PreferredFusedOpcode, {MI.getOperand(0).getReg()}, 6409 {NegY, FpExtZ, LHSReg}); 6410 }; 6411 return true; 6412 } 6413 6414 return false; 6415 } 6416 6417 bool CombinerHelper::matchCombineFSubFpExtFNegFMulToFMadOrFMA( 6418 MachineInstr &MI, 6419 std::function<void(MachineIRBuilder &)> &MatchInfo) const { 6420 assert(MI.getOpcode() == TargetOpcode::G_FSUB); 6421 6422 bool AllowFusionGlobally, HasFMAD, Aggressive; 6423 if (!canCombineFMadOrFMA(MI, AllowFusionGlobally, HasFMAD, Aggressive)) 6424 return false; 6425 6426 const auto &TLI = *MI.getMF()->getSubtarget().getTargetLowering(); 6427 LLT DstTy = MRI.getType(MI.getOperand(0).getReg()); 6428 Register LHSReg = MI.getOperand(1).getReg(); 6429 Register RHSReg = MI.getOperand(2).getReg(); 6430 6431 unsigned PreferredFusedOpcode = 6432 HasFMAD ? TargetOpcode::G_FMAD : TargetOpcode::G_FMA; 6433 6434 auto buildMatchInfo = [=](Register Dst, Register X, Register Y, Register Z, 6435 MachineIRBuilder &B) { 6436 Register FpExtX = B.buildFPExt(DstTy, X).getReg(0); 6437 Register FpExtY = B.buildFPExt(DstTy, Y).getReg(0); 6438 B.buildInstr(PreferredFusedOpcode, {Dst}, {FpExtX, FpExtY, Z}); 6439 }; 6440 6441 MachineInstr *FMulMI; 6442 // fold (fsub (fpext (fneg (fmul x, y))), z) -> 6443 // (fneg (fma (fpext x), (fpext y), z)) 6444 // fold (fsub (fneg (fpext (fmul x, y))), z) -> 6445 // (fneg (fma (fpext x), (fpext y), z)) 6446 if ((mi_match(LHSReg, MRI, m_GFPExt(m_GFNeg(m_MInstr(FMulMI)))) || 6447 mi_match(LHSReg, MRI, m_GFNeg(m_GFPExt(m_MInstr(FMulMI))))) && 6448 isContractableFMul(*FMulMI, AllowFusionGlobally) && 6449 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstTy, 6450 MRI.getType(FMulMI->getOperand(0).getReg()))) { 6451 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6452 Register FMAReg = MRI.createGenericVirtualRegister(DstTy); 6453 buildMatchInfo(FMAReg, FMulMI->getOperand(1).getReg(), 6454 FMulMI->getOperand(2).getReg(), RHSReg, B); 6455 B.buildFNeg(MI.getOperand(0).getReg(), FMAReg); 6456 }; 6457 return true; 6458 } 6459 6460 // fold (fsub x, (fpext (fneg (fmul y, z)))) -> (fma (fpext y), (fpext z), x) 6461 // fold (fsub x, (fneg (fpext (fmul y, z)))) -> (fma (fpext y), (fpext z), x) 6462 if ((mi_match(RHSReg, MRI, m_GFPExt(m_GFNeg(m_MInstr(FMulMI)))) || 6463 mi_match(RHSReg, MRI, m_GFNeg(m_GFPExt(m_MInstr(FMulMI))))) && 6464 isContractableFMul(*FMulMI, AllowFusionGlobally) && 6465 TLI.isFPExtFoldable(MI, PreferredFusedOpcode, DstTy, 6466 MRI.getType(FMulMI->getOperand(0).getReg()))) { 6467 MatchInfo = [=, &MI](MachineIRBuilder &B) { 6468 buildMatchInfo(MI.getOperand(0).getReg(), FMulMI->getOperand(1).getReg(), 6469 FMulMI->getOperand(2).getReg(), LHSReg, B); 6470 }; 6471 return true; 6472 } 6473 6474 return false; 6475 } 6476 6477 bool CombinerHelper::matchCombineFMinMaxNaN(MachineInstr &MI, 6478 unsigned &IdxToPropagate) const { 6479 bool PropagateNaN; 6480 switch (MI.getOpcode()) { 6481 default: 6482 return false; 6483 case TargetOpcode::G_FMINNUM: 6484 case TargetOpcode::G_FMAXNUM: 6485 PropagateNaN = false; 6486 break; 6487 case TargetOpcode::G_FMINIMUM: 6488 case TargetOpcode::G_FMAXIMUM: 6489 PropagateNaN = true; 6490 break; 6491 } 6492 6493 auto MatchNaN = [&](unsigned Idx) { 6494 Register MaybeNaNReg = MI.getOperand(Idx).getReg(); 6495 const ConstantFP *MaybeCst = getConstantFPVRegVal(MaybeNaNReg, MRI); 6496 if (!MaybeCst || !MaybeCst->getValueAPF().isNaN()) 6497 return false; 6498 IdxToPropagate = PropagateNaN ? Idx : (Idx == 1 ? 2 : 1); 6499 return true; 6500 }; 6501 6502 return MatchNaN(1) || MatchNaN(2); 6503 } 6504 6505 bool CombinerHelper::matchAddSubSameReg(MachineInstr &MI, Register &Src) const { 6506 assert(MI.getOpcode() == TargetOpcode::G_ADD && "Expected a G_ADD"); 6507 Register LHS = MI.getOperand(1).getReg(); 6508 Register RHS = MI.getOperand(2).getReg(); 6509 6510 // Helper lambda to check for opportunities for 6511 // A + (B - A) -> B 6512 // (B - A) + A -> B 6513 auto CheckFold = [&](Register MaybeSub, Register MaybeSameReg) { 6514 Register Reg; 6515 return mi_match(MaybeSub, MRI, m_GSub(m_Reg(Src), m_Reg(Reg))) && 6516 Reg == MaybeSameReg; 6517 }; 6518 return CheckFold(LHS, RHS) || CheckFold(RHS, LHS); 6519 } 6520 6521 bool CombinerHelper::matchBuildVectorIdentityFold(MachineInstr &MI, 6522 Register &MatchInfo) const { 6523 // This combine folds the following patterns: 6524 // 6525 // G_BUILD_VECTOR_TRUNC (G_BITCAST(x), G_LSHR(G_BITCAST(x), k)) 6526 // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), G_TRUNC(G_LSHR(G_BITCAST(x), k))) 6527 // into 6528 // x 6529 // if 6530 // k == sizeof(VecEltTy)/2 6531 // type(x) == type(dst) 6532 // 6533 // G_BUILD_VECTOR(G_TRUNC(G_BITCAST(x)), undef) 6534 // into 6535 // x 6536 // if 6537 // type(x) == type(dst) 6538 6539 LLT DstVecTy = MRI.getType(MI.getOperand(0).getReg()); 6540 LLT DstEltTy = DstVecTy.getElementType(); 6541 6542 Register Lo, Hi; 6543 6544 if (mi_match( 6545 MI, MRI, 6546 m_GBuildVector(m_GTrunc(m_GBitcast(m_Reg(Lo))), m_GImplicitDef()))) { 6547 MatchInfo = Lo; 6548 return MRI.getType(MatchInfo) == DstVecTy; 6549 } 6550 6551 std::optional<ValueAndVReg> ShiftAmount; 6552 const auto LoPattern = m_GBitcast(m_Reg(Lo)); 6553 const auto HiPattern = m_GLShr(m_GBitcast(m_Reg(Hi)), m_GCst(ShiftAmount)); 6554 if (mi_match( 6555 MI, MRI, 6556 m_any_of(m_GBuildVectorTrunc(LoPattern, HiPattern), 6557 m_GBuildVector(m_GTrunc(LoPattern), m_GTrunc(HiPattern))))) { 6558 if (Lo == Hi && ShiftAmount->Value == DstEltTy.getSizeInBits()) { 6559 MatchInfo = Lo; 6560 return MRI.getType(MatchInfo) == DstVecTy; 6561 } 6562 } 6563 6564 return false; 6565 } 6566 6567 bool CombinerHelper::matchTruncBuildVectorFold(MachineInstr &MI, 6568 Register &MatchInfo) const { 6569 // Replace (G_TRUNC (G_BITCAST (G_BUILD_VECTOR x, y)) with just x 6570 // if type(x) == type(G_TRUNC) 6571 if (!mi_match(MI.getOperand(1).getReg(), MRI, 6572 m_GBitcast(m_GBuildVector(m_Reg(MatchInfo), m_Reg())))) 6573 return false; 6574 6575 return MRI.getType(MatchInfo) == MRI.getType(MI.getOperand(0).getReg()); 6576 } 6577 6578 bool CombinerHelper::matchTruncLshrBuildVectorFold(MachineInstr &MI, 6579 Register &MatchInfo) const { 6580 // Replace (G_TRUNC (G_LSHR (G_BITCAST (G_BUILD_VECTOR x, y)), K)) with 6581 // y if K == size of vector element type 6582 std::optional<ValueAndVReg> ShiftAmt; 6583 if (!mi_match(MI.getOperand(1).getReg(), MRI, 6584 m_GLShr(m_GBitcast(m_GBuildVector(m_Reg(), m_Reg(MatchInfo))), 6585 m_GCst(ShiftAmt)))) 6586 return false; 6587 6588 LLT MatchTy = MRI.getType(MatchInfo); 6589 return ShiftAmt->Value.getZExtValue() == MatchTy.getSizeInBits() && 6590 MatchTy == MRI.getType(MI.getOperand(0).getReg()); 6591 } 6592 6593 unsigned CombinerHelper::getFPMinMaxOpcForSelect( 6594 CmpInst::Predicate Pred, LLT DstTy, 6595 SelectPatternNaNBehaviour VsNaNRetVal) const { 6596 assert(VsNaNRetVal != SelectPatternNaNBehaviour::NOT_APPLICABLE && 6597 "Expected a NaN behaviour?"); 6598 // Choose an opcode based off of legality or the behaviour when one of the 6599 // LHS/RHS may be NaN. 6600 switch (Pred) { 6601 default: 6602 return 0; 6603 case CmpInst::FCMP_UGT: 6604 case CmpInst::FCMP_UGE: 6605 case CmpInst::FCMP_OGT: 6606 case CmpInst::FCMP_OGE: 6607 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER) 6608 return TargetOpcode::G_FMAXNUM; 6609 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN) 6610 return TargetOpcode::G_FMAXIMUM; 6611 if (isLegal({TargetOpcode::G_FMAXNUM, {DstTy}})) 6612 return TargetOpcode::G_FMAXNUM; 6613 if (isLegal({TargetOpcode::G_FMAXIMUM, {DstTy}})) 6614 return TargetOpcode::G_FMAXIMUM; 6615 return 0; 6616 case CmpInst::FCMP_ULT: 6617 case CmpInst::FCMP_ULE: 6618 case CmpInst::FCMP_OLT: 6619 case CmpInst::FCMP_OLE: 6620 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_OTHER) 6621 return TargetOpcode::G_FMINNUM; 6622 if (VsNaNRetVal == SelectPatternNaNBehaviour::RETURNS_NAN) 6623 return TargetOpcode::G_FMINIMUM; 6624 if (isLegal({TargetOpcode::G_FMINNUM, {DstTy}})) 6625 return TargetOpcode::G_FMINNUM; 6626 if (!isLegal({TargetOpcode::G_FMINIMUM, {DstTy}})) 6627 return 0; 6628 return TargetOpcode::G_FMINIMUM; 6629 } 6630 } 6631 6632 CombinerHelper::SelectPatternNaNBehaviour 6633 CombinerHelper::computeRetValAgainstNaN(Register LHS, Register RHS, 6634 bool IsOrderedComparison) const { 6635 bool LHSSafe = isKnownNeverNaN(LHS, MRI); 6636 bool RHSSafe = isKnownNeverNaN(RHS, MRI); 6637 // Completely unsafe. 6638 if (!LHSSafe && !RHSSafe) 6639 return SelectPatternNaNBehaviour::NOT_APPLICABLE; 6640 if (LHSSafe && RHSSafe) 6641 return SelectPatternNaNBehaviour::RETURNS_ANY; 6642 // An ordered comparison will return false when given a NaN, so it 6643 // returns the RHS. 6644 if (IsOrderedComparison) 6645 return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_NAN 6646 : SelectPatternNaNBehaviour::RETURNS_OTHER; 6647 // An unordered comparison will return true when given a NaN, so it 6648 // returns the LHS. 6649 return LHSSafe ? SelectPatternNaNBehaviour::RETURNS_OTHER 6650 : SelectPatternNaNBehaviour::RETURNS_NAN; 6651 } 6652 6653 bool CombinerHelper::matchFPSelectToMinMax(Register Dst, Register Cond, 6654 Register TrueVal, Register FalseVal, 6655 BuildFnTy &MatchInfo) const { 6656 // Match: select (fcmp cond x, y) x, y 6657 // select (fcmp cond x, y) y, x 6658 // And turn it into fminnum/fmaxnum or fmin/fmax based off of the condition. 6659 LLT DstTy = MRI.getType(Dst); 6660 // Bail out early on pointers, since we'll never want to fold to a min/max. 6661 if (DstTy.isPointer()) 6662 return false; 6663 // Match a floating point compare with a less-than/greater-than predicate. 6664 // TODO: Allow multiple users of the compare if they are all selects. 6665 CmpInst::Predicate Pred; 6666 Register CmpLHS, CmpRHS; 6667 if (!mi_match(Cond, MRI, 6668 m_OneNonDBGUse( 6669 m_GFCmp(m_Pred(Pred), m_Reg(CmpLHS), m_Reg(CmpRHS)))) || 6670 CmpInst::isEquality(Pred)) 6671 return false; 6672 SelectPatternNaNBehaviour ResWithKnownNaNInfo = 6673 computeRetValAgainstNaN(CmpLHS, CmpRHS, CmpInst::isOrdered(Pred)); 6674 if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::NOT_APPLICABLE) 6675 return false; 6676 if (TrueVal == CmpRHS && FalseVal == CmpLHS) { 6677 std::swap(CmpLHS, CmpRHS); 6678 Pred = CmpInst::getSwappedPredicate(Pred); 6679 if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_NAN) 6680 ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_OTHER; 6681 else if (ResWithKnownNaNInfo == SelectPatternNaNBehaviour::RETURNS_OTHER) 6682 ResWithKnownNaNInfo = SelectPatternNaNBehaviour::RETURNS_NAN; 6683 } 6684 if (TrueVal != CmpLHS || FalseVal != CmpRHS) 6685 return false; 6686 // Decide what type of max/min this should be based off of the predicate. 6687 unsigned Opc = getFPMinMaxOpcForSelect(Pred, DstTy, ResWithKnownNaNInfo); 6688 if (!Opc || !isLegal({Opc, {DstTy}})) 6689 return false; 6690 // Comparisons between signed zero and zero may have different results... 6691 // unless we have fmaximum/fminimum. In that case, we know -0 < 0. 6692 if (Opc != TargetOpcode::G_FMAXIMUM && Opc != TargetOpcode::G_FMINIMUM) { 6693 // We don't know if a comparison between two 0s will give us a consistent 6694 // result. Be conservative and only proceed if at least one side is 6695 // non-zero. 6696 auto KnownNonZeroSide = getFConstantVRegValWithLookThrough(CmpLHS, MRI); 6697 if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) { 6698 KnownNonZeroSide = getFConstantVRegValWithLookThrough(CmpRHS, MRI); 6699 if (!KnownNonZeroSide || !KnownNonZeroSide->Value.isNonZero()) 6700 return false; 6701 } 6702 } 6703 MatchInfo = [=](MachineIRBuilder &B) { 6704 B.buildInstr(Opc, {Dst}, {CmpLHS, CmpRHS}); 6705 }; 6706 return true; 6707 } 6708 6709 bool CombinerHelper::matchSimplifySelectToMinMax(MachineInstr &MI, 6710 BuildFnTy &MatchInfo) const { 6711 // TODO: Handle integer cases. 6712 assert(MI.getOpcode() == TargetOpcode::G_SELECT); 6713 // Condition may be fed by a truncated compare. 6714 Register Cond = MI.getOperand(1).getReg(); 6715 Register MaybeTrunc; 6716 if (mi_match(Cond, MRI, m_OneNonDBGUse(m_GTrunc(m_Reg(MaybeTrunc))))) 6717 Cond = MaybeTrunc; 6718 Register Dst = MI.getOperand(0).getReg(); 6719 Register TrueVal = MI.getOperand(2).getReg(); 6720 Register FalseVal = MI.getOperand(3).getReg(); 6721 return matchFPSelectToMinMax(Dst, Cond, TrueVal, FalseVal, MatchInfo); 6722 } 6723 6724 bool CombinerHelper::matchRedundantBinOpInEquality(MachineInstr &MI, 6725 BuildFnTy &MatchInfo) const { 6726 assert(MI.getOpcode() == TargetOpcode::G_ICMP); 6727 // (X + Y) == X --> Y == 0 6728 // (X + Y) != X --> Y != 0 6729 // (X - Y) == X --> Y == 0 6730 // (X - Y) != X --> Y != 0 6731 // (X ^ Y) == X --> Y == 0 6732 // (X ^ Y) != X --> Y != 0 6733 Register Dst = MI.getOperand(0).getReg(); 6734 CmpInst::Predicate Pred; 6735 Register X, Y, OpLHS, OpRHS; 6736 bool MatchedSub = mi_match( 6737 Dst, MRI, 6738 m_c_GICmp(m_Pred(Pred), m_Reg(X), m_GSub(m_Reg(OpLHS), m_Reg(Y)))); 6739 if (MatchedSub && X != OpLHS) 6740 return false; 6741 if (!MatchedSub) { 6742 if (!mi_match(Dst, MRI, 6743 m_c_GICmp(m_Pred(Pred), m_Reg(X), 6744 m_any_of(m_GAdd(m_Reg(OpLHS), m_Reg(OpRHS)), 6745 m_GXor(m_Reg(OpLHS), m_Reg(OpRHS)))))) 6746 return false; 6747 Y = X == OpLHS ? OpRHS : X == OpRHS ? OpLHS : Register(); 6748 } 6749 MatchInfo = [=](MachineIRBuilder &B) { 6750 auto Zero = B.buildConstant(MRI.getType(Y), 0); 6751 B.buildICmp(Pred, Dst, Y, Zero); 6752 }; 6753 return CmpInst::isEquality(Pred) && Y.isValid(); 6754 } 6755 6756 /// Return the minimum useless shift amount that results in complete loss of the 6757 /// source value. Return std::nullopt when it cannot determine a value. 6758 static std::optional<unsigned> 6759 getMinUselessShift(KnownBits ValueKB, unsigned Opcode, 6760 std::optional<int64_t> &Result) { 6761 assert((Opcode == TargetOpcode::G_SHL || Opcode == TargetOpcode::G_LSHR || 6762 Opcode == TargetOpcode::G_ASHR) && 6763 "Expect G_SHL, G_LSHR or G_ASHR."); 6764 auto SignificantBits = 0; 6765 switch (Opcode) { 6766 case TargetOpcode::G_SHL: 6767 SignificantBits = ValueKB.countMinTrailingZeros(); 6768 Result = 0; 6769 break; 6770 case TargetOpcode::G_LSHR: 6771 Result = 0; 6772 SignificantBits = ValueKB.countMinLeadingZeros(); 6773 break; 6774 case TargetOpcode::G_ASHR: 6775 if (ValueKB.isNonNegative()) { 6776 SignificantBits = ValueKB.countMinLeadingZeros(); 6777 Result = 0; 6778 } else if (ValueKB.isNegative()) { 6779 SignificantBits = ValueKB.countMinLeadingOnes(); 6780 Result = -1; 6781 } else { 6782 // Cannot determine shift result. 6783 Result = std::nullopt; 6784 } 6785 break; 6786 default: 6787 break; 6788 } 6789 return ValueKB.getBitWidth() - SignificantBits; 6790 } 6791 6792 bool CombinerHelper::matchShiftsTooBig( 6793 MachineInstr &MI, std::optional<int64_t> &MatchInfo) const { 6794 Register ShiftVal = MI.getOperand(1).getReg(); 6795 Register ShiftReg = MI.getOperand(2).getReg(); 6796 LLT ResTy = MRI.getType(MI.getOperand(0).getReg()); 6797 auto IsShiftTooBig = [&](const Constant *C) { 6798 auto *CI = dyn_cast<ConstantInt>(C); 6799 if (!CI) 6800 return false; 6801 if (CI->uge(ResTy.getScalarSizeInBits())) { 6802 MatchInfo = std::nullopt; 6803 return true; 6804 } 6805 auto OptMaxUsefulShift = getMinUselessShift(VT->getKnownBits(ShiftVal), 6806 MI.getOpcode(), MatchInfo); 6807 return OptMaxUsefulShift && CI->uge(*OptMaxUsefulShift); 6808 }; 6809 return matchUnaryPredicate(MRI, ShiftReg, IsShiftTooBig); 6810 } 6811 6812 bool CombinerHelper::matchCommuteConstantToRHS(MachineInstr &MI) const { 6813 unsigned LHSOpndIdx = 1; 6814 unsigned RHSOpndIdx = 2; 6815 switch (MI.getOpcode()) { 6816 case TargetOpcode::G_UADDO: 6817 case TargetOpcode::G_SADDO: 6818 case TargetOpcode::G_UMULO: 6819 case TargetOpcode::G_SMULO: 6820 LHSOpndIdx = 2; 6821 RHSOpndIdx = 3; 6822 break; 6823 default: 6824 break; 6825 } 6826 Register LHS = MI.getOperand(LHSOpndIdx).getReg(); 6827 Register RHS = MI.getOperand(RHSOpndIdx).getReg(); 6828 if (!getIConstantVRegVal(LHS, MRI)) { 6829 // Skip commuting if LHS is not a constant. But, LHS may be a 6830 // G_CONSTANT_FOLD_BARRIER. If so we commute as long as we don't already 6831 // have a constant on the RHS. 6832 if (MRI.getVRegDef(LHS)->getOpcode() != 6833 TargetOpcode::G_CONSTANT_FOLD_BARRIER) 6834 return false; 6835 } 6836 // Commute as long as RHS is not a constant or G_CONSTANT_FOLD_BARRIER. 6837 return MRI.getVRegDef(RHS)->getOpcode() != 6838 TargetOpcode::G_CONSTANT_FOLD_BARRIER && 6839 !getIConstantVRegVal(RHS, MRI); 6840 } 6841 6842 bool CombinerHelper::matchCommuteFPConstantToRHS(MachineInstr &MI) const { 6843 Register LHS = MI.getOperand(1).getReg(); 6844 Register RHS = MI.getOperand(2).getReg(); 6845 std::optional<FPValueAndVReg> ValAndVReg; 6846 if (!mi_match(LHS, MRI, m_GFCstOrSplat(ValAndVReg))) 6847 return false; 6848 return !mi_match(RHS, MRI, m_GFCstOrSplat(ValAndVReg)); 6849 } 6850 6851 void CombinerHelper::applyCommuteBinOpOperands(MachineInstr &MI) const { 6852 Observer.changingInstr(MI); 6853 unsigned LHSOpndIdx = 1; 6854 unsigned RHSOpndIdx = 2; 6855 switch (MI.getOpcode()) { 6856 case TargetOpcode::G_UADDO: 6857 case TargetOpcode::G_SADDO: 6858 case TargetOpcode::G_UMULO: 6859 case TargetOpcode::G_SMULO: 6860 LHSOpndIdx = 2; 6861 RHSOpndIdx = 3; 6862 break; 6863 default: 6864 break; 6865 } 6866 Register LHSReg = MI.getOperand(LHSOpndIdx).getReg(); 6867 Register RHSReg = MI.getOperand(RHSOpndIdx).getReg(); 6868 MI.getOperand(LHSOpndIdx).setReg(RHSReg); 6869 MI.getOperand(RHSOpndIdx).setReg(LHSReg); 6870 Observer.changedInstr(MI); 6871 } 6872 6873 bool CombinerHelper::isOneOrOneSplat(Register Src, bool AllowUndefs) const { 6874 LLT SrcTy = MRI.getType(Src); 6875 if (SrcTy.isFixedVector()) 6876 return isConstantSplatVector(Src, 1, AllowUndefs); 6877 if (SrcTy.isScalar()) { 6878 if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr) 6879 return true; 6880 auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI); 6881 return IConstant && IConstant->Value == 1; 6882 } 6883 return false; // scalable vector 6884 } 6885 6886 bool CombinerHelper::isZeroOrZeroSplat(Register Src, bool AllowUndefs) const { 6887 LLT SrcTy = MRI.getType(Src); 6888 if (SrcTy.isFixedVector()) 6889 return isConstantSplatVector(Src, 0, AllowUndefs); 6890 if (SrcTy.isScalar()) { 6891 if (AllowUndefs && getOpcodeDef<GImplicitDef>(Src, MRI) != nullptr) 6892 return true; 6893 auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI); 6894 return IConstant && IConstant->Value == 0; 6895 } 6896 return false; // scalable vector 6897 } 6898 6899 // Ignores COPYs during conformance checks. 6900 // FIXME scalable vectors. 6901 bool CombinerHelper::isConstantSplatVector(Register Src, int64_t SplatValue, 6902 bool AllowUndefs) const { 6903 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI); 6904 if (!BuildVector) 6905 return false; 6906 unsigned NumSources = BuildVector->getNumSources(); 6907 6908 for (unsigned I = 0; I < NumSources; ++I) { 6909 GImplicitDef *ImplicitDef = 6910 getOpcodeDef<GImplicitDef>(BuildVector->getSourceReg(I), MRI); 6911 if (ImplicitDef && AllowUndefs) 6912 continue; 6913 if (ImplicitDef && !AllowUndefs) 6914 return false; 6915 std::optional<ValueAndVReg> IConstant = 6916 getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI); 6917 if (IConstant && IConstant->Value == SplatValue) 6918 continue; 6919 return false; 6920 } 6921 return true; 6922 } 6923 6924 // Ignores COPYs during lookups. 6925 // FIXME scalable vectors 6926 std::optional<APInt> 6927 CombinerHelper::getConstantOrConstantSplatVector(Register Src) const { 6928 auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI); 6929 if (IConstant) 6930 return IConstant->Value; 6931 6932 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI); 6933 if (!BuildVector) 6934 return std::nullopt; 6935 unsigned NumSources = BuildVector->getNumSources(); 6936 6937 std::optional<APInt> Value = std::nullopt; 6938 for (unsigned I = 0; I < NumSources; ++I) { 6939 std::optional<ValueAndVReg> IConstant = 6940 getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI); 6941 if (!IConstant) 6942 return std::nullopt; 6943 if (!Value) 6944 Value = IConstant->Value; 6945 else if (*Value != IConstant->Value) 6946 return std::nullopt; 6947 } 6948 return Value; 6949 } 6950 6951 // FIXME G_SPLAT_VECTOR 6952 bool CombinerHelper::isConstantOrConstantVectorI(Register Src) const { 6953 auto IConstant = getIConstantVRegValWithLookThrough(Src, MRI); 6954 if (IConstant) 6955 return true; 6956 6957 GBuildVector *BuildVector = getOpcodeDef<GBuildVector>(Src, MRI); 6958 if (!BuildVector) 6959 return false; 6960 6961 unsigned NumSources = BuildVector->getNumSources(); 6962 for (unsigned I = 0; I < NumSources; ++I) { 6963 std::optional<ValueAndVReg> IConstant = 6964 getIConstantVRegValWithLookThrough(BuildVector->getSourceReg(I), MRI); 6965 if (!IConstant) 6966 return false; 6967 } 6968 return true; 6969 } 6970 6971 // TODO: use knownbits to determine zeros 6972 bool CombinerHelper::tryFoldSelectOfConstants(GSelect *Select, 6973 BuildFnTy &MatchInfo) const { 6974 uint32_t Flags = Select->getFlags(); 6975 Register Dest = Select->getReg(0); 6976 Register Cond = Select->getCondReg(); 6977 Register True = Select->getTrueReg(); 6978 Register False = Select->getFalseReg(); 6979 LLT CondTy = MRI.getType(Select->getCondReg()); 6980 LLT TrueTy = MRI.getType(Select->getTrueReg()); 6981 6982 // We only do this combine for scalar boolean conditions. 6983 if (CondTy != LLT::scalar(1)) 6984 return false; 6985 6986 if (TrueTy.isPointer()) 6987 return false; 6988 6989 // Both are scalars. 6990 std::optional<ValueAndVReg> TrueOpt = 6991 getIConstantVRegValWithLookThrough(True, MRI); 6992 std::optional<ValueAndVReg> FalseOpt = 6993 getIConstantVRegValWithLookThrough(False, MRI); 6994 6995 if (!TrueOpt || !FalseOpt) 6996 return false; 6997 6998 APInt TrueValue = TrueOpt->Value; 6999 APInt FalseValue = FalseOpt->Value; 7000 7001 // select Cond, 1, 0 --> zext (Cond) 7002 if (TrueValue.isOne() && FalseValue.isZero()) { 7003 MatchInfo = [=](MachineIRBuilder &B) { 7004 B.setInstrAndDebugLoc(*Select); 7005 B.buildZExtOrTrunc(Dest, Cond); 7006 }; 7007 return true; 7008 } 7009 7010 // select Cond, -1, 0 --> sext (Cond) 7011 if (TrueValue.isAllOnes() && FalseValue.isZero()) { 7012 MatchInfo = [=](MachineIRBuilder &B) { 7013 B.setInstrAndDebugLoc(*Select); 7014 B.buildSExtOrTrunc(Dest, Cond); 7015 }; 7016 return true; 7017 } 7018 7019 // select Cond, 0, 1 --> zext (!Cond) 7020 if (TrueValue.isZero() && FalseValue.isOne()) { 7021 MatchInfo = [=](MachineIRBuilder &B) { 7022 B.setInstrAndDebugLoc(*Select); 7023 Register Inner = MRI.createGenericVirtualRegister(CondTy); 7024 B.buildNot(Inner, Cond); 7025 B.buildZExtOrTrunc(Dest, Inner); 7026 }; 7027 return true; 7028 } 7029 7030 // select Cond, 0, -1 --> sext (!Cond) 7031 if (TrueValue.isZero() && FalseValue.isAllOnes()) { 7032 MatchInfo = [=](MachineIRBuilder &B) { 7033 B.setInstrAndDebugLoc(*Select); 7034 Register Inner = MRI.createGenericVirtualRegister(CondTy); 7035 B.buildNot(Inner, Cond); 7036 B.buildSExtOrTrunc(Dest, Inner); 7037 }; 7038 return true; 7039 } 7040 7041 // select Cond, C1, C1-1 --> add (zext Cond), C1-1 7042 if (TrueValue - 1 == FalseValue) { 7043 MatchInfo = [=](MachineIRBuilder &B) { 7044 B.setInstrAndDebugLoc(*Select); 7045 Register Inner = MRI.createGenericVirtualRegister(TrueTy); 7046 B.buildZExtOrTrunc(Inner, Cond); 7047 B.buildAdd(Dest, Inner, False); 7048 }; 7049 return true; 7050 } 7051 7052 // select Cond, C1, C1+1 --> add (sext Cond), C1+1 7053 if (TrueValue + 1 == FalseValue) { 7054 MatchInfo = [=](MachineIRBuilder &B) { 7055 B.setInstrAndDebugLoc(*Select); 7056 Register Inner = MRI.createGenericVirtualRegister(TrueTy); 7057 B.buildSExtOrTrunc(Inner, Cond); 7058 B.buildAdd(Dest, Inner, False); 7059 }; 7060 return true; 7061 } 7062 7063 // select Cond, Pow2, 0 --> (zext Cond) << log2(Pow2) 7064 if (TrueValue.isPowerOf2() && FalseValue.isZero()) { 7065 MatchInfo = [=](MachineIRBuilder &B) { 7066 B.setInstrAndDebugLoc(*Select); 7067 Register Inner = MRI.createGenericVirtualRegister(TrueTy); 7068 B.buildZExtOrTrunc(Inner, Cond); 7069 // The shift amount must be scalar. 7070 LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy; 7071 auto ShAmtC = B.buildConstant(ShiftTy, TrueValue.exactLogBase2()); 7072 B.buildShl(Dest, Inner, ShAmtC, Flags); 7073 }; 7074 return true; 7075 } 7076 7077 // select Cond, 0, Pow2 --> (zext (!Cond)) << log2(Pow2) 7078 if (FalseValue.isPowerOf2() && TrueValue.isZero()) { 7079 MatchInfo = [=](MachineIRBuilder &B) { 7080 B.setInstrAndDebugLoc(*Select); 7081 Register Not = MRI.createGenericVirtualRegister(CondTy); 7082 B.buildNot(Not, Cond); 7083 Register Inner = MRI.createGenericVirtualRegister(TrueTy); 7084 B.buildZExtOrTrunc(Inner, Not); 7085 // The shift amount must be scalar. 7086 LLT ShiftTy = TrueTy.isVector() ? TrueTy.getElementType() : TrueTy; 7087 auto ShAmtC = B.buildConstant(ShiftTy, FalseValue.exactLogBase2()); 7088 B.buildShl(Dest, Inner, ShAmtC, Flags); 7089 }; 7090 return true; 7091 } 7092 7093 // select Cond, -1, C --> or (sext Cond), C 7094 if (TrueValue.isAllOnes()) { 7095 MatchInfo = [=](MachineIRBuilder &B) { 7096 B.setInstrAndDebugLoc(*Select); 7097 Register Inner = MRI.createGenericVirtualRegister(TrueTy); 7098 B.buildSExtOrTrunc(Inner, Cond); 7099 B.buildOr(Dest, Inner, False, Flags); 7100 }; 7101 return true; 7102 } 7103 7104 // select Cond, C, -1 --> or (sext (not Cond)), C 7105 if (FalseValue.isAllOnes()) { 7106 MatchInfo = [=](MachineIRBuilder &B) { 7107 B.setInstrAndDebugLoc(*Select); 7108 Register Not = MRI.createGenericVirtualRegister(CondTy); 7109 B.buildNot(Not, Cond); 7110 Register Inner = MRI.createGenericVirtualRegister(TrueTy); 7111 B.buildSExtOrTrunc(Inner, Not); 7112 B.buildOr(Dest, Inner, True, Flags); 7113 }; 7114 return true; 7115 } 7116 7117 return false; 7118 } 7119 7120 // TODO: use knownbits to determine zeros 7121 bool CombinerHelper::tryFoldBoolSelectToLogic(GSelect *Select, 7122 BuildFnTy &MatchInfo) const { 7123 uint32_t Flags = Select->getFlags(); 7124 Register DstReg = Select->getReg(0); 7125 Register Cond = Select->getCondReg(); 7126 Register True = Select->getTrueReg(); 7127 Register False = Select->getFalseReg(); 7128 LLT CondTy = MRI.getType(Select->getCondReg()); 7129 LLT TrueTy = MRI.getType(Select->getTrueReg()); 7130 7131 // Boolean or fixed vector of booleans. 7132 if (CondTy.isScalableVector() || 7133 (CondTy.isFixedVector() && 7134 CondTy.getElementType().getScalarSizeInBits() != 1) || 7135 CondTy.getScalarSizeInBits() != 1) 7136 return false; 7137 7138 if (CondTy != TrueTy) 7139 return false; 7140 7141 // select Cond, Cond, F --> or Cond, F 7142 // select Cond, 1, F --> or Cond, F 7143 if ((Cond == True) || isOneOrOneSplat(True, /* AllowUndefs */ true)) { 7144 MatchInfo = [=](MachineIRBuilder &B) { 7145 B.setInstrAndDebugLoc(*Select); 7146 Register Ext = MRI.createGenericVirtualRegister(TrueTy); 7147 B.buildZExtOrTrunc(Ext, Cond); 7148 auto FreezeFalse = B.buildFreeze(TrueTy, False); 7149 B.buildOr(DstReg, Ext, FreezeFalse, Flags); 7150 }; 7151 return true; 7152 } 7153 7154 // select Cond, T, Cond --> and Cond, T 7155 // select Cond, T, 0 --> and Cond, T 7156 if ((Cond == False) || isZeroOrZeroSplat(False, /* AllowUndefs */ true)) { 7157 MatchInfo = [=](MachineIRBuilder &B) { 7158 B.setInstrAndDebugLoc(*Select); 7159 Register Ext = MRI.createGenericVirtualRegister(TrueTy); 7160 B.buildZExtOrTrunc(Ext, Cond); 7161 auto FreezeTrue = B.buildFreeze(TrueTy, True); 7162 B.buildAnd(DstReg, Ext, FreezeTrue); 7163 }; 7164 return true; 7165 } 7166 7167 // select Cond, T, 1 --> or (not Cond), T 7168 if (isOneOrOneSplat(False, /* AllowUndefs */ true)) { 7169 MatchInfo = [=](MachineIRBuilder &B) { 7170 B.setInstrAndDebugLoc(*Select); 7171 // First the not. 7172 Register Inner = MRI.createGenericVirtualRegister(CondTy); 7173 B.buildNot(Inner, Cond); 7174 // Then an ext to match the destination register. 7175 Register Ext = MRI.createGenericVirtualRegister(TrueTy); 7176 B.buildZExtOrTrunc(Ext, Inner); 7177 auto FreezeTrue = B.buildFreeze(TrueTy, True); 7178 B.buildOr(DstReg, Ext, FreezeTrue, Flags); 7179 }; 7180 return true; 7181 } 7182 7183 // select Cond, 0, F --> and (not Cond), F 7184 if (isZeroOrZeroSplat(True, /* AllowUndefs */ true)) { 7185 MatchInfo = [=](MachineIRBuilder &B) { 7186 B.setInstrAndDebugLoc(*Select); 7187 // First the not. 7188 Register Inner = MRI.createGenericVirtualRegister(CondTy); 7189 B.buildNot(Inner, Cond); 7190 // Then an ext to match the destination register. 7191 Register Ext = MRI.createGenericVirtualRegister(TrueTy); 7192 B.buildZExtOrTrunc(Ext, Inner); 7193 auto FreezeFalse = B.buildFreeze(TrueTy, False); 7194 B.buildAnd(DstReg, Ext, FreezeFalse); 7195 }; 7196 return true; 7197 } 7198 7199 return false; 7200 } 7201 7202 bool CombinerHelper::matchSelectIMinMax(const MachineOperand &MO, 7203 BuildFnTy &MatchInfo) const { 7204 GSelect *Select = cast<GSelect>(MRI.getVRegDef(MO.getReg())); 7205 GICmp *Cmp = cast<GICmp>(MRI.getVRegDef(Select->getCondReg())); 7206 7207 Register DstReg = Select->getReg(0); 7208 Register True = Select->getTrueReg(); 7209 Register False = Select->getFalseReg(); 7210 LLT DstTy = MRI.getType(DstReg); 7211 7212 if (DstTy.isPointer()) 7213 return false; 7214 7215 // We want to fold the icmp and replace the select. 7216 if (!MRI.hasOneNonDBGUse(Cmp->getReg(0))) 7217 return false; 7218 7219 CmpInst::Predicate Pred = Cmp->getCond(); 7220 // We need a larger or smaller predicate for 7221 // canonicalization. 7222 if (CmpInst::isEquality(Pred)) 7223 return false; 7224 7225 Register CmpLHS = Cmp->getLHSReg(); 7226 Register CmpRHS = Cmp->getRHSReg(); 7227 7228 // We can swap CmpLHS and CmpRHS for higher hitrate. 7229 if (True == CmpRHS && False == CmpLHS) { 7230 std::swap(CmpLHS, CmpRHS); 7231 Pred = CmpInst::getSwappedPredicate(Pred); 7232 } 7233 7234 // (icmp X, Y) ? X : Y -> integer minmax. 7235 // see matchSelectPattern in ValueTracking. 7236 // Legality between G_SELECT and integer minmax can differ. 7237 if (True != CmpLHS || False != CmpRHS) 7238 return false; 7239 7240 switch (Pred) { 7241 case ICmpInst::ICMP_UGT: 7242 case ICmpInst::ICMP_UGE: { 7243 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_UMAX, DstTy})) 7244 return false; 7245 MatchInfo = [=](MachineIRBuilder &B) { B.buildUMax(DstReg, True, False); }; 7246 return true; 7247 } 7248 case ICmpInst::ICMP_SGT: 7249 case ICmpInst::ICMP_SGE: { 7250 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SMAX, DstTy})) 7251 return false; 7252 MatchInfo = [=](MachineIRBuilder &B) { B.buildSMax(DstReg, True, False); }; 7253 return true; 7254 } 7255 case ICmpInst::ICMP_ULT: 7256 case ICmpInst::ICMP_ULE: { 7257 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_UMIN, DstTy})) 7258 return false; 7259 MatchInfo = [=](MachineIRBuilder &B) { B.buildUMin(DstReg, True, False); }; 7260 return true; 7261 } 7262 case ICmpInst::ICMP_SLT: 7263 case ICmpInst::ICMP_SLE: { 7264 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SMIN, DstTy})) 7265 return false; 7266 MatchInfo = [=](MachineIRBuilder &B) { B.buildSMin(DstReg, True, False); }; 7267 return true; 7268 } 7269 default: 7270 return false; 7271 } 7272 } 7273 7274 // (neg (min/max x, (neg x))) --> (max/min x, (neg x)) 7275 bool CombinerHelper::matchSimplifyNegMinMax(MachineInstr &MI, 7276 BuildFnTy &MatchInfo) const { 7277 assert(MI.getOpcode() == TargetOpcode::G_SUB); 7278 Register DestReg = MI.getOperand(0).getReg(); 7279 LLT DestTy = MRI.getType(DestReg); 7280 7281 Register X; 7282 Register Sub0; 7283 auto NegPattern = m_all_of(m_Neg(m_DeferredReg(X)), m_Reg(Sub0)); 7284 if (mi_match(DestReg, MRI, 7285 m_Neg(m_OneUse(m_any_of(m_GSMin(m_Reg(X), NegPattern), 7286 m_GSMax(m_Reg(X), NegPattern), 7287 m_GUMin(m_Reg(X), NegPattern), 7288 m_GUMax(m_Reg(X), NegPattern)))))) { 7289 MachineInstr *MinMaxMI = MRI.getVRegDef(MI.getOperand(2).getReg()); 7290 unsigned NewOpc = getInverseGMinMaxOpcode(MinMaxMI->getOpcode()); 7291 if (isLegal({NewOpc, {DestTy}})) { 7292 MatchInfo = [=](MachineIRBuilder &B) { 7293 B.buildInstr(NewOpc, {DestReg}, {X, Sub0}); 7294 }; 7295 return true; 7296 } 7297 } 7298 7299 return false; 7300 } 7301 7302 bool CombinerHelper::matchSelect(MachineInstr &MI, BuildFnTy &MatchInfo) const { 7303 GSelect *Select = cast<GSelect>(&MI); 7304 7305 if (tryFoldSelectOfConstants(Select, MatchInfo)) 7306 return true; 7307 7308 if (tryFoldBoolSelectToLogic(Select, MatchInfo)) 7309 return true; 7310 7311 return false; 7312 } 7313 7314 /// Fold (icmp Pred1 V1, C1) && (icmp Pred2 V2, C2) 7315 /// or (icmp Pred1 V1, C1) || (icmp Pred2 V2, C2) 7316 /// into a single comparison using range-based reasoning. 7317 /// see InstCombinerImpl::foldAndOrOfICmpsUsingRanges. 7318 bool CombinerHelper::tryFoldAndOrOrICmpsUsingRanges( 7319 GLogicalBinOp *Logic, BuildFnTy &MatchInfo) const { 7320 assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpected xor"); 7321 bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND; 7322 Register DstReg = Logic->getReg(0); 7323 Register LHS = Logic->getLHSReg(); 7324 Register RHS = Logic->getRHSReg(); 7325 unsigned Flags = Logic->getFlags(); 7326 7327 // We need an G_ICMP on the LHS register. 7328 GICmp *Cmp1 = getOpcodeDef<GICmp>(LHS, MRI); 7329 if (!Cmp1) 7330 return false; 7331 7332 // We need an G_ICMP on the RHS register. 7333 GICmp *Cmp2 = getOpcodeDef<GICmp>(RHS, MRI); 7334 if (!Cmp2) 7335 return false; 7336 7337 // We want to fold the icmps. 7338 if (!MRI.hasOneNonDBGUse(Cmp1->getReg(0)) || 7339 !MRI.hasOneNonDBGUse(Cmp2->getReg(0))) 7340 return false; 7341 7342 APInt C1; 7343 APInt C2; 7344 std::optional<ValueAndVReg> MaybeC1 = 7345 getIConstantVRegValWithLookThrough(Cmp1->getRHSReg(), MRI); 7346 if (!MaybeC1) 7347 return false; 7348 C1 = MaybeC1->Value; 7349 7350 std::optional<ValueAndVReg> MaybeC2 = 7351 getIConstantVRegValWithLookThrough(Cmp2->getRHSReg(), MRI); 7352 if (!MaybeC2) 7353 return false; 7354 C2 = MaybeC2->Value; 7355 7356 Register R1 = Cmp1->getLHSReg(); 7357 Register R2 = Cmp2->getLHSReg(); 7358 CmpInst::Predicate Pred1 = Cmp1->getCond(); 7359 CmpInst::Predicate Pred2 = Cmp2->getCond(); 7360 LLT CmpTy = MRI.getType(Cmp1->getReg(0)); 7361 LLT CmpOperandTy = MRI.getType(R1); 7362 7363 if (CmpOperandTy.isPointer()) 7364 return false; 7365 7366 // We build ands, adds, and constants of type CmpOperandTy. 7367 // They must be legal to build. 7368 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_AND, CmpOperandTy}) || 7369 !isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, CmpOperandTy}) || 7370 !isConstantLegalOrBeforeLegalizer(CmpOperandTy)) 7371 return false; 7372 7373 // Look through add of a constant offset on R1, R2, or both operands. This 7374 // allows us to interpret the R + C' < C'' range idiom into a proper range. 7375 std::optional<APInt> Offset1; 7376 std::optional<APInt> Offset2; 7377 if (R1 != R2) { 7378 if (GAdd *Add = getOpcodeDef<GAdd>(R1, MRI)) { 7379 std::optional<ValueAndVReg> MaybeOffset1 = 7380 getIConstantVRegValWithLookThrough(Add->getRHSReg(), MRI); 7381 if (MaybeOffset1) { 7382 R1 = Add->getLHSReg(); 7383 Offset1 = MaybeOffset1->Value; 7384 } 7385 } 7386 if (GAdd *Add = getOpcodeDef<GAdd>(R2, MRI)) { 7387 std::optional<ValueAndVReg> MaybeOffset2 = 7388 getIConstantVRegValWithLookThrough(Add->getRHSReg(), MRI); 7389 if (MaybeOffset2) { 7390 R2 = Add->getLHSReg(); 7391 Offset2 = MaybeOffset2->Value; 7392 } 7393 } 7394 } 7395 7396 if (R1 != R2) 7397 return false; 7398 7399 // We calculate the icmp ranges including maybe offsets. 7400 ConstantRange CR1 = ConstantRange::makeExactICmpRegion( 7401 IsAnd ? ICmpInst::getInversePredicate(Pred1) : Pred1, C1); 7402 if (Offset1) 7403 CR1 = CR1.subtract(*Offset1); 7404 7405 ConstantRange CR2 = ConstantRange::makeExactICmpRegion( 7406 IsAnd ? ICmpInst::getInversePredicate(Pred2) : Pred2, C2); 7407 if (Offset2) 7408 CR2 = CR2.subtract(*Offset2); 7409 7410 bool CreateMask = false; 7411 APInt LowerDiff; 7412 std::optional<ConstantRange> CR = CR1.exactUnionWith(CR2); 7413 if (!CR) { 7414 // We need non-wrapping ranges. 7415 if (CR1.isWrappedSet() || CR2.isWrappedSet()) 7416 return false; 7417 7418 // Check whether we have equal-size ranges that only differ by one bit. 7419 // In that case we can apply a mask to map one range onto the other. 7420 LowerDiff = CR1.getLower() ^ CR2.getLower(); 7421 APInt UpperDiff = (CR1.getUpper() - 1) ^ (CR2.getUpper() - 1); 7422 APInt CR1Size = CR1.getUpper() - CR1.getLower(); 7423 if (!LowerDiff.isPowerOf2() || LowerDiff != UpperDiff || 7424 CR1Size != CR2.getUpper() - CR2.getLower()) 7425 return false; 7426 7427 CR = CR1.getLower().ult(CR2.getLower()) ? CR1 : CR2; 7428 CreateMask = true; 7429 } 7430 7431 if (IsAnd) 7432 CR = CR->inverse(); 7433 7434 CmpInst::Predicate NewPred; 7435 APInt NewC, Offset; 7436 CR->getEquivalentICmp(NewPred, NewC, Offset); 7437 7438 // We take the result type of one of the original icmps, CmpTy, for 7439 // the to be build icmp. The operand type, CmpOperandTy, is used for 7440 // the other instructions and constants to be build. The types of 7441 // the parameters and output are the same for add and and. CmpTy 7442 // and the type of DstReg might differ. That is why we zext or trunc 7443 // the icmp into the destination register. 7444 7445 MatchInfo = [=](MachineIRBuilder &B) { 7446 if (CreateMask && Offset != 0) { 7447 auto TildeLowerDiff = B.buildConstant(CmpOperandTy, ~LowerDiff); 7448 auto And = B.buildAnd(CmpOperandTy, R1, TildeLowerDiff); // the mask. 7449 auto OffsetC = B.buildConstant(CmpOperandTy, Offset); 7450 auto Add = B.buildAdd(CmpOperandTy, And, OffsetC, Flags); 7451 auto NewCon = B.buildConstant(CmpOperandTy, NewC); 7452 auto ICmp = B.buildICmp(NewPred, CmpTy, Add, NewCon); 7453 B.buildZExtOrTrunc(DstReg, ICmp); 7454 } else if (CreateMask && Offset == 0) { 7455 auto TildeLowerDiff = B.buildConstant(CmpOperandTy, ~LowerDiff); 7456 auto And = B.buildAnd(CmpOperandTy, R1, TildeLowerDiff); // the mask. 7457 auto NewCon = B.buildConstant(CmpOperandTy, NewC); 7458 auto ICmp = B.buildICmp(NewPred, CmpTy, And, NewCon); 7459 B.buildZExtOrTrunc(DstReg, ICmp); 7460 } else if (!CreateMask && Offset != 0) { 7461 auto OffsetC = B.buildConstant(CmpOperandTy, Offset); 7462 auto Add = B.buildAdd(CmpOperandTy, R1, OffsetC, Flags); 7463 auto NewCon = B.buildConstant(CmpOperandTy, NewC); 7464 auto ICmp = B.buildICmp(NewPred, CmpTy, Add, NewCon); 7465 B.buildZExtOrTrunc(DstReg, ICmp); 7466 } else if (!CreateMask && Offset == 0) { 7467 auto NewCon = B.buildConstant(CmpOperandTy, NewC); 7468 auto ICmp = B.buildICmp(NewPred, CmpTy, R1, NewCon); 7469 B.buildZExtOrTrunc(DstReg, ICmp); 7470 } else { 7471 llvm_unreachable("unexpected configuration of CreateMask and Offset"); 7472 } 7473 }; 7474 return true; 7475 } 7476 7477 bool CombinerHelper::tryFoldLogicOfFCmps(GLogicalBinOp *Logic, 7478 BuildFnTy &MatchInfo) const { 7479 assert(Logic->getOpcode() != TargetOpcode::G_XOR && "unexpecte xor"); 7480 Register DestReg = Logic->getReg(0); 7481 Register LHS = Logic->getLHSReg(); 7482 Register RHS = Logic->getRHSReg(); 7483 bool IsAnd = Logic->getOpcode() == TargetOpcode::G_AND; 7484 7485 // We need a compare on the LHS register. 7486 GFCmp *Cmp1 = getOpcodeDef<GFCmp>(LHS, MRI); 7487 if (!Cmp1) 7488 return false; 7489 7490 // We need a compare on the RHS register. 7491 GFCmp *Cmp2 = getOpcodeDef<GFCmp>(RHS, MRI); 7492 if (!Cmp2) 7493 return false; 7494 7495 LLT CmpTy = MRI.getType(Cmp1->getReg(0)); 7496 LLT CmpOperandTy = MRI.getType(Cmp1->getLHSReg()); 7497 7498 // We build one fcmp, want to fold the fcmps, replace the logic op, 7499 // and the fcmps must have the same shape. 7500 if (!isLegalOrBeforeLegalizer( 7501 {TargetOpcode::G_FCMP, {CmpTy, CmpOperandTy}}) || 7502 !MRI.hasOneNonDBGUse(Logic->getReg(0)) || 7503 !MRI.hasOneNonDBGUse(Cmp1->getReg(0)) || 7504 !MRI.hasOneNonDBGUse(Cmp2->getReg(0)) || 7505 MRI.getType(Cmp1->getLHSReg()) != MRI.getType(Cmp2->getLHSReg())) 7506 return false; 7507 7508 CmpInst::Predicate PredL = Cmp1->getCond(); 7509 CmpInst::Predicate PredR = Cmp2->getCond(); 7510 Register LHS0 = Cmp1->getLHSReg(); 7511 Register LHS1 = Cmp1->getRHSReg(); 7512 Register RHS0 = Cmp2->getLHSReg(); 7513 Register RHS1 = Cmp2->getRHSReg(); 7514 7515 if (LHS0 == RHS1 && LHS1 == RHS0) { 7516 // Swap RHS operands to match LHS. 7517 PredR = CmpInst::getSwappedPredicate(PredR); 7518 std::swap(RHS0, RHS1); 7519 } 7520 7521 if (LHS0 == RHS0 && LHS1 == RHS1) { 7522 // We determine the new predicate. 7523 unsigned CmpCodeL = getFCmpCode(PredL); 7524 unsigned CmpCodeR = getFCmpCode(PredR); 7525 unsigned NewPred = IsAnd ? CmpCodeL & CmpCodeR : CmpCodeL | CmpCodeR; 7526 unsigned Flags = Cmp1->getFlags() | Cmp2->getFlags(); 7527 MatchInfo = [=](MachineIRBuilder &B) { 7528 // The fcmp predicates fill the lower part of the enum. 7529 FCmpInst::Predicate Pred = static_cast<FCmpInst::Predicate>(NewPred); 7530 if (Pred == FCmpInst::FCMP_FALSE && 7531 isConstantLegalOrBeforeLegalizer(CmpTy)) { 7532 auto False = B.buildConstant(CmpTy, 0); 7533 B.buildZExtOrTrunc(DestReg, False); 7534 } else if (Pred == FCmpInst::FCMP_TRUE && 7535 isConstantLegalOrBeforeLegalizer(CmpTy)) { 7536 auto True = 7537 B.buildConstant(CmpTy, getICmpTrueVal(getTargetLowering(), 7538 CmpTy.isVector() /*isVector*/, 7539 true /*isFP*/)); 7540 B.buildZExtOrTrunc(DestReg, True); 7541 } else { // We take the predicate without predicate optimizations. 7542 auto Cmp = B.buildFCmp(Pred, CmpTy, LHS0, LHS1, Flags); 7543 B.buildZExtOrTrunc(DestReg, Cmp); 7544 } 7545 }; 7546 return true; 7547 } 7548 7549 return false; 7550 } 7551 7552 bool CombinerHelper::matchAnd(MachineInstr &MI, BuildFnTy &MatchInfo) const { 7553 GAnd *And = cast<GAnd>(&MI); 7554 7555 if (tryFoldAndOrOrICmpsUsingRanges(And, MatchInfo)) 7556 return true; 7557 7558 if (tryFoldLogicOfFCmps(And, MatchInfo)) 7559 return true; 7560 7561 return false; 7562 } 7563 7564 bool CombinerHelper::matchOr(MachineInstr &MI, BuildFnTy &MatchInfo) const { 7565 GOr *Or = cast<GOr>(&MI); 7566 7567 if (tryFoldAndOrOrICmpsUsingRanges(Or, MatchInfo)) 7568 return true; 7569 7570 if (tryFoldLogicOfFCmps(Or, MatchInfo)) 7571 return true; 7572 7573 return false; 7574 } 7575 7576 bool CombinerHelper::matchAddOverflow(MachineInstr &MI, 7577 BuildFnTy &MatchInfo) const { 7578 GAddCarryOut *Add = cast<GAddCarryOut>(&MI); 7579 7580 // Addo has no flags 7581 Register Dst = Add->getReg(0); 7582 Register Carry = Add->getReg(1); 7583 Register LHS = Add->getLHSReg(); 7584 Register RHS = Add->getRHSReg(); 7585 bool IsSigned = Add->isSigned(); 7586 LLT DstTy = MRI.getType(Dst); 7587 LLT CarryTy = MRI.getType(Carry); 7588 7589 // Fold addo, if the carry is dead -> add, undef. 7590 if (MRI.use_nodbg_empty(Carry) && 7591 isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, {DstTy}})) { 7592 MatchInfo = [=](MachineIRBuilder &B) { 7593 B.buildAdd(Dst, LHS, RHS); 7594 B.buildUndef(Carry); 7595 }; 7596 return true; 7597 } 7598 7599 // Canonicalize constant to RHS. 7600 if (isConstantOrConstantVectorI(LHS) && !isConstantOrConstantVectorI(RHS)) { 7601 if (IsSigned) { 7602 MatchInfo = [=](MachineIRBuilder &B) { 7603 B.buildSAddo(Dst, Carry, RHS, LHS); 7604 }; 7605 return true; 7606 } 7607 // !IsSigned 7608 MatchInfo = [=](MachineIRBuilder &B) { 7609 B.buildUAddo(Dst, Carry, RHS, LHS); 7610 }; 7611 return true; 7612 } 7613 7614 std::optional<APInt> MaybeLHS = getConstantOrConstantSplatVector(LHS); 7615 std::optional<APInt> MaybeRHS = getConstantOrConstantSplatVector(RHS); 7616 7617 // Fold addo(c1, c2) -> c3, carry. 7618 if (MaybeLHS && MaybeRHS && isConstantLegalOrBeforeLegalizer(DstTy) && 7619 isConstantLegalOrBeforeLegalizer(CarryTy)) { 7620 bool Overflow; 7621 APInt Result = IsSigned ? MaybeLHS->sadd_ov(*MaybeRHS, Overflow) 7622 : MaybeLHS->uadd_ov(*MaybeRHS, Overflow); 7623 MatchInfo = [=](MachineIRBuilder &B) { 7624 B.buildConstant(Dst, Result); 7625 B.buildConstant(Carry, Overflow); 7626 }; 7627 return true; 7628 } 7629 7630 // Fold (addo x, 0) -> x, no carry 7631 if (MaybeRHS && *MaybeRHS == 0 && isConstantLegalOrBeforeLegalizer(CarryTy)) { 7632 MatchInfo = [=](MachineIRBuilder &B) { 7633 B.buildCopy(Dst, LHS); 7634 B.buildConstant(Carry, 0); 7635 }; 7636 return true; 7637 } 7638 7639 // Given 2 constant operands whose sum does not overflow: 7640 // uaddo (X +nuw C0), C1 -> uaddo X, C0 + C1 7641 // saddo (X +nsw C0), C1 -> saddo X, C0 + C1 7642 GAdd *AddLHS = getOpcodeDef<GAdd>(LHS, MRI); 7643 if (MaybeRHS && AddLHS && MRI.hasOneNonDBGUse(Add->getReg(0)) && 7644 ((IsSigned && AddLHS->getFlag(MachineInstr::MIFlag::NoSWrap)) || 7645 (!IsSigned && AddLHS->getFlag(MachineInstr::MIFlag::NoUWrap)))) { 7646 std::optional<APInt> MaybeAddRHS = 7647 getConstantOrConstantSplatVector(AddLHS->getRHSReg()); 7648 if (MaybeAddRHS) { 7649 bool Overflow; 7650 APInt NewC = IsSigned ? MaybeAddRHS->sadd_ov(*MaybeRHS, Overflow) 7651 : MaybeAddRHS->uadd_ov(*MaybeRHS, Overflow); 7652 if (!Overflow && isConstantLegalOrBeforeLegalizer(DstTy)) { 7653 if (IsSigned) { 7654 MatchInfo = [=](MachineIRBuilder &B) { 7655 auto ConstRHS = B.buildConstant(DstTy, NewC); 7656 B.buildSAddo(Dst, Carry, AddLHS->getLHSReg(), ConstRHS); 7657 }; 7658 return true; 7659 } 7660 // !IsSigned 7661 MatchInfo = [=](MachineIRBuilder &B) { 7662 auto ConstRHS = B.buildConstant(DstTy, NewC); 7663 B.buildUAddo(Dst, Carry, AddLHS->getLHSReg(), ConstRHS); 7664 }; 7665 return true; 7666 } 7667 } 7668 }; 7669 7670 // We try to combine addo to non-overflowing add. 7671 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_ADD, {DstTy}}) || 7672 !isConstantLegalOrBeforeLegalizer(CarryTy)) 7673 return false; 7674 7675 // We try to combine uaddo to non-overflowing add. 7676 if (!IsSigned) { 7677 ConstantRange CRLHS = 7678 ConstantRange::fromKnownBits(VT->getKnownBits(LHS), /*IsSigned=*/false); 7679 ConstantRange CRRHS = 7680 ConstantRange::fromKnownBits(VT->getKnownBits(RHS), /*IsSigned=*/false); 7681 7682 switch (CRLHS.unsignedAddMayOverflow(CRRHS)) { 7683 case ConstantRange::OverflowResult::MayOverflow: 7684 return false; 7685 case ConstantRange::OverflowResult::NeverOverflows: { 7686 MatchInfo = [=](MachineIRBuilder &B) { 7687 B.buildAdd(Dst, LHS, RHS, MachineInstr::MIFlag::NoUWrap); 7688 B.buildConstant(Carry, 0); 7689 }; 7690 return true; 7691 } 7692 case ConstantRange::OverflowResult::AlwaysOverflowsLow: 7693 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: { 7694 MatchInfo = [=](MachineIRBuilder &B) { 7695 B.buildAdd(Dst, LHS, RHS); 7696 B.buildConstant(Carry, 1); 7697 }; 7698 return true; 7699 } 7700 } 7701 return false; 7702 } 7703 7704 // We try to combine saddo to non-overflowing add. 7705 7706 // If LHS and RHS each have at least two sign bits, then there is no signed 7707 // overflow. 7708 if (VT->computeNumSignBits(RHS) > 1 && VT->computeNumSignBits(LHS) > 1) { 7709 MatchInfo = [=](MachineIRBuilder &B) { 7710 B.buildAdd(Dst, LHS, RHS, MachineInstr::MIFlag::NoSWrap); 7711 B.buildConstant(Carry, 0); 7712 }; 7713 return true; 7714 } 7715 7716 ConstantRange CRLHS = 7717 ConstantRange::fromKnownBits(VT->getKnownBits(LHS), /*IsSigned=*/true); 7718 ConstantRange CRRHS = 7719 ConstantRange::fromKnownBits(VT->getKnownBits(RHS), /*IsSigned=*/true); 7720 7721 switch (CRLHS.signedAddMayOverflow(CRRHS)) { 7722 case ConstantRange::OverflowResult::MayOverflow: 7723 return false; 7724 case ConstantRange::OverflowResult::NeverOverflows: { 7725 MatchInfo = [=](MachineIRBuilder &B) { 7726 B.buildAdd(Dst, LHS, RHS, MachineInstr::MIFlag::NoSWrap); 7727 B.buildConstant(Carry, 0); 7728 }; 7729 return true; 7730 } 7731 case ConstantRange::OverflowResult::AlwaysOverflowsLow: 7732 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: { 7733 MatchInfo = [=](MachineIRBuilder &B) { 7734 B.buildAdd(Dst, LHS, RHS); 7735 B.buildConstant(Carry, 1); 7736 }; 7737 return true; 7738 } 7739 } 7740 7741 return false; 7742 } 7743 7744 void CombinerHelper::applyBuildFnMO(const MachineOperand &MO, 7745 BuildFnTy &MatchInfo) const { 7746 MachineInstr *Root = getDefIgnoringCopies(MO.getReg(), MRI); 7747 MatchInfo(Builder); 7748 Root->eraseFromParent(); 7749 } 7750 7751 bool CombinerHelper::matchFPowIExpansion(MachineInstr &MI, 7752 int64_t Exponent) const { 7753 bool OptForSize = MI.getMF()->getFunction().hasOptSize(); 7754 return getTargetLowering().isBeneficialToExpandPowI(Exponent, OptForSize); 7755 } 7756 7757 void CombinerHelper::applyExpandFPowI(MachineInstr &MI, 7758 int64_t Exponent) const { 7759 auto [Dst, Base] = MI.getFirst2Regs(); 7760 LLT Ty = MRI.getType(Dst); 7761 int64_t ExpVal = Exponent; 7762 7763 if (ExpVal == 0) { 7764 Builder.buildFConstant(Dst, 1.0); 7765 MI.removeFromParent(); 7766 return; 7767 } 7768 7769 if (ExpVal < 0) 7770 ExpVal = -ExpVal; 7771 7772 // We use the simple binary decomposition method from SelectionDAG ExpandPowI 7773 // to generate the multiply sequence. There are more optimal ways to do this 7774 // (for example, powi(x,15) generates one more multiply than it should), but 7775 // this has the benefit of being both really simple and much better than a 7776 // libcall. 7777 std::optional<SrcOp> Res; 7778 SrcOp CurSquare = Base; 7779 while (ExpVal > 0) { 7780 if (ExpVal & 1) { 7781 if (!Res) 7782 Res = CurSquare; 7783 else 7784 Res = Builder.buildFMul(Ty, *Res, CurSquare); 7785 } 7786 7787 CurSquare = Builder.buildFMul(Ty, CurSquare, CurSquare); 7788 ExpVal >>= 1; 7789 } 7790 7791 // If the original exponent was negative, invert the result, producing 7792 // 1/(x*x*x). 7793 if (Exponent < 0) 7794 Res = Builder.buildFDiv(Ty, Builder.buildFConstant(Ty, 1.0), *Res, 7795 MI.getFlags()); 7796 7797 Builder.buildCopy(Dst, *Res); 7798 MI.eraseFromParent(); 7799 } 7800 7801 bool CombinerHelper::matchFoldAPlusC1MinusC2(const MachineInstr &MI, 7802 BuildFnTy &MatchInfo) const { 7803 // fold (A+C1)-C2 -> A+(C1-C2) 7804 const GSub *Sub = cast<GSub>(&MI); 7805 GAdd *Add = cast<GAdd>(MRI.getVRegDef(Sub->getLHSReg())); 7806 7807 if (!MRI.hasOneNonDBGUse(Add->getReg(0))) 7808 return false; 7809 7810 APInt C2 = getIConstantFromReg(Sub->getRHSReg(), MRI); 7811 APInt C1 = getIConstantFromReg(Add->getRHSReg(), MRI); 7812 7813 Register Dst = Sub->getReg(0); 7814 LLT DstTy = MRI.getType(Dst); 7815 7816 MatchInfo = [=](MachineIRBuilder &B) { 7817 auto Const = B.buildConstant(DstTy, C1 - C2); 7818 B.buildAdd(Dst, Add->getLHSReg(), Const); 7819 }; 7820 7821 return true; 7822 } 7823 7824 bool CombinerHelper::matchFoldC2MinusAPlusC1(const MachineInstr &MI, 7825 BuildFnTy &MatchInfo) const { 7826 // fold C2-(A+C1) -> (C2-C1)-A 7827 const GSub *Sub = cast<GSub>(&MI); 7828 GAdd *Add = cast<GAdd>(MRI.getVRegDef(Sub->getRHSReg())); 7829 7830 if (!MRI.hasOneNonDBGUse(Add->getReg(0))) 7831 return false; 7832 7833 APInt C2 = getIConstantFromReg(Sub->getLHSReg(), MRI); 7834 APInt C1 = getIConstantFromReg(Add->getRHSReg(), MRI); 7835 7836 Register Dst = Sub->getReg(0); 7837 LLT DstTy = MRI.getType(Dst); 7838 7839 MatchInfo = [=](MachineIRBuilder &B) { 7840 auto Const = B.buildConstant(DstTy, C2 - C1); 7841 B.buildSub(Dst, Const, Add->getLHSReg()); 7842 }; 7843 7844 return true; 7845 } 7846 7847 bool CombinerHelper::matchFoldAMinusC1MinusC2(const MachineInstr &MI, 7848 BuildFnTy &MatchInfo) const { 7849 // fold (A-C1)-C2 -> A-(C1+C2) 7850 const GSub *Sub1 = cast<GSub>(&MI); 7851 GSub *Sub2 = cast<GSub>(MRI.getVRegDef(Sub1->getLHSReg())); 7852 7853 if (!MRI.hasOneNonDBGUse(Sub2->getReg(0))) 7854 return false; 7855 7856 APInt C2 = getIConstantFromReg(Sub1->getRHSReg(), MRI); 7857 APInt C1 = getIConstantFromReg(Sub2->getRHSReg(), MRI); 7858 7859 Register Dst = Sub1->getReg(0); 7860 LLT DstTy = MRI.getType(Dst); 7861 7862 MatchInfo = [=](MachineIRBuilder &B) { 7863 auto Const = B.buildConstant(DstTy, C1 + C2); 7864 B.buildSub(Dst, Sub2->getLHSReg(), Const); 7865 }; 7866 7867 return true; 7868 } 7869 7870 bool CombinerHelper::matchFoldC1Minus2MinusC2(const MachineInstr &MI, 7871 BuildFnTy &MatchInfo) const { 7872 // fold (C1-A)-C2 -> (C1-C2)-A 7873 const GSub *Sub1 = cast<GSub>(&MI); 7874 GSub *Sub2 = cast<GSub>(MRI.getVRegDef(Sub1->getLHSReg())); 7875 7876 if (!MRI.hasOneNonDBGUse(Sub2->getReg(0))) 7877 return false; 7878 7879 APInt C2 = getIConstantFromReg(Sub1->getRHSReg(), MRI); 7880 APInt C1 = getIConstantFromReg(Sub2->getLHSReg(), MRI); 7881 7882 Register Dst = Sub1->getReg(0); 7883 LLT DstTy = MRI.getType(Dst); 7884 7885 MatchInfo = [=](MachineIRBuilder &B) { 7886 auto Const = B.buildConstant(DstTy, C1 - C2); 7887 B.buildSub(Dst, Const, Sub2->getRHSReg()); 7888 }; 7889 7890 return true; 7891 } 7892 7893 bool CombinerHelper::matchFoldAMinusC1PlusC2(const MachineInstr &MI, 7894 BuildFnTy &MatchInfo) const { 7895 // fold ((A-C1)+C2) -> (A+(C2-C1)) 7896 const GAdd *Add = cast<GAdd>(&MI); 7897 GSub *Sub = cast<GSub>(MRI.getVRegDef(Add->getLHSReg())); 7898 7899 if (!MRI.hasOneNonDBGUse(Sub->getReg(0))) 7900 return false; 7901 7902 APInt C2 = getIConstantFromReg(Add->getRHSReg(), MRI); 7903 APInt C1 = getIConstantFromReg(Sub->getRHSReg(), MRI); 7904 7905 Register Dst = Add->getReg(0); 7906 LLT DstTy = MRI.getType(Dst); 7907 7908 MatchInfo = [=](MachineIRBuilder &B) { 7909 auto Const = B.buildConstant(DstTy, C2 - C1); 7910 B.buildAdd(Dst, Sub->getLHSReg(), Const); 7911 }; 7912 7913 return true; 7914 } 7915 7916 bool CombinerHelper::matchUnmergeValuesAnyExtBuildVector( 7917 const MachineInstr &MI, BuildFnTy &MatchInfo) const { 7918 const GUnmerge *Unmerge = cast<GUnmerge>(&MI); 7919 7920 if (!MRI.hasOneNonDBGUse(Unmerge->getSourceReg())) 7921 return false; 7922 7923 const MachineInstr *Source = MRI.getVRegDef(Unmerge->getSourceReg()); 7924 7925 LLT DstTy = MRI.getType(Unmerge->getReg(0)); 7926 7927 // $bv:_(<8 x s8>) = G_BUILD_VECTOR .... 7928 // $any:_(<8 x s16>) = G_ANYEXT $bv 7929 // $uv:_(<4 x s16>), $uv1:_(<4 x s16>) = G_UNMERGE_VALUES $any 7930 // 7931 // -> 7932 // 7933 // $any:_(s16) = G_ANYEXT $bv[0] 7934 // $any1:_(s16) = G_ANYEXT $bv[1] 7935 // $any2:_(s16) = G_ANYEXT $bv[2] 7936 // $any3:_(s16) = G_ANYEXT $bv[3] 7937 // $any4:_(s16) = G_ANYEXT $bv[4] 7938 // $any5:_(s16) = G_ANYEXT $bv[5] 7939 // $any6:_(s16) = G_ANYEXT $bv[6] 7940 // $any7:_(s16) = G_ANYEXT $bv[7] 7941 // $uv:_(<4 x s16>) = G_BUILD_VECTOR $any, $any1, $any2, $any3 7942 // $uv1:_(<4 x s16>) = G_BUILD_VECTOR $any4, $any5, $any6, $any7 7943 7944 // We want to unmerge into vectors. 7945 if (!DstTy.isFixedVector()) 7946 return false; 7947 7948 const GAnyExt *Any = dyn_cast<GAnyExt>(Source); 7949 if (!Any) 7950 return false; 7951 7952 const MachineInstr *NextSource = MRI.getVRegDef(Any->getSrcReg()); 7953 7954 if (const GBuildVector *BV = dyn_cast<GBuildVector>(NextSource)) { 7955 // G_UNMERGE_VALUES G_ANYEXT G_BUILD_VECTOR 7956 7957 if (!MRI.hasOneNonDBGUse(BV->getReg(0))) 7958 return false; 7959 7960 // FIXME: check element types? 7961 if (BV->getNumSources() % Unmerge->getNumDefs() != 0) 7962 return false; 7963 7964 LLT BigBvTy = MRI.getType(BV->getReg(0)); 7965 LLT SmallBvTy = DstTy; 7966 LLT SmallBvElemenTy = SmallBvTy.getElementType(); 7967 7968 if (!isLegalOrBeforeLegalizer( 7969 {TargetOpcode::G_BUILD_VECTOR, {SmallBvTy, SmallBvElemenTy}})) 7970 return false; 7971 7972 // We check the legality of scalar anyext. 7973 if (!isLegalOrBeforeLegalizer( 7974 {TargetOpcode::G_ANYEXT, 7975 {SmallBvElemenTy, BigBvTy.getElementType()}})) 7976 return false; 7977 7978 MatchInfo = [=](MachineIRBuilder &B) { 7979 // Build into each G_UNMERGE_VALUES def 7980 // a small build vector with anyext from the source build vector. 7981 for (unsigned I = 0; I < Unmerge->getNumDefs(); ++I) { 7982 SmallVector<Register> Ops; 7983 for (unsigned J = 0; J < SmallBvTy.getNumElements(); ++J) { 7984 Register SourceArray = 7985 BV->getSourceReg(I * SmallBvTy.getNumElements() + J); 7986 auto AnyExt = B.buildAnyExt(SmallBvElemenTy, SourceArray); 7987 Ops.push_back(AnyExt.getReg(0)); 7988 } 7989 B.buildBuildVector(Unmerge->getOperand(I).getReg(), Ops); 7990 }; 7991 }; 7992 return true; 7993 }; 7994 7995 return false; 7996 } 7997 7998 bool CombinerHelper::matchShuffleUndefRHS(MachineInstr &MI, 7999 BuildFnTy &MatchInfo) const { 8000 8001 bool Changed = false; 8002 auto &Shuffle = cast<GShuffleVector>(MI); 8003 ArrayRef<int> OrigMask = Shuffle.getMask(); 8004 SmallVector<int, 16> NewMask; 8005 const LLT SrcTy = MRI.getType(Shuffle.getSrc1Reg()); 8006 const unsigned NumSrcElems = SrcTy.isVector() ? SrcTy.getNumElements() : 1; 8007 const unsigned NumDstElts = OrigMask.size(); 8008 for (unsigned i = 0; i != NumDstElts; ++i) { 8009 int Idx = OrigMask[i]; 8010 if (Idx >= (int)NumSrcElems) { 8011 Idx = -1; 8012 Changed = true; 8013 } 8014 NewMask.push_back(Idx); 8015 } 8016 8017 if (!Changed) 8018 return false; 8019 8020 MatchInfo = [&, NewMask = std::move(NewMask)](MachineIRBuilder &B) { 8021 B.buildShuffleVector(MI.getOperand(0), MI.getOperand(1), MI.getOperand(2), 8022 std::move(NewMask)); 8023 }; 8024 8025 return true; 8026 } 8027 8028 static void commuteMask(MutableArrayRef<int> Mask, const unsigned NumElems) { 8029 const unsigned MaskSize = Mask.size(); 8030 for (unsigned I = 0; I < MaskSize; ++I) { 8031 int Idx = Mask[I]; 8032 if (Idx < 0) 8033 continue; 8034 8035 if (Idx < (int)NumElems) 8036 Mask[I] = Idx + NumElems; 8037 else 8038 Mask[I] = Idx - NumElems; 8039 } 8040 } 8041 8042 bool CombinerHelper::matchShuffleDisjointMask(MachineInstr &MI, 8043 BuildFnTy &MatchInfo) const { 8044 8045 auto &Shuffle = cast<GShuffleVector>(MI); 8046 // If any of the two inputs is already undef, don't check the mask again to 8047 // prevent infinite loop 8048 if (getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, Shuffle.getSrc1Reg(), MRI)) 8049 return false; 8050 8051 if (getOpcodeDef(TargetOpcode::G_IMPLICIT_DEF, Shuffle.getSrc2Reg(), MRI)) 8052 return false; 8053 8054 const LLT DstTy = MRI.getType(Shuffle.getReg(0)); 8055 const LLT Src1Ty = MRI.getType(Shuffle.getSrc1Reg()); 8056 if (!isLegalOrBeforeLegalizer( 8057 {TargetOpcode::G_SHUFFLE_VECTOR, {DstTy, Src1Ty}})) 8058 return false; 8059 8060 ArrayRef<int> Mask = Shuffle.getMask(); 8061 const unsigned NumSrcElems = Src1Ty.isVector() ? Src1Ty.getNumElements() : 1; 8062 8063 bool TouchesSrc1 = false; 8064 bool TouchesSrc2 = false; 8065 const unsigned NumElems = Mask.size(); 8066 for (unsigned Idx = 0; Idx < NumElems; ++Idx) { 8067 if (Mask[Idx] < 0) 8068 continue; 8069 8070 if (Mask[Idx] < (int)NumSrcElems) 8071 TouchesSrc1 = true; 8072 else 8073 TouchesSrc2 = true; 8074 } 8075 8076 if (TouchesSrc1 == TouchesSrc2) 8077 return false; 8078 8079 Register NewSrc1 = Shuffle.getSrc1Reg(); 8080 SmallVector<int, 16> NewMask(Mask); 8081 if (TouchesSrc2) { 8082 NewSrc1 = Shuffle.getSrc2Reg(); 8083 commuteMask(NewMask, NumSrcElems); 8084 } 8085 8086 MatchInfo = [=, &Shuffle](MachineIRBuilder &B) { 8087 auto Undef = B.buildUndef(Src1Ty); 8088 B.buildShuffleVector(Shuffle.getReg(0), NewSrc1, Undef, NewMask); 8089 }; 8090 8091 return true; 8092 } 8093 8094 bool CombinerHelper::matchSuboCarryOut(const MachineInstr &MI, 8095 BuildFnTy &MatchInfo) const { 8096 const GSubCarryOut *Subo = cast<GSubCarryOut>(&MI); 8097 8098 Register Dst = Subo->getReg(0); 8099 Register LHS = Subo->getLHSReg(); 8100 Register RHS = Subo->getRHSReg(); 8101 Register Carry = Subo->getCarryOutReg(); 8102 LLT DstTy = MRI.getType(Dst); 8103 LLT CarryTy = MRI.getType(Carry); 8104 8105 // Check legality before known bits. 8106 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SUB, {DstTy}}) || 8107 !isConstantLegalOrBeforeLegalizer(CarryTy)) 8108 return false; 8109 8110 ConstantRange KBLHS = 8111 ConstantRange::fromKnownBits(VT->getKnownBits(LHS), 8112 /* IsSigned=*/Subo->isSigned()); 8113 ConstantRange KBRHS = 8114 ConstantRange::fromKnownBits(VT->getKnownBits(RHS), 8115 /* IsSigned=*/Subo->isSigned()); 8116 8117 if (Subo->isSigned()) { 8118 // G_SSUBO 8119 switch (KBLHS.signedSubMayOverflow(KBRHS)) { 8120 case ConstantRange::OverflowResult::MayOverflow: 8121 return false; 8122 case ConstantRange::OverflowResult::NeverOverflows: { 8123 MatchInfo = [=](MachineIRBuilder &B) { 8124 B.buildSub(Dst, LHS, RHS, MachineInstr::MIFlag::NoSWrap); 8125 B.buildConstant(Carry, 0); 8126 }; 8127 return true; 8128 } 8129 case ConstantRange::OverflowResult::AlwaysOverflowsLow: 8130 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: { 8131 MatchInfo = [=](MachineIRBuilder &B) { 8132 B.buildSub(Dst, LHS, RHS); 8133 B.buildConstant(Carry, getICmpTrueVal(getTargetLowering(), 8134 /*isVector=*/CarryTy.isVector(), 8135 /*isFP=*/false)); 8136 }; 8137 return true; 8138 } 8139 } 8140 return false; 8141 } 8142 8143 // G_USUBO 8144 switch (KBLHS.unsignedSubMayOverflow(KBRHS)) { 8145 case ConstantRange::OverflowResult::MayOverflow: 8146 return false; 8147 case ConstantRange::OverflowResult::NeverOverflows: { 8148 MatchInfo = [=](MachineIRBuilder &B) { 8149 B.buildSub(Dst, LHS, RHS, MachineInstr::MIFlag::NoUWrap); 8150 B.buildConstant(Carry, 0); 8151 }; 8152 return true; 8153 } 8154 case ConstantRange::OverflowResult::AlwaysOverflowsLow: 8155 case ConstantRange::OverflowResult::AlwaysOverflowsHigh: { 8156 MatchInfo = [=](MachineIRBuilder &B) { 8157 B.buildSub(Dst, LHS, RHS); 8158 B.buildConstant(Carry, getICmpTrueVal(getTargetLowering(), 8159 /*isVector=*/CarryTy.isVector(), 8160 /*isFP=*/false)); 8161 }; 8162 return true; 8163 } 8164 } 8165 8166 return false; 8167 } 8168