1 //===- AggressiveInstCombine.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the aggressive expression pattern combiner classes. 10 // Currently, it handles expression patterns for: 11 // * Truncate instruction 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" 16 #include "AggressiveInstCombineInternal.h" 17 #include "llvm/ADT/Statistic.h" 18 #include "llvm/Analysis/AliasAnalysis.h" 19 #include "llvm/Analysis/AssumptionCache.h" 20 #include "llvm/Analysis/BasicAliasAnalysis.h" 21 #include "llvm/Analysis/ConstantFolding.h" 22 #include "llvm/Analysis/GlobalsModRef.h" 23 #include "llvm/Analysis/TargetLibraryInfo.h" 24 #include "llvm/Analysis/TargetTransformInfo.h" 25 #include "llvm/Analysis/ValueTracking.h" 26 #include "llvm/IR/DataLayout.h" 27 #include "llvm/IR/Dominators.h" 28 #include "llvm/IR/Function.h" 29 #include "llvm/IR/IRBuilder.h" 30 #include "llvm/IR/PatternMatch.h" 31 #include "llvm/Transforms/Utils/BuildLibCalls.h" 32 #include "llvm/Transforms/Utils/Local.h" 33 34 using namespace llvm; 35 using namespace PatternMatch; 36 37 #define DEBUG_TYPE "aggressive-instcombine" 38 39 STATISTIC(NumAnyOrAllBitsSet, "Number of any/all-bits-set patterns folded"); 40 STATISTIC(NumGuardedRotates, 41 "Number of guarded rotates transformed into funnel shifts"); 42 STATISTIC(NumGuardedFunnelShifts, 43 "Number of guarded funnel shifts transformed into funnel shifts"); 44 STATISTIC(NumPopCountRecognized, "Number of popcount idioms recognized"); 45 46 static cl::opt<unsigned> MaxInstrsToScan( 47 "aggressive-instcombine-max-scan-instrs", cl::init(64), cl::Hidden, 48 cl::desc("Max number of instructions to scan for aggressive instcombine.")); 49 50 /// Match a pattern for a bitwise funnel/rotate operation that partially guards 51 /// against undefined behavior by branching around the funnel-shift/rotation 52 /// when the shift amount is 0. 53 static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) { 54 if (I.getOpcode() != Instruction::PHI || I.getNumOperands() != 2) 55 return false; 56 57 // As with the one-use checks below, this is not strictly necessary, but we 58 // are being cautious to avoid potential perf regressions on targets that 59 // do not actually have a funnel/rotate instruction (where the funnel shift 60 // would be expanded back into math/shift/logic ops). 61 if (!isPowerOf2_32(I.getType()->getScalarSizeInBits())) 62 return false; 63 64 // Match V to funnel shift left/right and capture the source operands and 65 // shift amount. 66 auto matchFunnelShift = [](Value *V, Value *&ShVal0, Value *&ShVal1, 67 Value *&ShAmt) { 68 unsigned Width = V->getType()->getScalarSizeInBits(); 69 70 // fshl(ShVal0, ShVal1, ShAmt) 71 // == (ShVal0 << ShAmt) | (ShVal1 >> (Width -ShAmt)) 72 if (match(V, m_OneUse(m_c_Or( 73 m_Shl(m_Value(ShVal0), m_Value(ShAmt)), 74 m_LShr(m_Value(ShVal1), 75 m_Sub(m_SpecificInt(Width), m_Deferred(ShAmt))))))) { 76 return Intrinsic::fshl; 77 } 78 79 // fshr(ShVal0, ShVal1, ShAmt) 80 // == (ShVal0 >> ShAmt) | (ShVal1 << (Width - ShAmt)) 81 if (match(V, 82 m_OneUse(m_c_Or(m_Shl(m_Value(ShVal0), m_Sub(m_SpecificInt(Width), 83 m_Value(ShAmt))), 84 m_LShr(m_Value(ShVal1), m_Deferred(ShAmt)))))) { 85 return Intrinsic::fshr; 86 } 87 88 return Intrinsic::not_intrinsic; 89 }; 90 91 // One phi operand must be a funnel/rotate operation, and the other phi 92 // operand must be the source value of that funnel/rotate operation: 93 // phi [ rotate(RotSrc, ShAmt), FunnelBB ], [ RotSrc, GuardBB ] 94 // phi [ fshl(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal0, GuardBB ] 95 // phi [ fshr(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal1, GuardBB ] 96 PHINode &Phi = cast<PHINode>(I); 97 unsigned FunnelOp = 0, GuardOp = 1; 98 Value *P0 = Phi.getOperand(0), *P1 = Phi.getOperand(1); 99 Value *ShVal0, *ShVal1, *ShAmt; 100 Intrinsic::ID IID = matchFunnelShift(P0, ShVal0, ShVal1, ShAmt); 101 if (IID == Intrinsic::not_intrinsic || 102 (IID == Intrinsic::fshl && ShVal0 != P1) || 103 (IID == Intrinsic::fshr && ShVal1 != P1)) { 104 IID = matchFunnelShift(P1, ShVal0, ShVal1, ShAmt); 105 if (IID == Intrinsic::not_intrinsic || 106 (IID == Intrinsic::fshl && ShVal0 != P0) || 107 (IID == Intrinsic::fshr && ShVal1 != P0)) 108 return false; 109 assert((IID == Intrinsic::fshl || IID == Intrinsic::fshr) && 110 "Pattern must match funnel shift left or right"); 111 std::swap(FunnelOp, GuardOp); 112 } 113 114 // The incoming block with our source operand must be the "guard" block. 115 // That must contain a cmp+branch to avoid the funnel/rotate when the shift 116 // amount is equal to 0. The other incoming block is the block with the 117 // funnel/rotate. 118 BasicBlock *GuardBB = Phi.getIncomingBlock(GuardOp); 119 BasicBlock *FunnelBB = Phi.getIncomingBlock(FunnelOp); 120 Instruction *TermI = GuardBB->getTerminator(); 121 122 // Ensure that the shift values dominate each block. 123 if (!DT.dominates(ShVal0, TermI) || !DT.dominates(ShVal1, TermI)) 124 return false; 125 126 ICmpInst::Predicate Pred; 127 BasicBlock *PhiBB = Phi.getParent(); 128 if (!match(TermI, m_Br(m_ICmp(Pred, m_Specific(ShAmt), m_ZeroInt()), 129 m_SpecificBB(PhiBB), m_SpecificBB(FunnelBB)))) 130 return false; 131 132 if (Pred != CmpInst::ICMP_EQ) 133 return false; 134 135 IRBuilder<> Builder(PhiBB, PhiBB->getFirstInsertionPt()); 136 137 if (ShVal0 == ShVal1) 138 ++NumGuardedRotates; 139 else 140 ++NumGuardedFunnelShifts; 141 142 // If this is not a rotate then the select was blocking poison from the 143 // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it. 144 bool IsFshl = IID == Intrinsic::fshl; 145 if (ShVal0 != ShVal1) { 146 if (IsFshl && !llvm::isGuaranteedNotToBePoison(ShVal1)) 147 ShVal1 = Builder.CreateFreeze(ShVal1); 148 else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(ShVal0)) 149 ShVal0 = Builder.CreateFreeze(ShVal0); 150 } 151 152 // We matched a variation of this IR pattern: 153 // GuardBB: 154 // %cmp = icmp eq i32 %ShAmt, 0 155 // br i1 %cmp, label %PhiBB, label %FunnelBB 156 // FunnelBB: 157 // %sub = sub i32 32, %ShAmt 158 // %shr = lshr i32 %ShVal1, %sub 159 // %shl = shl i32 %ShVal0, %ShAmt 160 // %fsh = or i32 %shr, %shl 161 // br label %PhiBB 162 // PhiBB: 163 // %cond = phi i32 [ %fsh, %FunnelBB ], [ %ShVal0, %GuardBB ] 164 // --> 165 // llvm.fshl.i32(i32 %ShVal0, i32 %ShVal1, i32 %ShAmt) 166 Function *F = Intrinsic::getDeclaration(Phi.getModule(), IID, Phi.getType()); 167 Phi.replaceAllUsesWith(Builder.CreateCall(F, {ShVal0, ShVal1, ShAmt})); 168 return true; 169 } 170 171 /// This is used by foldAnyOrAllBitsSet() to capture a source value (Root) and 172 /// the bit indexes (Mask) needed by a masked compare. If we're matching a chain 173 /// of 'and' ops, then we also need to capture the fact that we saw an 174 /// "and X, 1", so that's an extra return value for that case. 175 struct MaskOps { 176 Value *Root = nullptr; 177 APInt Mask; 178 bool MatchAndChain; 179 bool FoundAnd1 = false; 180 181 MaskOps(unsigned BitWidth, bool MatchAnds) 182 : Mask(APInt::getZero(BitWidth)), MatchAndChain(MatchAnds) {} 183 }; 184 185 /// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a 186 /// chain of 'and' or 'or' instructions looking for shift ops of a common source 187 /// value. Examples: 188 /// or (or (or X, (X >> 3)), (X >> 5)), (X >> 8) 189 /// returns { X, 0x129 } 190 /// and (and (X >> 1), 1), (X >> 4) 191 /// returns { X, 0x12 } 192 static bool matchAndOrChain(Value *V, MaskOps &MOps) { 193 Value *Op0, *Op1; 194 if (MOps.MatchAndChain) { 195 // Recurse through a chain of 'and' operands. This requires an extra check 196 // vs. the 'or' matcher: we must find an "and X, 1" instruction somewhere 197 // in the chain to know that all of the high bits are cleared. 198 if (match(V, m_And(m_Value(Op0), m_One()))) { 199 MOps.FoundAnd1 = true; 200 return matchAndOrChain(Op0, MOps); 201 } 202 if (match(V, m_And(m_Value(Op0), m_Value(Op1)))) 203 return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps); 204 } else { 205 // Recurse through a chain of 'or' operands. 206 if (match(V, m_Or(m_Value(Op0), m_Value(Op1)))) 207 return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps); 208 } 209 210 // We need a shift-right or a bare value representing a compare of bit 0 of 211 // the original source operand. 212 Value *Candidate; 213 const APInt *BitIndex = nullptr; 214 if (!match(V, m_LShr(m_Value(Candidate), m_APInt(BitIndex)))) 215 Candidate = V; 216 217 // Initialize result source operand. 218 if (!MOps.Root) 219 MOps.Root = Candidate; 220 221 // The shift constant is out-of-range? This code hasn't been simplified. 222 if (BitIndex && BitIndex->uge(MOps.Mask.getBitWidth())) 223 return false; 224 225 // Fill in the mask bit derived from the shift constant. 226 MOps.Mask.setBit(BitIndex ? BitIndex->getZExtValue() : 0); 227 return MOps.Root == Candidate; 228 } 229 230 /// Match patterns that correspond to "any-bits-set" and "all-bits-set". 231 /// These will include a chain of 'or' or 'and'-shifted bits from a 232 /// common source value: 233 /// and (or (lshr X, C), ...), 1 --> (X & CMask) != 0 234 /// and (and (lshr X, C), ...), 1 --> (X & CMask) == CMask 235 /// Note: "any-bits-clear" and "all-bits-clear" are variations of these patterns 236 /// that differ only with a final 'not' of the result. We expect that final 237 /// 'not' to be folded with the compare that we create here (invert predicate). 238 static bool foldAnyOrAllBitsSet(Instruction &I) { 239 // The 'any-bits-set' ('or' chain) pattern is simpler to match because the 240 // final "and X, 1" instruction must be the final op in the sequence. 241 bool MatchAllBitsSet; 242 if (match(&I, m_c_And(m_OneUse(m_And(m_Value(), m_Value())), m_Value()))) 243 MatchAllBitsSet = true; 244 else if (match(&I, m_And(m_OneUse(m_Or(m_Value(), m_Value())), m_One()))) 245 MatchAllBitsSet = false; 246 else 247 return false; 248 249 MaskOps MOps(I.getType()->getScalarSizeInBits(), MatchAllBitsSet); 250 if (MatchAllBitsSet) { 251 if (!matchAndOrChain(cast<BinaryOperator>(&I), MOps) || !MOps.FoundAnd1) 252 return false; 253 } else { 254 if (!matchAndOrChain(cast<BinaryOperator>(&I)->getOperand(0), MOps)) 255 return false; 256 } 257 258 // The pattern was found. Create a masked compare that replaces all of the 259 // shift and logic ops. 260 IRBuilder<> Builder(&I); 261 Constant *Mask = ConstantInt::get(I.getType(), MOps.Mask); 262 Value *And = Builder.CreateAnd(MOps.Root, Mask); 263 Value *Cmp = MatchAllBitsSet ? Builder.CreateICmpEQ(And, Mask) 264 : Builder.CreateIsNotNull(And); 265 Value *Zext = Builder.CreateZExt(Cmp, I.getType()); 266 I.replaceAllUsesWith(Zext); 267 ++NumAnyOrAllBitsSet; 268 return true; 269 } 270 271 // Try to recognize below function as popcount intrinsic. 272 // This is the "best" algorithm from 273 // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel 274 // Also used in TargetLowering::expandCTPOP(). 275 // 276 // int popcount(unsigned int i) { 277 // i = i - ((i >> 1) & 0x55555555); 278 // i = (i & 0x33333333) + ((i >> 2) & 0x33333333); 279 // i = ((i + (i >> 4)) & 0x0F0F0F0F); 280 // return (i * 0x01010101) >> 24; 281 // } 282 static bool tryToRecognizePopCount(Instruction &I) { 283 if (I.getOpcode() != Instruction::LShr) 284 return false; 285 286 Type *Ty = I.getType(); 287 if (!Ty->isIntOrIntVectorTy()) 288 return false; 289 290 unsigned Len = Ty->getScalarSizeInBits(); 291 // FIXME: fix Len == 8 and other irregular type lengths. 292 if (!(Len <= 128 && Len > 8 && Len % 8 == 0)) 293 return false; 294 295 APInt Mask55 = APInt::getSplat(Len, APInt(8, 0x55)); 296 APInt Mask33 = APInt::getSplat(Len, APInt(8, 0x33)); 297 APInt Mask0F = APInt::getSplat(Len, APInt(8, 0x0F)); 298 APInt Mask01 = APInt::getSplat(Len, APInt(8, 0x01)); 299 APInt MaskShift = APInt(Len, Len - 8); 300 301 Value *Op0 = I.getOperand(0); 302 Value *Op1 = I.getOperand(1); 303 Value *MulOp0; 304 // Matching "(i * 0x01010101...) >> 24". 305 if ((match(Op0, m_Mul(m_Value(MulOp0), m_SpecificInt(Mask01)))) && 306 match(Op1, m_SpecificInt(MaskShift))) { 307 Value *ShiftOp0; 308 // Matching "((i + (i >> 4)) & 0x0F0F0F0F...)". 309 if (match(MulOp0, m_And(m_c_Add(m_LShr(m_Value(ShiftOp0), m_SpecificInt(4)), 310 m_Deferred(ShiftOp0)), 311 m_SpecificInt(Mask0F)))) { 312 Value *AndOp0; 313 // Matching "(i & 0x33333333...) + ((i >> 2) & 0x33333333...)". 314 if (match(ShiftOp0, 315 m_c_Add(m_And(m_Value(AndOp0), m_SpecificInt(Mask33)), 316 m_And(m_LShr(m_Deferred(AndOp0), m_SpecificInt(2)), 317 m_SpecificInt(Mask33))))) { 318 Value *Root, *SubOp1; 319 // Matching "i - ((i >> 1) & 0x55555555...)". 320 if (match(AndOp0, m_Sub(m_Value(Root), m_Value(SubOp1))) && 321 match(SubOp1, m_And(m_LShr(m_Specific(Root), m_SpecificInt(1)), 322 m_SpecificInt(Mask55)))) { 323 LLVM_DEBUG(dbgs() << "Recognized popcount intrinsic\n"); 324 IRBuilder<> Builder(&I); 325 Function *Func = Intrinsic::getDeclaration( 326 I.getModule(), Intrinsic::ctpop, I.getType()); 327 I.replaceAllUsesWith(Builder.CreateCall(Func, {Root})); 328 ++NumPopCountRecognized; 329 return true; 330 } 331 } 332 } 333 } 334 335 return false; 336 } 337 338 /// Fold smin(smax(fptosi(x), C1), C2) to llvm.fptosi.sat(x), providing C1 and 339 /// C2 saturate the value of the fp conversion. The transform is not reversable 340 /// as the fptosi.sat is more defined than the input - all values produce a 341 /// valid value for the fptosi.sat, where as some produce poison for original 342 /// that were out of range of the integer conversion. The reversed pattern may 343 /// use fmax and fmin instead. As we cannot directly reverse the transform, and 344 /// it is not always profitable, we make it conditional on the cost being 345 /// reported as lower by TTI. 346 static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) { 347 // Look for min(max(fptosi, converting to fptosi_sat. 348 Value *In; 349 const APInt *MinC, *MaxC; 350 if (!match(&I, m_SMax(m_OneUse(m_SMin(m_OneUse(m_FPToSI(m_Value(In))), 351 m_APInt(MinC))), 352 m_APInt(MaxC))) && 353 !match(&I, m_SMin(m_OneUse(m_SMax(m_OneUse(m_FPToSI(m_Value(In))), 354 m_APInt(MaxC))), 355 m_APInt(MinC)))) 356 return false; 357 358 // Check that the constants clamp a saturate. 359 if (!(*MinC + 1).isPowerOf2() || -*MaxC != *MinC + 1) 360 return false; 361 362 Type *IntTy = I.getType(); 363 Type *FpTy = In->getType(); 364 Type *SatTy = 365 IntegerType::get(IntTy->getContext(), (*MinC + 1).exactLogBase2() + 1); 366 if (auto *VecTy = dyn_cast<VectorType>(IntTy)) 367 SatTy = VectorType::get(SatTy, VecTy->getElementCount()); 368 369 // Get the cost of the intrinsic, and check that against the cost of 370 // fptosi+smin+smax 371 InstructionCost SatCost = TTI.getIntrinsicInstrCost( 372 IntrinsicCostAttributes(Intrinsic::fptosi_sat, SatTy, {In}, {FpTy}), 373 TTI::TCK_RecipThroughput); 374 SatCost += TTI.getCastInstrCost(Instruction::SExt, SatTy, IntTy, 375 TTI::CastContextHint::None, 376 TTI::TCK_RecipThroughput); 377 378 InstructionCost MinMaxCost = TTI.getCastInstrCost( 379 Instruction::FPToSI, IntTy, FpTy, TTI::CastContextHint::None, 380 TTI::TCK_RecipThroughput); 381 MinMaxCost += TTI.getIntrinsicInstrCost( 382 IntrinsicCostAttributes(Intrinsic::smin, IntTy, {IntTy}), 383 TTI::TCK_RecipThroughput); 384 MinMaxCost += TTI.getIntrinsicInstrCost( 385 IntrinsicCostAttributes(Intrinsic::smax, IntTy, {IntTy}), 386 TTI::TCK_RecipThroughput); 387 388 if (SatCost >= MinMaxCost) 389 return false; 390 391 IRBuilder<> Builder(&I); 392 Function *Fn = Intrinsic::getDeclaration(I.getModule(), Intrinsic::fptosi_sat, 393 {SatTy, FpTy}); 394 Value *Sat = Builder.CreateCall(Fn, In); 395 I.replaceAllUsesWith(Builder.CreateSExt(Sat, IntTy)); 396 return true; 397 } 398 399 /// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids 400 /// pessimistic codegen that has to account for setting errno and can enable 401 /// vectorization. 402 static bool foldSqrt(Instruction &I, TargetTransformInfo &TTI, 403 TargetLibraryInfo &TLI, AssumptionCache &AC, 404 DominatorTree &DT) { 405 // Match a call to sqrt mathlib function. 406 auto *Call = dyn_cast<CallInst>(&I); 407 if (!Call) 408 return false; 409 410 Module *M = Call->getModule(); 411 LibFunc Func; 412 if (!TLI.getLibFunc(*Call, Func) || !isLibFuncEmittable(M, &TLI, Func)) 413 return false; 414 415 if (Func != LibFunc_sqrt && Func != LibFunc_sqrtf && Func != LibFunc_sqrtl) 416 return false; 417 418 // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created 419 // (because NNAN or the operand arg must not be less than -0.0) and (2) we 420 // would not end up lowering to a libcall anyway (which could change the value 421 // of errno), then: 422 // (1) errno won't be set. 423 // (2) it is safe to convert this to an intrinsic call. 424 Type *Ty = Call->getType(); 425 Value *Arg = Call->getArgOperand(0); 426 if (TTI.haveFastSqrt(Ty) && 427 (Call->hasNoNaNs() || 428 cannotBeOrderedLessThanZero(Arg, M->getDataLayout(), &TLI, 0, &AC, &I, 429 &DT))) { 430 IRBuilder<> Builder(&I); 431 IRBuilderBase::FastMathFlagGuard Guard(Builder); 432 Builder.setFastMathFlags(Call->getFastMathFlags()); 433 434 Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty); 435 Value *NewSqrt = Builder.CreateCall(Sqrt, Arg, "sqrt"); 436 I.replaceAllUsesWith(NewSqrt); 437 438 // Explicitly erase the old call because a call with side effects is not 439 // trivially dead. 440 I.eraseFromParent(); 441 return true; 442 } 443 444 return false; 445 } 446 447 // Check if this array of constants represents a cttz table. 448 // Iterate over the elements from \p Table by trying to find/match all 449 // the numbers from 0 to \p InputBits that should represent cttz results. 450 static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul, 451 uint64_t Shift, uint64_t InputBits) { 452 unsigned Length = Table.getNumElements(); 453 if (Length < InputBits || Length > InputBits * 2) 454 return false; 455 456 APInt Mask = APInt::getBitsSetFrom(InputBits, Shift); 457 unsigned Matched = 0; 458 459 for (unsigned i = 0; i < Length; i++) { 460 uint64_t Element = Table.getElementAsInteger(i); 461 if (Element >= InputBits) 462 continue; 463 464 // Check if \p Element matches a concrete answer. It could fail for some 465 // elements that are never accessed, so we keep iterating over each element 466 // from the table. The number of matched elements should be equal to the 467 // number of potential right answers which is \p InputBits actually. 468 if ((((Mul << Element) & Mask.getZExtValue()) >> Shift) == i) 469 Matched++; 470 } 471 472 return Matched == InputBits; 473 } 474 475 // Try to recognize table-based ctz implementation. 476 // E.g., an example in C (for more cases please see the llvm/tests): 477 // int f(unsigned x) { 478 // static const char table[32] = 479 // {0, 1, 28, 2, 29, 14, 24, 3, 30, 480 // 22, 20, 15, 25, 17, 4, 8, 31, 27, 481 // 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9}; 482 // return table[((unsigned)((x & -x) * 0x077CB531U)) >> 27]; 483 // } 484 // this can be lowered to `cttz` instruction. 485 // There is also a special case when the element is 0. 486 // 487 // Here are some examples or LLVM IR for a 64-bit target: 488 // 489 // CASE 1: 490 // %sub = sub i32 0, %x 491 // %and = and i32 %sub, %x 492 // %mul = mul i32 %and, 125613361 493 // %shr = lshr i32 %mul, 27 494 // %idxprom = zext i32 %shr to i64 495 // %arrayidx = getelementptr inbounds [32 x i8], [32 x i8]* @ctz1.table, i64 0, 496 // i64 %idxprom %0 = load i8, i8* %arrayidx, align 1, !tbaa !8 497 // 498 // CASE 2: 499 // %sub = sub i32 0, %x 500 // %and = and i32 %sub, %x 501 // %mul = mul i32 %and, 72416175 502 // %shr = lshr i32 %mul, 26 503 // %idxprom = zext i32 %shr to i64 504 // %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table, i64 505 // 0, i64 %idxprom %0 = load i16, i16* %arrayidx, align 2, !tbaa !8 506 // 507 // CASE 3: 508 // %sub = sub i32 0, %x 509 // %and = and i32 %sub, %x 510 // %mul = mul i32 %and, 81224991 511 // %shr = lshr i32 %mul, 27 512 // %idxprom = zext i32 %shr to i64 513 // %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table, i64 514 // 0, i64 %idxprom %0 = load i32, i32* %arrayidx, align 4, !tbaa !8 515 // 516 // CASE 4: 517 // %sub = sub i64 0, %x 518 // %and = and i64 %sub, %x 519 // %mul = mul i64 %and, 283881067100198605 520 // %shr = lshr i64 %mul, 58 521 // %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0, i64 522 // %shr %0 = load i8, i8* %arrayidx, align 1, !tbaa !8 523 // 524 // All this can be lowered to @llvm.cttz.i32/64 intrinsic. 525 static bool tryToRecognizeTableBasedCttz(Instruction &I) { 526 LoadInst *LI = dyn_cast<LoadInst>(&I); 527 if (!LI) 528 return false; 529 530 Type *AccessType = LI->getType(); 531 if (!AccessType->isIntegerTy()) 532 return false; 533 534 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand()); 535 if (!GEP || !GEP->isInBounds() || GEP->getNumIndices() != 2) 536 return false; 537 538 if (!GEP->getSourceElementType()->isArrayTy()) 539 return false; 540 541 uint64_t ArraySize = GEP->getSourceElementType()->getArrayNumElements(); 542 if (ArraySize != 32 && ArraySize != 64) 543 return false; 544 545 GlobalVariable *GVTable = dyn_cast<GlobalVariable>(GEP->getPointerOperand()); 546 if (!GVTable || !GVTable->hasInitializer() || !GVTable->isConstant()) 547 return false; 548 549 ConstantDataArray *ConstData = 550 dyn_cast<ConstantDataArray>(GVTable->getInitializer()); 551 if (!ConstData) 552 return false; 553 554 if (!match(GEP->idx_begin()->get(), m_ZeroInt())) 555 return false; 556 557 Value *Idx2 = std::next(GEP->idx_begin())->get(); 558 Value *X1; 559 uint64_t MulConst, ShiftConst; 560 // FIXME: 64-bit targets have `i64` type for the GEP index, so this match will 561 // probably fail for other (e.g. 32-bit) targets. 562 if (!match(Idx2, m_ZExtOrSelf( 563 m_LShr(m_Mul(m_c_And(m_Neg(m_Value(X1)), m_Deferred(X1)), 564 m_ConstantInt(MulConst)), 565 m_ConstantInt(ShiftConst))))) 566 return false; 567 568 unsigned InputBits = X1->getType()->getScalarSizeInBits(); 569 if (InputBits != 32 && InputBits != 64) 570 return false; 571 572 // Shift should extract top 5..7 bits. 573 if (InputBits - Log2_32(InputBits) != ShiftConst && 574 InputBits - Log2_32(InputBits) - 1 != ShiftConst) 575 return false; 576 577 if (!isCTTZTable(*ConstData, MulConst, ShiftConst, InputBits)) 578 return false; 579 580 auto ZeroTableElem = ConstData->getElementAsInteger(0); 581 bool DefinedForZero = ZeroTableElem == InputBits; 582 583 IRBuilder<> B(LI); 584 ConstantInt *BoolConst = B.getInt1(!DefinedForZero); 585 Type *XType = X1->getType(); 586 auto Cttz = B.CreateIntrinsic(Intrinsic::cttz, {XType}, {X1, BoolConst}); 587 Value *ZExtOrTrunc = nullptr; 588 589 if (DefinedForZero) { 590 ZExtOrTrunc = B.CreateZExtOrTrunc(Cttz, AccessType); 591 } else { 592 // If the value in elem 0 isn't the same as InputBits, we still want to 593 // produce the value from the table. 594 auto Cmp = B.CreateICmpEQ(X1, ConstantInt::get(XType, 0)); 595 auto Select = 596 B.CreateSelect(Cmp, ConstantInt::get(XType, ZeroTableElem), Cttz); 597 598 // NOTE: If the table[0] is 0, but the cttz(0) is defined by the Target 599 // it should be handled as: `cttz(x) & (typeSize - 1)`. 600 601 ZExtOrTrunc = B.CreateZExtOrTrunc(Select, AccessType); 602 } 603 604 LI->replaceAllUsesWith(ZExtOrTrunc); 605 606 return true; 607 } 608 609 /// This is used by foldLoadsRecursive() to capture a Root Load node which is 610 /// of type or(load, load) and recursively build the wide load. Also capture the 611 /// shift amount, zero extend type and loadSize. 612 struct LoadOps { 613 LoadInst *Root = nullptr; 614 LoadInst *RootInsert = nullptr; 615 bool FoundRoot = false; 616 uint64_t LoadSize = 0; 617 const APInt *Shift = nullptr; 618 Type *ZextType; 619 AAMDNodes AATags; 620 }; 621 622 // Identify and Merge consecutive loads recursively which is of the form 623 // (ZExt(L1) << shift1) | (ZExt(L2) << shift2) -> ZExt(L3) << shift1 624 // (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3) 625 static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL, 626 AliasAnalysis &AA) { 627 const APInt *ShAmt2 = nullptr; 628 Value *X; 629 Instruction *L1, *L2; 630 631 // Go to the last node with loads. 632 if (match(V, m_OneUse(m_c_Or( 633 m_Value(X), 634 m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))), 635 m_APInt(ShAmt2)))))) || 636 match(V, m_OneUse(m_Or(m_Value(X), 637 m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) { 638 if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot) 639 // Avoid Partial chain merge. 640 return false; 641 } else 642 return false; 643 644 // Check if the pattern has loads 645 LoadInst *LI1 = LOps.Root; 646 const APInt *ShAmt1 = LOps.Shift; 647 if (LOps.FoundRoot == false && 648 (match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) || 649 match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))), 650 m_APInt(ShAmt1)))))) { 651 LI1 = dyn_cast<LoadInst>(L1); 652 } 653 LoadInst *LI2 = dyn_cast<LoadInst>(L2); 654 655 // Check if loads are same, atomic, volatile and having same address space. 656 if (LI1 == LI2 || !LI1 || !LI2 || !LI1->isSimple() || !LI2->isSimple() || 657 LI1->getPointerAddressSpace() != LI2->getPointerAddressSpace()) 658 return false; 659 660 // Check if Loads come from same BB. 661 if (LI1->getParent() != LI2->getParent()) 662 return false; 663 664 // Find the data layout 665 bool IsBigEndian = DL.isBigEndian(); 666 667 // Check if loads are consecutive and same size. 668 Value *Load1Ptr = LI1->getPointerOperand(); 669 APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0); 670 Load1Ptr = 671 Load1Ptr->stripAndAccumulateConstantOffsets(DL, Offset1, 672 /* AllowNonInbounds */ true); 673 674 Value *Load2Ptr = LI2->getPointerOperand(); 675 APInt Offset2(DL.getIndexTypeSizeInBits(Load2Ptr->getType()), 0); 676 Load2Ptr = 677 Load2Ptr->stripAndAccumulateConstantOffsets(DL, Offset2, 678 /* AllowNonInbounds */ true); 679 680 // Verify if both loads have same base pointers and load sizes are same. 681 uint64_t LoadSize1 = LI1->getType()->getPrimitiveSizeInBits(); 682 uint64_t LoadSize2 = LI2->getType()->getPrimitiveSizeInBits(); 683 if (Load1Ptr != Load2Ptr || LoadSize1 != LoadSize2) 684 return false; 685 686 // Support Loadsizes greater or equal to 8bits and only power of 2. 687 if (LoadSize1 < 8 || !isPowerOf2_64(LoadSize1)) 688 return false; 689 690 // Alias Analysis to check for stores b/w the loads. 691 LoadInst *Start = LOps.FoundRoot ? LOps.RootInsert : LI1, *End = LI2; 692 MemoryLocation Loc; 693 if (!Start->comesBefore(End)) { 694 std::swap(Start, End); 695 Loc = MemoryLocation::get(End); 696 if (LOps.FoundRoot) 697 Loc = Loc.getWithNewSize(LOps.LoadSize); 698 } else 699 Loc = MemoryLocation::get(End); 700 unsigned NumScanned = 0; 701 for (Instruction &Inst : 702 make_range(Start->getIterator(), End->getIterator())) { 703 if (Inst.mayWriteToMemory() && isModSet(AA.getModRefInfo(&Inst, Loc))) 704 return false; 705 if (++NumScanned > MaxInstrsToScan) 706 return false; 707 } 708 709 // Make sure Load with lower Offset is at LI1 710 bool Reverse = false; 711 if (Offset2.slt(Offset1)) { 712 std::swap(LI1, LI2); 713 std::swap(ShAmt1, ShAmt2); 714 std::swap(Offset1, Offset2); 715 std::swap(Load1Ptr, Load2Ptr); 716 std::swap(LoadSize1, LoadSize2); 717 Reverse = true; 718 } 719 720 // Big endian swap the shifts 721 if (IsBigEndian) 722 std::swap(ShAmt1, ShAmt2); 723 724 // Find Shifts values. 725 uint64_t Shift1 = 0, Shift2 = 0; 726 if (ShAmt1) 727 Shift1 = ShAmt1->getZExtValue(); 728 if (ShAmt2) 729 Shift2 = ShAmt2->getZExtValue(); 730 731 // First load is always LI1. This is where we put the new load. 732 // Use the merged load size available from LI1 for forward loads. 733 if (LOps.FoundRoot) { 734 if (!Reverse) 735 LoadSize1 = LOps.LoadSize; 736 else 737 LoadSize2 = LOps.LoadSize; 738 } 739 740 // Verify if shift amount and load index aligns and verifies that loads 741 // are consecutive. 742 uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1; 743 uint64_t PrevSize = 744 DL.getTypeStoreSize(IntegerType::get(LI1->getContext(), LoadSize1)); 745 if ((Shift2 - Shift1) != ShiftDiff || (Offset2 - Offset1) != PrevSize) 746 return false; 747 748 // Update LOps 749 AAMDNodes AATags1 = LOps.AATags; 750 AAMDNodes AATags2 = LI2->getAAMetadata(); 751 if (LOps.FoundRoot == false) { 752 LOps.FoundRoot = true; 753 AATags1 = LI1->getAAMetadata(); 754 } 755 LOps.LoadSize = LoadSize1 + LoadSize2; 756 LOps.RootInsert = Start; 757 758 // Concatenate the AATags of the Merged Loads. 759 LOps.AATags = AATags1.concat(AATags2); 760 761 LOps.Root = LI1; 762 LOps.Shift = ShAmt1; 763 LOps.ZextType = X->getType(); 764 return true; 765 } 766 767 // For a given BB instruction, evaluate all loads in the chain that form a 768 // pattern which suggests that the loads can be combined. The one and only use 769 // of the loads is to form a wider load. 770 static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL, 771 TargetTransformInfo &TTI, AliasAnalysis &AA, 772 const DominatorTree &DT) { 773 // Only consider load chains of scalar values. 774 if (isa<VectorType>(I.getType())) 775 return false; 776 777 LoadOps LOps; 778 if (!foldLoadsRecursive(&I, LOps, DL, AA) || !LOps.FoundRoot) 779 return false; 780 781 IRBuilder<> Builder(&I); 782 LoadInst *NewLoad = nullptr, *LI1 = LOps.Root; 783 784 IntegerType *WiderType = IntegerType::get(I.getContext(), LOps.LoadSize); 785 // TTI based checks if we want to proceed with wider load 786 bool Allowed = TTI.isTypeLegal(WiderType); 787 if (!Allowed) 788 return false; 789 790 unsigned AS = LI1->getPointerAddressSpace(); 791 unsigned Fast = 0; 792 Allowed = TTI.allowsMisalignedMemoryAccesses(I.getContext(), LOps.LoadSize, 793 AS, LI1->getAlign(), &Fast); 794 if (!Allowed || !Fast) 795 return false; 796 797 // Get the Index and Ptr for the new GEP. 798 Value *Load1Ptr = LI1->getPointerOperand(); 799 Builder.SetInsertPoint(LOps.RootInsert); 800 if (!DT.dominates(Load1Ptr, LOps.RootInsert)) { 801 APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0); 802 Load1Ptr = Load1Ptr->stripAndAccumulateConstantOffsets( 803 DL, Offset1, /* AllowNonInbounds */ true); 804 Load1Ptr = Builder.CreateGEP(Builder.getInt8Ty(), Load1Ptr, 805 Builder.getInt32(Offset1.getZExtValue())); 806 } 807 // Generate wider load. 808 NewLoad = Builder.CreateAlignedLoad(WiderType, Load1Ptr, LI1->getAlign(), 809 LI1->isVolatile(), ""); 810 NewLoad->takeName(LI1); 811 // Set the New Load AATags Metadata. 812 if (LOps.AATags) 813 NewLoad->setAAMetadata(LOps.AATags); 814 815 Value *NewOp = NewLoad; 816 // Check if zero extend needed. 817 if (LOps.ZextType) 818 NewOp = Builder.CreateZExt(NewOp, LOps.ZextType); 819 820 // Check if shift needed. We need to shift with the amount of load1 821 // shift if not zero. 822 if (LOps.Shift) 823 NewOp = Builder.CreateShl(NewOp, ConstantInt::get(I.getContext(), *LOps.Shift)); 824 I.replaceAllUsesWith(NewOp); 825 826 return true; 827 } 828 829 // Calculate GEP Stride and accumulated const ModOffset. Return Stride and 830 // ModOffset 831 static std::pair<APInt, APInt> 832 getStrideAndModOffsetOfGEP(Value *PtrOp, const DataLayout &DL) { 833 unsigned BW = DL.getIndexTypeSizeInBits(PtrOp->getType()); 834 std::optional<APInt> Stride; 835 APInt ModOffset(BW, 0); 836 // Return a minimum gep stride, greatest common divisor of consective gep 837 // index scales(c.f. Bézout's identity). 838 while (auto *GEP = dyn_cast<GEPOperator>(PtrOp)) { 839 MapVector<Value *, APInt> VarOffsets; 840 if (!GEP->collectOffset(DL, BW, VarOffsets, ModOffset)) 841 break; 842 843 for (auto [V, Scale] : VarOffsets) { 844 // Only keep a power of two factor for non-inbounds 845 if (!GEP->isInBounds()) 846 Scale = APInt::getOneBitSet(Scale.getBitWidth(), Scale.countr_zero()); 847 848 if (!Stride) 849 Stride = Scale; 850 else 851 Stride = APIntOps::GreatestCommonDivisor(*Stride, Scale); 852 } 853 854 PtrOp = GEP->getPointerOperand(); 855 } 856 857 // Check whether pointer arrives back at Global Variable via at least one GEP. 858 // Even if it doesn't, we can check by alignment. 859 if (!isa<GlobalVariable>(PtrOp) || !Stride) 860 return {APInt(BW, 1), APInt(BW, 0)}; 861 862 // In consideration of signed GEP indices, non-negligible offset become 863 // remainder of division by minimum GEP stride. 864 ModOffset = ModOffset.srem(*Stride); 865 if (ModOffset.isNegative()) 866 ModOffset += *Stride; 867 868 return {*Stride, ModOffset}; 869 } 870 871 /// If C is a constant patterned array and all valid loaded results for given 872 /// alignment are same to a constant, return that constant. 873 static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) { 874 auto *LI = dyn_cast<LoadInst>(&I); 875 if (!LI || LI->isVolatile()) 876 return false; 877 878 // We can only fold the load if it is from a constant global with definitive 879 // initializer. Skip expensive logic if this is not the case. 880 auto *PtrOp = LI->getPointerOperand(); 881 auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(PtrOp)); 882 if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) 883 return false; 884 885 // Bail for large initializers in excess of 4K to avoid too many scans. 886 Constant *C = GV->getInitializer(); 887 uint64_t GVSize = DL.getTypeAllocSize(C->getType()); 888 if (!GVSize || 4096 < GVSize) 889 return false; 890 891 Type *LoadTy = LI->getType(); 892 unsigned BW = DL.getIndexTypeSizeInBits(PtrOp->getType()); 893 auto [Stride, ConstOffset] = getStrideAndModOffsetOfGEP(PtrOp, DL); 894 895 // Any possible offset could be multiple of GEP stride. And any valid 896 // offset is multiple of load alignment, so checking only multiples of bigger 897 // one is sufficient to say results' equality. 898 if (auto LA = LI->getAlign(); 899 LA <= GV->getAlign().valueOrOne() && Stride.getZExtValue() < LA.value()) { 900 ConstOffset = APInt(BW, 0); 901 Stride = APInt(BW, LA.value()); 902 } 903 904 Constant *Ca = ConstantFoldLoadFromConst(C, LoadTy, ConstOffset, DL); 905 if (!Ca) 906 return false; 907 908 unsigned E = GVSize - DL.getTypeStoreSize(LoadTy); 909 for (; ConstOffset.getZExtValue() <= E; ConstOffset += Stride) 910 if (Ca != ConstantFoldLoadFromConst(C, LoadTy, ConstOffset, DL)) 911 return false; 912 913 I.replaceAllUsesWith(Ca); 914 915 return true; 916 } 917 918 /// This is the entry point for folds that could be implemented in regular 919 /// InstCombine, but they are separated because they are not expected to 920 /// occur frequently and/or have more than a constant-length pattern match. 921 static bool foldUnusualPatterns(Function &F, DominatorTree &DT, 922 TargetTransformInfo &TTI, 923 TargetLibraryInfo &TLI, AliasAnalysis &AA, 924 AssumptionCache &AC) { 925 bool MadeChange = false; 926 for (BasicBlock &BB : F) { 927 // Ignore unreachable basic blocks. 928 if (!DT.isReachableFromEntry(&BB)) 929 continue; 930 931 const DataLayout &DL = F.getParent()->getDataLayout(); 932 933 // Walk the block backwards for efficiency. We're matching a chain of 934 // use->defs, so we're more likely to succeed by starting from the bottom. 935 // Also, we want to avoid matching partial patterns. 936 // TODO: It would be more efficient if we removed dead instructions 937 // iteratively in this loop rather than waiting until the end. 938 for (Instruction &I : make_early_inc_range(llvm::reverse(BB))) { 939 MadeChange |= foldAnyOrAllBitsSet(I); 940 MadeChange |= foldGuardedFunnelShift(I, DT); 941 MadeChange |= tryToRecognizePopCount(I); 942 MadeChange |= tryToFPToSat(I, TTI); 943 MadeChange |= tryToRecognizeTableBasedCttz(I); 944 MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT); 945 MadeChange |= foldPatternedLoads(I, DL); 946 // NOTE: This function introduces erasing of the instruction `I`, so it 947 // needs to be called at the end of this sequence, otherwise we may make 948 // bugs. 949 MadeChange |= foldSqrt(I, TTI, TLI, AC, DT); 950 } 951 } 952 953 // We're done with transforms, so remove dead instructions. 954 if (MadeChange) 955 for (BasicBlock &BB : F) 956 SimplifyInstructionsInBlock(&BB); 957 958 return MadeChange; 959 } 960 961 /// This is the entry point for all transforms. Pass manager differences are 962 /// handled in the callers of this function. 963 static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI, 964 TargetLibraryInfo &TLI, DominatorTree &DT, 965 AliasAnalysis &AA) { 966 bool MadeChange = false; 967 const DataLayout &DL = F.getParent()->getDataLayout(); 968 TruncInstCombine TIC(AC, TLI, DL, DT); 969 MadeChange |= TIC.run(F); 970 MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC); 971 return MadeChange; 972 } 973 974 PreservedAnalyses AggressiveInstCombinePass::run(Function &F, 975 FunctionAnalysisManager &AM) { 976 auto &AC = AM.getResult<AssumptionAnalysis>(F); 977 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); 978 auto &DT = AM.getResult<DominatorTreeAnalysis>(F); 979 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 980 auto &AA = AM.getResult<AAManager>(F); 981 if (!runImpl(F, AC, TTI, TLI, DT, AA)) { 982 // No changes, all analyses are preserved. 983 return PreservedAnalyses::all(); 984 } 985 // Mark all the analyses that instcombine updates as preserved. 986 PreservedAnalyses PA; 987 PA.preserveSet<CFGAnalyses>(); 988 return PA; 989 } 990