1 //===- AggressiveInstCombine.cpp ------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the aggressive expression pattern combiner classes. 10 // Currently, it handles expression patterns for: 11 // * Truncate instruction 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h" 16 #include "AggressiveInstCombineInternal.h" 17 #include "llvm/ADT/Statistic.h" 18 #include "llvm/Analysis/AliasAnalysis.h" 19 #include "llvm/Analysis/AssumptionCache.h" 20 #include "llvm/Analysis/BasicAliasAnalysis.h" 21 #include "llvm/Analysis/ConstantFolding.h" 22 #include "llvm/Analysis/DomTreeUpdater.h" 23 #include "llvm/Analysis/GlobalsModRef.h" 24 #include "llvm/Analysis/TargetLibraryInfo.h" 25 #include "llvm/Analysis/TargetTransformInfo.h" 26 #include "llvm/Analysis/ValueTracking.h" 27 #include "llvm/IR/DataLayout.h" 28 #include "llvm/IR/Dominators.h" 29 #include "llvm/IR/Function.h" 30 #include "llvm/IR/IRBuilder.h" 31 #include "llvm/IR/PatternMatch.h" 32 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 33 #include "llvm/Transforms/Utils/BuildLibCalls.h" 34 #include "llvm/Transforms/Utils/Local.h" 35 36 using namespace llvm; 37 using namespace PatternMatch; 38 39 #define DEBUG_TYPE "aggressive-instcombine" 40 41 STATISTIC(NumAnyOrAllBitsSet, "Number of any/all-bits-set patterns folded"); 42 STATISTIC(NumGuardedRotates, 43 "Number of guarded rotates transformed into funnel shifts"); 44 STATISTIC(NumGuardedFunnelShifts, 45 "Number of guarded funnel shifts transformed into funnel shifts"); 46 STATISTIC(NumPopCountRecognized, "Number of popcount idioms recognized"); 47 48 static cl::opt<unsigned> MaxInstrsToScan( 49 "aggressive-instcombine-max-scan-instrs", cl::init(64), cl::Hidden, 50 cl::desc("Max number of instructions to scan for aggressive instcombine.")); 51 52 static cl::opt<unsigned> StrNCmpInlineThreshold( 53 "strncmp-inline-threshold", cl::init(3), cl::Hidden, 54 cl::desc("The maximum length of a constant string for a builtin string cmp " 55 "call eligible for inlining. The default value is 3.")); 56 57 static cl::opt<unsigned> 58 MemChrInlineThreshold("memchr-inline-threshold", cl::init(3), cl::Hidden, 59 cl::desc("The maximum length of a constant string to " 60 "inline a memchr call.")); 61 62 /// Match a pattern for a bitwise funnel/rotate operation that partially guards 63 /// against undefined behavior by branching around the funnel-shift/rotation 64 /// when the shift amount is 0. 65 static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) { 66 if (I.getOpcode() != Instruction::PHI || I.getNumOperands() != 2) 67 return false; 68 69 // As with the one-use checks below, this is not strictly necessary, but we 70 // are being cautious to avoid potential perf regressions on targets that 71 // do not actually have a funnel/rotate instruction (where the funnel shift 72 // would be expanded back into math/shift/logic ops). 73 if (!isPowerOf2_32(I.getType()->getScalarSizeInBits())) 74 return false; 75 76 // Match V to funnel shift left/right and capture the source operands and 77 // shift amount. 78 auto matchFunnelShift = [](Value *V, Value *&ShVal0, Value *&ShVal1, 79 Value *&ShAmt) { 80 unsigned Width = V->getType()->getScalarSizeInBits(); 81 82 // fshl(ShVal0, ShVal1, ShAmt) 83 // == (ShVal0 << ShAmt) | (ShVal1 >> (Width -ShAmt)) 84 if (match(V, m_OneUse(m_c_Or( 85 m_Shl(m_Value(ShVal0), m_Value(ShAmt)), 86 m_LShr(m_Value(ShVal1), 87 m_Sub(m_SpecificInt(Width), m_Deferred(ShAmt))))))) { 88 return Intrinsic::fshl; 89 } 90 91 // fshr(ShVal0, ShVal1, ShAmt) 92 // == (ShVal0 >> ShAmt) | (ShVal1 << (Width - ShAmt)) 93 if (match(V, 94 m_OneUse(m_c_Or(m_Shl(m_Value(ShVal0), m_Sub(m_SpecificInt(Width), 95 m_Value(ShAmt))), 96 m_LShr(m_Value(ShVal1), m_Deferred(ShAmt)))))) { 97 return Intrinsic::fshr; 98 } 99 100 return Intrinsic::not_intrinsic; 101 }; 102 103 // One phi operand must be a funnel/rotate operation, and the other phi 104 // operand must be the source value of that funnel/rotate operation: 105 // phi [ rotate(RotSrc, ShAmt), FunnelBB ], [ RotSrc, GuardBB ] 106 // phi [ fshl(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal0, GuardBB ] 107 // phi [ fshr(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal1, GuardBB ] 108 PHINode &Phi = cast<PHINode>(I); 109 unsigned FunnelOp = 0, GuardOp = 1; 110 Value *P0 = Phi.getOperand(0), *P1 = Phi.getOperand(1); 111 Value *ShVal0, *ShVal1, *ShAmt; 112 Intrinsic::ID IID = matchFunnelShift(P0, ShVal0, ShVal1, ShAmt); 113 if (IID == Intrinsic::not_intrinsic || 114 (IID == Intrinsic::fshl && ShVal0 != P1) || 115 (IID == Intrinsic::fshr && ShVal1 != P1)) { 116 IID = matchFunnelShift(P1, ShVal0, ShVal1, ShAmt); 117 if (IID == Intrinsic::not_intrinsic || 118 (IID == Intrinsic::fshl && ShVal0 != P0) || 119 (IID == Intrinsic::fshr && ShVal1 != P0)) 120 return false; 121 assert((IID == Intrinsic::fshl || IID == Intrinsic::fshr) && 122 "Pattern must match funnel shift left or right"); 123 std::swap(FunnelOp, GuardOp); 124 } 125 126 // The incoming block with our source operand must be the "guard" block. 127 // That must contain a cmp+branch to avoid the funnel/rotate when the shift 128 // amount is equal to 0. The other incoming block is the block with the 129 // funnel/rotate. 130 BasicBlock *GuardBB = Phi.getIncomingBlock(GuardOp); 131 BasicBlock *FunnelBB = Phi.getIncomingBlock(FunnelOp); 132 Instruction *TermI = GuardBB->getTerminator(); 133 134 // Ensure that the shift values dominate each block. 135 if (!DT.dominates(ShVal0, TermI) || !DT.dominates(ShVal1, TermI)) 136 return false; 137 138 BasicBlock *PhiBB = Phi.getParent(); 139 if (!match(TermI, m_Br(m_SpecificICmp(CmpInst::ICMP_EQ, m_Specific(ShAmt), 140 m_ZeroInt()), 141 m_SpecificBB(PhiBB), m_SpecificBB(FunnelBB)))) 142 return false; 143 144 IRBuilder<> Builder(PhiBB, PhiBB->getFirstInsertionPt()); 145 146 if (ShVal0 == ShVal1) 147 ++NumGuardedRotates; 148 else 149 ++NumGuardedFunnelShifts; 150 151 // If this is not a rotate then the select was blocking poison from the 152 // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it. 153 bool IsFshl = IID == Intrinsic::fshl; 154 if (ShVal0 != ShVal1) { 155 if (IsFshl && !llvm::isGuaranteedNotToBePoison(ShVal1)) 156 ShVal1 = Builder.CreateFreeze(ShVal1); 157 else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(ShVal0)) 158 ShVal0 = Builder.CreateFreeze(ShVal0); 159 } 160 161 // We matched a variation of this IR pattern: 162 // GuardBB: 163 // %cmp = icmp eq i32 %ShAmt, 0 164 // br i1 %cmp, label %PhiBB, label %FunnelBB 165 // FunnelBB: 166 // %sub = sub i32 32, %ShAmt 167 // %shr = lshr i32 %ShVal1, %sub 168 // %shl = shl i32 %ShVal0, %ShAmt 169 // %fsh = or i32 %shr, %shl 170 // br label %PhiBB 171 // PhiBB: 172 // %cond = phi i32 [ %fsh, %FunnelBB ], [ %ShVal0, %GuardBB ] 173 // --> 174 // llvm.fshl.i32(i32 %ShVal0, i32 %ShVal1, i32 %ShAmt) 175 Phi.replaceAllUsesWith( 176 Builder.CreateIntrinsic(IID, Phi.getType(), {ShVal0, ShVal1, ShAmt})); 177 return true; 178 } 179 180 /// This is used by foldAnyOrAllBitsSet() to capture a source value (Root) and 181 /// the bit indexes (Mask) needed by a masked compare. If we're matching a chain 182 /// of 'and' ops, then we also need to capture the fact that we saw an 183 /// "and X, 1", so that's an extra return value for that case. 184 namespace { 185 struct MaskOps { 186 Value *Root = nullptr; 187 APInt Mask; 188 bool MatchAndChain; 189 bool FoundAnd1 = false; 190 191 MaskOps(unsigned BitWidth, bool MatchAnds) 192 : Mask(APInt::getZero(BitWidth)), MatchAndChain(MatchAnds) {} 193 }; 194 } // namespace 195 196 /// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a 197 /// chain of 'and' or 'or' instructions looking for shift ops of a common source 198 /// value. Examples: 199 /// or (or (or X, (X >> 3)), (X >> 5)), (X >> 8) 200 /// returns { X, 0x129 } 201 /// and (and (X >> 1), 1), (X >> 4) 202 /// returns { X, 0x12 } 203 static bool matchAndOrChain(Value *V, MaskOps &MOps) { 204 Value *Op0, *Op1; 205 if (MOps.MatchAndChain) { 206 // Recurse through a chain of 'and' operands. This requires an extra check 207 // vs. the 'or' matcher: we must find an "and X, 1" instruction somewhere 208 // in the chain to know that all of the high bits are cleared. 209 if (match(V, m_And(m_Value(Op0), m_One()))) { 210 MOps.FoundAnd1 = true; 211 return matchAndOrChain(Op0, MOps); 212 } 213 if (match(V, m_And(m_Value(Op0), m_Value(Op1)))) 214 return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps); 215 } else { 216 // Recurse through a chain of 'or' operands. 217 if (match(V, m_Or(m_Value(Op0), m_Value(Op1)))) 218 return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps); 219 } 220 221 // We need a shift-right or a bare value representing a compare of bit 0 of 222 // the original source operand. 223 Value *Candidate; 224 const APInt *BitIndex = nullptr; 225 if (!match(V, m_LShr(m_Value(Candidate), m_APInt(BitIndex)))) 226 Candidate = V; 227 228 // Initialize result source operand. 229 if (!MOps.Root) 230 MOps.Root = Candidate; 231 232 // The shift constant is out-of-range? This code hasn't been simplified. 233 if (BitIndex && BitIndex->uge(MOps.Mask.getBitWidth())) 234 return false; 235 236 // Fill in the mask bit derived from the shift constant. 237 MOps.Mask.setBit(BitIndex ? BitIndex->getZExtValue() : 0); 238 return MOps.Root == Candidate; 239 } 240 241 /// Match patterns that correspond to "any-bits-set" and "all-bits-set". 242 /// These will include a chain of 'or' or 'and'-shifted bits from a 243 /// common source value: 244 /// and (or (lshr X, C), ...), 1 --> (X & CMask) != 0 245 /// and (and (lshr X, C), ...), 1 --> (X & CMask) == CMask 246 /// Note: "any-bits-clear" and "all-bits-clear" are variations of these patterns 247 /// that differ only with a final 'not' of the result. We expect that final 248 /// 'not' to be folded with the compare that we create here (invert predicate). 249 static bool foldAnyOrAllBitsSet(Instruction &I) { 250 // The 'any-bits-set' ('or' chain) pattern is simpler to match because the 251 // final "and X, 1" instruction must be the final op in the sequence. 252 bool MatchAllBitsSet; 253 if (match(&I, m_c_And(m_OneUse(m_And(m_Value(), m_Value())), m_Value()))) 254 MatchAllBitsSet = true; 255 else if (match(&I, m_And(m_OneUse(m_Or(m_Value(), m_Value())), m_One()))) 256 MatchAllBitsSet = false; 257 else 258 return false; 259 260 MaskOps MOps(I.getType()->getScalarSizeInBits(), MatchAllBitsSet); 261 if (MatchAllBitsSet) { 262 if (!matchAndOrChain(cast<BinaryOperator>(&I), MOps) || !MOps.FoundAnd1) 263 return false; 264 } else { 265 if (!matchAndOrChain(cast<BinaryOperator>(&I)->getOperand(0), MOps)) 266 return false; 267 } 268 269 // The pattern was found. Create a masked compare that replaces all of the 270 // shift and logic ops. 271 IRBuilder<> Builder(&I); 272 Constant *Mask = ConstantInt::get(I.getType(), MOps.Mask); 273 Value *And = Builder.CreateAnd(MOps.Root, Mask); 274 Value *Cmp = MatchAllBitsSet ? Builder.CreateICmpEQ(And, Mask) 275 : Builder.CreateIsNotNull(And); 276 Value *Zext = Builder.CreateZExt(Cmp, I.getType()); 277 I.replaceAllUsesWith(Zext); 278 ++NumAnyOrAllBitsSet; 279 return true; 280 } 281 282 // Try to recognize below function as popcount intrinsic. 283 // This is the "best" algorithm from 284 // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel 285 // Also used in TargetLowering::expandCTPOP(). 286 // 287 // int popcount(unsigned int i) { 288 // i = i - ((i >> 1) & 0x55555555); 289 // i = (i & 0x33333333) + ((i >> 2) & 0x33333333); 290 // i = ((i + (i >> 4)) & 0x0F0F0F0F); 291 // return (i * 0x01010101) >> 24; 292 // } 293 static bool tryToRecognizePopCount(Instruction &I) { 294 if (I.getOpcode() != Instruction::LShr) 295 return false; 296 297 Type *Ty = I.getType(); 298 if (!Ty->isIntOrIntVectorTy()) 299 return false; 300 301 unsigned Len = Ty->getScalarSizeInBits(); 302 // FIXME: fix Len == 8 and other irregular type lengths. 303 if (!(Len <= 128 && Len > 8 && Len % 8 == 0)) 304 return false; 305 306 APInt Mask55 = APInt::getSplat(Len, APInt(8, 0x55)); 307 APInt Mask33 = APInt::getSplat(Len, APInt(8, 0x33)); 308 APInt Mask0F = APInt::getSplat(Len, APInt(8, 0x0F)); 309 APInt Mask01 = APInt::getSplat(Len, APInt(8, 0x01)); 310 APInt MaskShift = APInt(Len, Len - 8); 311 312 Value *Op0 = I.getOperand(0); 313 Value *Op1 = I.getOperand(1); 314 Value *MulOp0; 315 // Matching "(i * 0x01010101...) >> 24". 316 if ((match(Op0, m_Mul(m_Value(MulOp0), m_SpecificInt(Mask01)))) && 317 match(Op1, m_SpecificInt(MaskShift))) { 318 Value *ShiftOp0; 319 // Matching "((i + (i >> 4)) & 0x0F0F0F0F...)". 320 if (match(MulOp0, m_And(m_c_Add(m_LShr(m_Value(ShiftOp0), m_SpecificInt(4)), 321 m_Deferred(ShiftOp0)), 322 m_SpecificInt(Mask0F)))) { 323 Value *AndOp0; 324 // Matching "(i & 0x33333333...) + ((i >> 2) & 0x33333333...)". 325 if (match(ShiftOp0, 326 m_c_Add(m_And(m_Value(AndOp0), m_SpecificInt(Mask33)), 327 m_And(m_LShr(m_Deferred(AndOp0), m_SpecificInt(2)), 328 m_SpecificInt(Mask33))))) { 329 Value *Root, *SubOp1; 330 // Matching "i - ((i >> 1) & 0x55555555...)". 331 const APInt *AndMask; 332 if (match(AndOp0, m_Sub(m_Value(Root), m_Value(SubOp1))) && 333 match(SubOp1, m_And(m_LShr(m_Specific(Root), m_SpecificInt(1)), 334 m_APInt(AndMask)))) { 335 auto CheckAndMask = [&]() { 336 if (*AndMask == Mask55) 337 return true; 338 339 // Exact match failed, see if any bits are known to be 0 where we 340 // expect a 1 in the mask. 341 if (!AndMask->isSubsetOf(Mask55)) 342 return false; 343 344 APInt NeededMask = Mask55 & ~*AndMask; 345 return MaskedValueIsZero(cast<Instruction>(SubOp1)->getOperand(0), 346 NeededMask, 347 SimplifyQuery(I.getDataLayout())); 348 }; 349 350 if (CheckAndMask()) { 351 LLVM_DEBUG(dbgs() << "Recognized popcount intrinsic\n"); 352 IRBuilder<> Builder(&I); 353 I.replaceAllUsesWith( 354 Builder.CreateIntrinsic(Intrinsic::ctpop, I.getType(), {Root})); 355 ++NumPopCountRecognized; 356 return true; 357 } 358 } 359 } 360 } 361 } 362 363 return false; 364 } 365 366 /// Fold smin(smax(fptosi(x), C1), C2) to llvm.fptosi.sat(x), providing C1 and 367 /// C2 saturate the value of the fp conversion. The transform is not reversable 368 /// as the fptosi.sat is more defined than the input - all values produce a 369 /// valid value for the fptosi.sat, where as some produce poison for original 370 /// that were out of range of the integer conversion. The reversed pattern may 371 /// use fmax and fmin instead. As we cannot directly reverse the transform, and 372 /// it is not always profitable, we make it conditional on the cost being 373 /// reported as lower by TTI. 374 static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) { 375 // Look for min(max(fptosi, converting to fptosi_sat. 376 Value *In; 377 const APInt *MinC, *MaxC; 378 if (!match(&I, m_SMax(m_OneUse(m_SMin(m_OneUse(m_FPToSI(m_Value(In))), 379 m_APInt(MinC))), 380 m_APInt(MaxC))) && 381 !match(&I, m_SMin(m_OneUse(m_SMax(m_OneUse(m_FPToSI(m_Value(In))), 382 m_APInt(MaxC))), 383 m_APInt(MinC)))) 384 return false; 385 386 // Check that the constants clamp a saturate. 387 if (!(*MinC + 1).isPowerOf2() || -*MaxC != *MinC + 1) 388 return false; 389 390 Type *IntTy = I.getType(); 391 Type *FpTy = In->getType(); 392 Type *SatTy = 393 IntegerType::get(IntTy->getContext(), (*MinC + 1).exactLogBase2() + 1); 394 if (auto *VecTy = dyn_cast<VectorType>(IntTy)) 395 SatTy = VectorType::get(SatTy, VecTy->getElementCount()); 396 397 // Get the cost of the intrinsic, and check that against the cost of 398 // fptosi+smin+smax 399 InstructionCost SatCost = TTI.getIntrinsicInstrCost( 400 IntrinsicCostAttributes(Intrinsic::fptosi_sat, SatTy, {In}, {FpTy}), 401 TTI::TCK_RecipThroughput); 402 SatCost += TTI.getCastInstrCost(Instruction::SExt, IntTy, SatTy, 403 TTI::CastContextHint::None, 404 TTI::TCK_RecipThroughput); 405 406 InstructionCost MinMaxCost = TTI.getCastInstrCost( 407 Instruction::FPToSI, IntTy, FpTy, TTI::CastContextHint::None, 408 TTI::TCK_RecipThroughput); 409 MinMaxCost += TTI.getIntrinsicInstrCost( 410 IntrinsicCostAttributes(Intrinsic::smin, IntTy, {IntTy}), 411 TTI::TCK_RecipThroughput); 412 MinMaxCost += TTI.getIntrinsicInstrCost( 413 IntrinsicCostAttributes(Intrinsic::smax, IntTy, {IntTy}), 414 TTI::TCK_RecipThroughput); 415 416 if (SatCost >= MinMaxCost) 417 return false; 418 419 IRBuilder<> Builder(&I); 420 Value *Sat = 421 Builder.CreateIntrinsic(Intrinsic::fptosi_sat, {SatTy, FpTy}, In); 422 I.replaceAllUsesWith(Builder.CreateSExt(Sat, IntTy)); 423 return true; 424 } 425 426 /// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids 427 /// pessimistic codegen that has to account for setting errno and can enable 428 /// vectorization. 429 static bool foldSqrt(CallInst *Call, LibFunc Func, TargetTransformInfo &TTI, 430 TargetLibraryInfo &TLI, AssumptionCache &AC, 431 DominatorTree &DT) { 432 // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created 433 // (because NNAN or the operand arg must not be less than -0.0) and (2) we 434 // would not end up lowering to a libcall anyway (which could change the value 435 // of errno), then: 436 // (1) errno won't be set. 437 // (2) it is safe to convert this to an intrinsic call. 438 Type *Ty = Call->getType(); 439 Value *Arg = Call->getArgOperand(0); 440 if (TTI.haveFastSqrt(Ty) && 441 (Call->hasNoNaNs() || 442 cannotBeOrderedLessThanZero( 443 Arg, SimplifyQuery(Call->getDataLayout(), &TLI, &DT, &AC, Call)))) { 444 IRBuilder<> Builder(Call); 445 Value *NewSqrt = 446 Builder.CreateIntrinsic(Intrinsic::sqrt, Ty, Arg, Call, "sqrt"); 447 Call->replaceAllUsesWith(NewSqrt); 448 449 // Explicitly erase the old call because a call with side effects is not 450 // trivially dead. 451 Call->eraseFromParent(); 452 return true; 453 } 454 455 return false; 456 } 457 458 // Check if this array of constants represents a cttz table. 459 // Iterate over the elements from \p Table by trying to find/match all 460 // the numbers from 0 to \p InputBits that should represent cttz results. 461 static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul, 462 uint64_t Shift, uint64_t InputBits) { 463 unsigned Length = Table.getNumElements(); 464 if (Length < InputBits || Length > InputBits * 2) 465 return false; 466 467 APInt Mask = APInt::getBitsSetFrom(InputBits, Shift); 468 unsigned Matched = 0; 469 470 for (unsigned i = 0; i < Length; i++) { 471 uint64_t Element = Table.getElementAsInteger(i); 472 if (Element >= InputBits) 473 continue; 474 475 // Check if \p Element matches a concrete answer. It could fail for some 476 // elements that are never accessed, so we keep iterating over each element 477 // from the table. The number of matched elements should be equal to the 478 // number of potential right answers which is \p InputBits actually. 479 if ((((Mul << Element) & Mask.getZExtValue()) >> Shift) == i) 480 Matched++; 481 } 482 483 return Matched == InputBits; 484 } 485 486 // Try to recognize table-based ctz implementation. 487 // E.g., an example in C (for more cases please see the llvm/tests): 488 // int f(unsigned x) { 489 // static const char table[32] = 490 // {0, 1, 28, 2, 29, 14, 24, 3, 30, 491 // 22, 20, 15, 25, 17, 4, 8, 31, 27, 492 // 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9}; 493 // return table[((unsigned)((x & -x) * 0x077CB531U)) >> 27]; 494 // } 495 // this can be lowered to `cttz` instruction. 496 // There is also a special case when the element is 0. 497 // 498 // Here are some examples or LLVM IR for a 64-bit target: 499 // 500 // CASE 1: 501 // %sub = sub i32 0, %x 502 // %and = and i32 %sub, %x 503 // %mul = mul i32 %and, 125613361 504 // %shr = lshr i32 %mul, 27 505 // %idxprom = zext i32 %shr to i64 506 // %arrayidx = getelementptr inbounds [32 x i8], [32 x i8]* @ctz1.table, i64 0, 507 // i64 %idxprom 508 // %0 = load i8, i8* %arrayidx, align 1, !tbaa !8 509 // 510 // CASE 2: 511 // %sub = sub i32 0, %x 512 // %and = and i32 %sub, %x 513 // %mul = mul i32 %and, 72416175 514 // %shr = lshr i32 %mul, 26 515 // %idxprom = zext i32 %shr to i64 516 // %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table, 517 // i64 0, i64 %idxprom 518 // %0 = load i16, i16* %arrayidx, align 2, !tbaa !8 519 // 520 // CASE 3: 521 // %sub = sub i32 0, %x 522 // %and = and i32 %sub, %x 523 // %mul = mul i32 %and, 81224991 524 // %shr = lshr i32 %mul, 27 525 // %idxprom = zext i32 %shr to i64 526 // %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table, 527 // i64 0, i64 %idxprom 528 // %0 = load i32, i32* %arrayidx, align 4, !tbaa !8 529 // 530 // CASE 4: 531 // %sub = sub i64 0, %x 532 // %and = and i64 %sub, %x 533 // %mul = mul i64 %and, 283881067100198605 534 // %shr = lshr i64 %mul, 58 535 // %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0, 536 // i64 %shr 537 // %0 = load i8, i8* %arrayidx, align 1, !tbaa !8 538 // 539 // All this can be lowered to @llvm.cttz.i32/64 intrinsic. 540 static bool tryToRecognizeTableBasedCttz(Instruction &I) { 541 LoadInst *LI = dyn_cast<LoadInst>(&I); 542 if (!LI) 543 return false; 544 545 Type *AccessType = LI->getType(); 546 if (!AccessType->isIntegerTy()) 547 return false; 548 549 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand()); 550 if (!GEP || !GEP->hasNoUnsignedSignedWrap() || GEP->getNumIndices() != 2) 551 return false; 552 553 if (!GEP->getSourceElementType()->isArrayTy()) 554 return false; 555 556 uint64_t ArraySize = GEP->getSourceElementType()->getArrayNumElements(); 557 if (ArraySize != 32 && ArraySize != 64) 558 return false; 559 560 GlobalVariable *GVTable = dyn_cast<GlobalVariable>(GEP->getPointerOperand()); 561 if (!GVTable || !GVTable->hasInitializer() || !GVTable->isConstant()) 562 return false; 563 564 ConstantDataArray *ConstData = 565 dyn_cast<ConstantDataArray>(GVTable->getInitializer()); 566 if (!ConstData) 567 return false; 568 569 if (!match(GEP->idx_begin()->get(), m_ZeroInt())) 570 return false; 571 572 Value *Idx2 = std::next(GEP->idx_begin())->get(); 573 Value *X1; 574 uint64_t MulConst, ShiftConst; 575 // FIXME: 64-bit targets have `i64` type for the GEP index, so this match will 576 // probably fail for other (e.g. 32-bit) targets. 577 if (!match(Idx2, m_ZExtOrSelf( 578 m_LShr(m_Mul(m_c_And(m_Neg(m_Value(X1)), m_Deferred(X1)), 579 m_ConstantInt(MulConst)), 580 m_ConstantInt(ShiftConst))))) 581 return false; 582 583 unsigned InputBits = X1->getType()->getScalarSizeInBits(); 584 if (InputBits != 32 && InputBits != 64) 585 return false; 586 587 // Shift should extract top 5..7 bits. 588 if (InputBits - Log2_32(InputBits) != ShiftConst && 589 InputBits - Log2_32(InputBits) - 1 != ShiftConst) 590 return false; 591 592 if (!isCTTZTable(*ConstData, MulConst, ShiftConst, InputBits)) 593 return false; 594 595 auto ZeroTableElem = ConstData->getElementAsInteger(0); 596 bool DefinedForZero = ZeroTableElem == InputBits; 597 598 IRBuilder<> B(LI); 599 ConstantInt *BoolConst = B.getInt1(!DefinedForZero); 600 Type *XType = X1->getType(); 601 auto Cttz = B.CreateIntrinsic(Intrinsic::cttz, {XType}, {X1, BoolConst}); 602 Value *ZExtOrTrunc = nullptr; 603 604 if (DefinedForZero) { 605 ZExtOrTrunc = B.CreateZExtOrTrunc(Cttz, AccessType); 606 } else { 607 // If the value in elem 0 isn't the same as InputBits, we still want to 608 // produce the value from the table. 609 auto Cmp = B.CreateICmpEQ(X1, ConstantInt::get(XType, 0)); 610 auto Select = 611 B.CreateSelect(Cmp, ConstantInt::get(XType, ZeroTableElem), Cttz); 612 613 // NOTE: If the table[0] is 0, but the cttz(0) is defined by the Target 614 // it should be handled as: `cttz(x) & (typeSize - 1)`. 615 616 ZExtOrTrunc = B.CreateZExtOrTrunc(Select, AccessType); 617 } 618 619 LI->replaceAllUsesWith(ZExtOrTrunc); 620 621 return true; 622 } 623 624 /// This is used by foldLoadsRecursive() to capture a Root Load node which is 625 /// of type or(load, load) and recursively build the wide load. Also capture the 626 /// shift amount, zero extend type and loadSize. 627 struct LoadOps { 628 LoadInst *Root = nullptr; 629 LoadInst *RootInsert = nullptr; 630 bool FoundRoot = false; 631 uint64_t LoadSize = 0; 632 const APInt *Shift = nullptr; 633 Type *ZextType; 634 AAMDNodes AATags; 635 }; 636 637 // Identify and Merge consecutive loads recursively which is of the form 638 // (ZExt(L1) << shift1) | (ZExt(L2) << shift2) -> ZExt(L3) << shift1 639 // (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3) 640 static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL, 641 AliasAnalysis &AA) { 642 const APInt *ShAmt2 = nullptr; 643 Value *X; 644 Instruction *L1, *L2; 645 646 // Go to the last node with loads. 647 if (match(V, m_OneUse(m_c_Or( 648 m_Value(X), 649 m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))), 650 m_APInt(ShAmt2)))))) || 651 match(V, m_OneUse(m_Or(m_Value(X), 652 m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) { 653 if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot) 654 // Avoid Partial chain merge. 655 return false; 656 } else 657 return false; 658 659 // Check if the pattern has loads 660 LoadInst *LI1 = LOps.Root; 661 const APInt *ShAmt1 = LOps.Shift; 662 if (LOps.FoundRoot == false && 663 (match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) || 664 match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))), 665 m_APInt(ShAmt1)))))) { 666 LI1 = dyn_cast<LoadInst>(L1); 667 } 668 LoadInst *LI2 = dyn_cast<LoadInst>(L2); 669 670 // Check if loads are same, atomic, volatile and having same address space. 671 if (LI1 == LI2 || !LI1 || !LI2 || !LI1->isSimple() || !LI2->isSimple() || 672 LI1->getPointerAddressSpace() != LI2->getPointerAddressSpace()) 673 return false; 674 675 // Check if Loads come from same BB. 676 if (LI1->getParent() != LI2->getParent()) 677 return false; 678 679 // Find the data layout 680 bool IsBigEndian = DL.isBigEndian(); 681 682 // Check if loads are consecutive and same size. 683 Value *Load1Ptr = LI1->getPointerOperand(); 684 APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0); 685 Load1Ptr = 686 Load1Ptr->stripAndAccumulateConstantOffsets(DL, Offset1, 687 /* AllowNonInbounds */ true); 688 689 Value *Load2Ptr = LI2->getPointerOperand(); 690 APInt Offset2(DL.getIndexTypeSizeInBits(Load2Ptr->getType()), 0); 691 Load2Ptr = 692 Load2Ptr->stripAndAccumulateConstantOffsets(DL, Offset2, 693 /* AllowNonInbounds */ true); 694 695 // Verify if both loads have same base pointers 696 uint64_t LoadSize1 = LI1->getType()->getPrimitiveSizeInBits(); 697 uint64_t LoadSize2 = LI2->getType()->getPrimitiveSizeInBits(); 698 if (Load1Ptr != Load2Ptr) 699 return false; 700 701 // Make sure that there are no padding bits. 702 if (!DL.typeSizeEqualsStoreSize(LI1->getType()) || 703 !DL.typeSizeEqualsStoreSize(LI2->getType())) 704 return false; 705 706 // Alias Analysis to check for stores b/w the loads. 707 LoadInst *Start = LOps.FoundRoot ? LOps.RootInsert : LI1, *End = LI2; 708 MemoryLocation Loc; 709 if (!Start->comesBefore(End)) { 710 std::swap(Start, End); 711 Loc = MemoryLocation::get(End); 712 if (LOps.FoundRoot) 713 Loc = Loc.getWithNewSize(LOps.LoadSize); 714 } else 715 Loc = MemoryLocation::get(End); 716 unsigned NumScanned = 0; 717 for (Instruction &Inst : 718 make_range(Start->getIterator(), End->getIterator())) { 719 if (Inst.mayWriteToMemory() && isModSet(AA.getModRefInfo(&Inst, Loc))) 720 return false; 721 722 if (++NumScanned > MaxInstrsToScan) 723 return false; 724 } 725 726 // Make sure Load with lower Offset is at LI1 727 bool Reverse = false; 728 if (Offset2.slt(Offset1)) { 729 std::swap(LI1, LI2); 730 std::swap(ShAmt1, ShAmt2); 731 std::swap(Offset1, Offset2); 732 std::swap(Load1Ptr, Load2Ptr); 733 std::swap(LoadSize1, LoadSize2); 734 Reverse = true; 735 } 736 737 // Big endian swap the shifts 738 if (IsBigEndian) 739 std::swap(ShAmt1, ShAmt2); 740 741 // Find Shifts values. 742 uint64_t Shift1 = 0, Shift2 = 0; 743 if (ShAmt1) 744 Shift1 = ShAmt1->getZExtValue(); 745 if (ShAmt2) 746 Shift2 = ShAmt2->getZExtValue(); 747 748 // First load is always LI1. This is where we put the new load. 749 // Use the merged load size available from LI1 for forward loads. 750 if (LOps.FoundRoot) { 751 if (!Reverse) 752 LoadSize1 = LOps.LoadSize; 753 else 754 LoadSize2 = LOps.LoadSize; 755 } 756 757 // Verify if shift amount and load index aligns and verifies that loads 758 // are consecutive. 759 uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1; 760 uint64_t PrevSize = 761 DL.getTypeStoreSize(IntegerType::get(LI1->getContext(), LoadSize1)); 762 if ((Shift2 - Shift1) != ShiftDiff || (Offset2 - Offset1) != PrevSize) 763 return false; 764 765 // Update LOps 766 AAMDNodes AATags1 = LOps.AATags; 767 AAMDNodes AATags2 = LI2->getAAMetadata(); 768 if (LOps.FoundRoot == false) { 769 LOps.FoundRoot = true; 770 AATags1 = LI1->getAAMetadata(); 771 } 772 LOps.LoadSize = LoadSize1 + LoadSize2; 773 LOps.RootInsert = Start; 774 775 // Concatenate the AATags of the Merged Loads. 776 LOps.AATags = AATags1.concat(AATags2); 777 778 LOps.Root = LI1; 779 LOps.Shift = ShAmt1; 780 LOps.ZextType = X->getType(); 781 return true; 782 } 783 784 // For a given BB instruction, evaluate all loads in the chain that form a 785 // pattern which suggests that the loads can be combined. The one and only use 786 // of the loads is to form a wider load. 787 static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL, 788 TargetTransformInfo &TTI, AliasAnalysis &AA, 789 const DominatorTree &DT) { 790 // Only consider load chains of scalar values. 791 if (isa<VectorType>(I.getType())) 792 return false; 793 794 LoadOps LOps; 795 if (!foldLoadsRecursive(&I, LOps, DL, AA) || !LOps.FoundRoot) 796 return false; 797 798 IRBuilder<> Builder(&I); 799 LoadInst *NewLoad = nullptr, *LI1 = LOps.Root; 800 801 IntegerType *WiderType = IntegerType::get(I.getContext(), LOps.LoadSize); 802 // TTI based checks if we want to proceed with wider load 803 bool Allowed = TTI.isTypeLegal(WiderType); 804 if (!Allowed) 805 return false; 806 807 unsigned AS = LI1->getPointerAddressSpace(); 808 unsigned Fast = 0; 809 Allowed = TTI.allowsMisalignedMemoryAccesses(I.getContext(), LOps.LoadSize, 810 AS, LI1->getAlign(), &Fast); 811 if (!Allowed || !Fast) 812 return false; 813 814 // Get the Index and Ptr for the new GEP. 815 Value *Load1Ptr = LI1->getPointerOperand(); 816 Builder.SetInsertPoint(LOps.RootInsert); 817 if (!DT.dominates(Load1Ptr, LOps.RootInsert)) { 818 APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0); 819 Load1Ptr = Load1Ptr->stripAndAccumulateConstantOffsets( 820 DL, Offset1, /* AllowNonInbounds */ true); 821 Load1Ptr = Builder.CreatePtrAdd(Load1Ptr, Builder.getInt(Offset1)); 822 } 823 // Generate wider load. 824 NewLoad = Builder.CreateAlignedLoad(WiderType, Load1Ptr, LI1->getAlign(), 825 LI1->isVolatile(), ""); 826 NewLoad->takeName(LI1); 827 // Set the New Load AATags Metadata. 828 if (LOps.AATags) 829 NewLoad->setAAMetadata(LOps.AATags); 830 831 Value *NewOp = NewLoad; 832 // Check if zero extend needed. 833 if (LOps.ZextType) 834 NewOp = Builder.CreateZExt(NewOp, LOps.ZextType); 835 836 // Check if shift needed. We need to shift with the amount of load1 837 // shift if not zero. 838 if (LOps.Shift) 839 NewOp = Builder.CreateShl(NewOp, ConstantInt::get(I.getContext(), *LOps.Shift)); 840 I.replaceAllUsesWith(NewOp); 841 842 return true; 843 } 844 845 /// Combine away instructions providing they are still equivalent when compared 846 /// against 0. i.e do they have any bits set. 847 static Value *optimizeShiftInOrChain(Value *V, IRBuilder<> &Builder) { 848 auto *I = dyn_cast<Instruction>(V); 849 if (!I || I->getOpcode() != Instruction::Or || !I->hasOneUse()) 850 return nullptr; 851 852 Value *A; 853 854 // Look deeper into the chain of or's, combining away shl (so long as they are 855 // nuw or nsw). 856 Value *Op0 = I->getOperand(0); 857 if (match(Op0, m_CombineOr(m_NSWShl(m_Value(A), m_Value()), 858 m_NUWShl(m_Value(A), m_Value())))) 859 Op0 = A; 860 else if (auto *NOp = optimizeShiftInOrChain(Op0, Builder)) 861 Op0 = NOp; 862 863 Value *Op1 = I->getOperand(1); 864 if (match(Op1, m_CombineOr(m_NSWShl(m_Value(A), m_Value()), 865 m_NUWShl(m_Value(A), m_Value())))) 866 Op1 = A; 867 else if (auto *NOp = optimizeShiftInOrChain(Op1, Builder)) 868 Op1 = NOp; 869 870 if (Op0 != I->getOperand(0) || Op1 != I->getOperand(1)) 871 return Builder.CreateOr(Op0, Op1); 872 return nullptr; 873 } 874 875 static bool foldICmpOrChain(Instruction &I, const DataLayout &DL, 876 TargetTransformInfo &TTI, AliasAnalysis &AA, 877 const DominatorTree &DT) { 878 CmpPredicate Pred; 879 Value *Op0; 880 if (!match(&I, m_ICmp(Pred, m_Value(Op0), m_Zero())) || 881 !ICmpInst::isEquality(Pred)) 882 return false; 883 884 // If the chain or or's matches a load, combine to that before attempting to 885 // remove shifts. 886 if (auto OpI = dyn_cast<Instruction>(Op0)) 887 if (OpI->getOpcode() == Instruction::Or) 888 if (foldConsecutiveLoads(*OpI, DL, TTI, AA, DT)) 889 return true; 890 891 IRBuilder<> Builder(&I); 892 // icmp eq/ne or(shl(a), b), 0 -> icmp eq/ne or(a, b), 0 893 if (auto *Res = optimizeShiftInOrChain(Op0, Builder)) { 894 I.replaceAllUsesWith(Builder.CreateICmp(Pred, Res, I.getOperand(1))); 895 return true; 896 } 897 898 return false; 899 } 900 901 // Calculate GEP Stride and accumulated const ModOffset. Return Stride and 902 // ModOffset 903 static std::pair<APInt, APInt> 904 getStrideAndModOffsetOfGEP(Value *PtrOp, const DataLayout &DL) { 905 unsigned BW = DL.getIndexTypeSizeInBits(PtrOp->getType()); 906 std::optional<APInt> Stride; 907 APInt ModOffset(BW, 0); 908 // Return a minimum gep stride, greatest common divisor of consective gep 909 // index scales(c.f. Bézout's identity). 910 while (auto *GEP = dyn_cast<GEPOperator>(PtrOp)) { 911 SmallMapVector<Value *, APInt, 4> VarOffsets; 912 if (!GEP->collectOffset(DL, BW, VarOffsets, ModOffset)) 913 break; 914 915 for (auto [V, Scale] : VarOffsets) { 916 // Only keep a power of two factor for non-inbounds 917 if (!GEP->hasNoUnsignedSignedWrap()) 918 Scale = APInt::getOneBitSet(Scale.getBitWidth(), Scale.countr_zero()); 919 920 if (!Stride) 921 Stride = Scale; 922 else 923 Stride = APIntOps::GreatestCommonDivisor(*Stride, Scale); 924 } 925 926 PtrOp = GEP->getPointerOperand(); 927 } 928 929 // Check whether pointer arrives back at Global Variable via at least one GEP. 930 // Even if it doesn't, we can check by alignment. 931 if (!isa<GlobalVariable>(PtrOp) || !Stride) 932 return {APInt(BW, 1), APInt(BW, 0)}; 933 934 // In consideration of signed GEP indices, non-negligible offset become 935 // remainder of division by minimum GEP stride. 936 ModOffset = ModOffset.srem(*Stride); 937 if (ModOffset.isNegative()) 938 ModOffset += *Stride; 939 940 return {*Stride, ModOffset}; 941 } 942 943 /// If C is a constant patterned array and all valid loaded results for given 944 /// alignment are same to a constant, return that constant. 945 static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) { 946 auto *LI = dyn_cast<LoadInst>(&I); 947 if (!LI || LI->isVolatile()) 948 return false; 949 950 // We can only fold the load if it is from a constant global with definitive 951 // initializer. Skip expensive logic if this is not the case. 952 auto *PtrOp = LI->getPointerOperand(); 953 auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(PtrOp)); 954 if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) 955 return false; 956 957 // Bail for large initializers in excess of 4K to avoid too many scans. 958 Constant *C = GV->getInitializer(); 959 uint64_t GVSize = DL.getTypeAllocSize(C->getType()); 960 if (!GVSize || 4096 < GVSize) 961 return false; 962 963 Type *LoadTy = LI->getType(); 964 unsigned BW = DL.getIndexTypeSizeInBits(PtrOp->getType()); 965 auto [Stride, ConstOffset] = getStrideAndModOffsetOfGEP(PtrOp, DL); 966 967 // Any possible offset could be multiple of GEP stride. And any valid 968 // offset is multiple of load alignment, so checking only multiples of bigger 969 // one is sufficient to say results' equality. 970 if (auto LA = LI->getAlign(); 971 LA <= GV->getAlign().valueOrOne() && Stride.getZExtValue() < LA.value()) { 972 ConstOffset = APInt(BW, 0); 973 Stride = APInt(BW, LA.value()); 974 } 975 976 Constant *Ca = ConstantFoldLoadFromConst(C, LoadTy, ConstOffset, DL); 977 if (!Ca) 978 return false; 979 980 unsigned E = GVSize - DL.getTypeStoreSize(LoadTy); 981 for (; ConstOffset.getZExtValue() <= E; ConstOffset += Stride) 982 if (Ca != ConstantFoldLoadFromConst(C, LoadTy, ConstOffset, DL)) 983 return false; 984 985 I.replaceAllUsesWith(Ca); 986 987 return true; 988 } 989 990 namespace { 991 class StrNCmpInliner { 992 public: 993 StrNCmpInliner(CallInst *CI, LibFunc Func, DomTreeUpdater *DTU, 994 const DataLayout &DL) 995 : CI(CI), Func(Func), DTU(DTU), DL(DL) {} 996 997 bool optimizeStrNCmp(); 998 999 private: 1000 void inlineCompare(Value *LHS, StringRef RHS, uint64_t N, bool Swapped); 1001 1002 CallInst *CI; 1003 LibFunc Func; 1004 DomTreeUpdater *DTU; 1005 const DataLayout &DL; 1006 }; 1007 1008 } // namespace 1009 1010 /// First we normalize calls to strncmp/strcmp to the form of 1011 /// compare(s1, s2, N), which means comparing first N bytes of s1 and s2 1012 /// (without considering '\0'). 1013 /// 1014 /// Examples: 1015 /// 1016 /// \code 1017 /// strncmp(s, "a", 3) -> compare(s, "a", 2) 1018 /// strncmp(s, "abc", 3) -> compare(s, "abc", 3) 1019 /// strncmp(s, "a\0b", 3) -> compare(s, "a\0b", 2) 1020 /// strcmp(s, "a") -> compare(s, "a", 2) 1021 /// 1022 /// char s2[] = {'a'} 1023 /// strncmp(s, s2, 3) -> compare(s, s2, 3) 1024 /// 1025 /// char s2[] = {'a', 'b', 'c', 'd'} 1026 /// strncmp(s, s2, 3) -> compare(s, s2, 3) 1027 /// \endcode 1028 /// 1029 /// We only handle cases where N and exactly one of s1 and s2 are constant. 1030 /// Cases that s1 and s2 are both constant are already handled by the 1031 /// instcombine pass. 1032 /// 1033 /// We do not handle cases where N > StrNCmpInlineThreshold. 1034 /// 1035 /// We also do not handles cases where N < 2, which are already 1036 /// handled by the instcombine pass. 1037 /// 1038 bool StrNCmpInliner::optimizeStrNCmp() { 1039 if (StrNCmpInlineThreshold < 2) 1040 return false; 1041 1042 if (!isOnlyUsedInZeroComparison(CI)) 1043 return false; 1044 1045 Value *Str1P = CI->getArgOperand(0); 1046 Value *Str2P = CI->getArgOperand(1); 1047 // Should be handled elsewhere. 1048 if (Str1P == Str2P) 1049 return false; 1050 1051 StringRef Str1, Str2; 1052 bool HasStr1 = getConstantStringInfo(Str1P, Str1, /*TrimAtNul=*/false); 1053 bool HasStr2 = getConstantStringInfo(Str2P, Str2, /*TrimAtNul=*/false); 1054 if (HasStr1 == HasStr2) 1055 return false; 1056 1057 // Note that '\0' and characters after it are not trimmed. 1058 StringRef Str = HasStr1 ? Str1 : Str2; 1059 Value *StrP = HasStr1 ? Str2P : Str1P; 1060 1061 size_t Idx = Str.find('\0'); 1062 uint64_t N = Idx == StringRef::npos ? UINT64_MAX : Idx + 1; 1063 if (Func == LibFunc_strncmp) { 1064 if (auto *ConstInt = dyn_cast<ConstantInt>(CI->getArgOperand(2))) 1065 N = std::min(N, ConstInt->getZExtValue()); 1066 else 1067 return false; 1068 } 1069 // Now N means how many bytes we need to compare at most. 1070 if (N > Str.size() || N < 2 || N > StrNCmpInlineThreshold) 1071 return false; 1072 1073 // Cases where StrP has two or more dereferenceable bytes might be better 1074 // optimized elsewhere. 1075 bool CanBeNull = false, CanBeFreed = false; 1076 if (StrP->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed) > 1) 1077 return false; 1078 inlineCompare(StrP, Str, N, HasStr1); 1079 return true; 1080 } 1081 1082 /// Convert 1083 /// 1084 /// \code 1085 /// ret = compare(s1, s2, N) 1086 /// \endcode 1087 /// 1088 /// into 1089 /// 1090 /// \code 1091 /// ret = (int)s1[0] - (int)s2[0] 1092 /// if (ret != 0) 1093 /// goto NE 1094 /// ... 1095 /// ret = (int)s1[N-2] - (int)s2[N-2] 1096 /// if (ret != 0) 1097 /// goto NE 1098 /// ret = (int)s1[N-1] - (int)s2[N-1] 1099 /// NE: 1100 /// \endcode 1101 /// 1102 /// CFG before and after the transformation: 1103 /// 1104 /// (before) 1105 /// BBCI 1106 /// 1107 /// (after) 1108 /// BBCI -> BBSubs[0] (sub,icmp) --NE-> BBNE -> BBTail 1109 /// | ^ 1110 /// E | 1111 /// | | 1112 /// BBSubs[1] (sub,icmp) --NE-----+ 1113 /// ... | 1114 /// BBSubs[N-1] (sub) ---------+ 1115 /// 1116 void StrNCmpInliner::inlineCompare(Value *LHS, StringRef RHS, uint64_t N, 1117 bool Swapped) { 1118 auto &Ctx = CI->getContext(); 1119 IRBuilder<> B(Ctx); 1120 // We want these instructions to be recognized as inlined instructions for the 1121 // compare call, but we don't have a source location for the definition of 1122 // that function, since we're generating that code now. Because the generated 1123 // code is a viable point for a memory access error, we make the pragmatic 1124 // choice here to directly use CI's location so that we have useful 1125 // attribution for the generated code. 1126 B.SetCurrentDebugLocation(CI->getDebugLoc()); 1127 1128 BasicBlock *BBCI = CI->getParent(); 1129 BasicBlock *BBTail = 1130 SplitBlock(BBCI, CI, DTU, nullptr, nullptr, BBCI->getName() + ".tail"); 1131 1132 SmallVector<BasicBlock *> BBSubs; 1133 for (uint64_t I = 0; I < N; ++I) 1134 BBSubs.push_back( 1135 BasicBlock::Create(Ctx, "sub_" + Twine(I), BBCI->getParent(), BBTail)); 1136 BasicBlock *BBNE = BasicBlock::Create(Ctx, "ne", BBCI->getParent(), BBTail); 1137 1138 cast<BranchInst>(BBCI->getTerminator())->setSuccessor(0, BBSubs[0]); 1139 1140 B.SetInsertPoint(BBNE); 1141 PHINode *Phi = B.CreatePHI(CI->getType(), N); 1142 B.CreateBr(BBTail); 1143 1144 Value *Base = LHS; 1145 for (uint64_t i = 0; i < N; ++i) { 1146 B.SetInsertPoint(BBSubs[i]); 1147 Value *VL = 1148 B.CreateZExt(B.CreateLoad(B.getInt8Ty(), 1149 B.CreateInBoundsPtrAdd(Base, B.getInt64(i))), 1150 CI->getType()); 1151 Value *VR = 1152 ConstantInt::get(CI->getType(), static_cast<unsigned char>(RHS[i])); 1153 Value *Sub = Swapped ? B.CreateSub(VR, VL) : B.CreateSub(VL, VR); 1154 if (i < N - 1) 1155 B.CreateCondBr(B.CreateICmpNE(Sub, ConstantInt::get(CI->getType(), 0)), 1156 BBNE, BBSubs[i + 1]); 1157 else 1158 B.CreateBr(BBNE); 1159 1160 Phi->addIncoming(Sub, BBSubs[i]); 1161 } 1162 1163 CI->replaceAllUsesWith(Phi); 1164 CI->eraseFromParent(); 1165 1166 if (DTU) { 1167 SmallVector<DominatorTree::UpdateType, 8> Updates; 1168 Updates.push_back({DominatorTree::Insert, BBCI, BBSubs[0]}); 1169 for (uint64_t i = 0; i < N; ++i) { 1170 if (i < N - 1) 1171 Updates.push_back({DominatorTree::Insert, BBSubs[i], BBSubs[i + 1]}); 1172 Updates.push_back({DominatorTree::Insert, BBSubs[i], BBNE}); 1173 } 1174 Updates.push_back({DominatorTree::Insert, BBNE, BBTail}); 1175 Updates.push_back({DominatorTree::Delete, BBCI, BBTail}); 1176 DTU->applyUpdates(Updates); 1177 } 1178 } 1179 1180 /// Convert memchr with a small constant string into a switch 1181 static bool foldMemChr(CallInst *Call, DomTreeUpdater *DTU, 1182 const DataLayout &DL) { 1183 if (isa<Constant>(Call->getArgOperand(1))) 1184 return false; 1185 1186 StringRef Str; 1187 Value *Base = Call->getArgOperand(0); 1188 if (!getConstantStringInfo(Base, Str, /*TrimAtNul=*/false)) 1189 return false; 1190 1191 uint64_t N = Str.size(); 1192 if (auto *ConstInt = dyn_cast<ConstantInt>(Call->getArgOperand(2))) { 1193 uint64_t Val = ConstInt->getZExtValue(); 1194 // Ignore the case that n is larger than the size of string. 1195 if (Val > N) 1196 return false; 1197 N = Val; 1198 } else 1199 return false; 1200 1201 if (N > MemChrInlineThreshold) 1202 return false; 1203 1204 BasicBlock *BB = Call->getParent(); 1205 BasicBlock *BBNext = SplitBlock(BB, Call, DTU); 1206 IRBuilder<> IRB(BB); 1207 IRB.SetCurrentDebugLocation(Call->getDebugLoc()); 1208 IntegerType *ByteTy = IRB.getInt8Ty(); 1209 BB->getTerminator()->eraseFromParent(); 1210 SwitchInst *SI = IRB.CreateSwitch( 1211 IRB.CreateTrunc(Call->getArgOperand(1), ByteTy), BBNext, N); 1212 Type *IndexTy = DL.getIndexType(Call->getType()); 1213 SmallVector<DominatorTree::UpdateType, 8> Updates; 1214 1215 BasicBlock *BBSuccess = BasicBlock::Create( 1216 Call->getContext(), "memchr.success", BB->getParent(), BBNext); 1217 IRB.SetInsertPoint(BBSuccess); 1218 PHINode *IndexPHI = IRB.CreatePHI(IndexTy, N, "memchr.idx"); 1219 Value *FirstOccursLocation = IRB.CreateInBoundsPtrAdd(Base, IndexPHI); 1220 IRB.CreateBr(BBNext); 1221 if (DTU) 1222 Updates.push_back({DominatorTree::Insert, BBSuccess, BBNext}); 1223 1224 SmallPtrSet<ConstantInt *, 4> Cases; 1225 for (uint64_t I = 0; I < N; ++I) { 1226 ConstantInt *CaseVal = ConstantInt::get(ByteTy, Str[I]); 1227 if (!Cases.insert(CaseVal).second) 1228 continue; 1229 1230 BasicBlock *BBCase = BasicBlock::Create(Call->getContext(), "memchr.case", 1231 BB->getParent(), BBSuccess); 1232 SI->addCase(CaseVal, BBCase); 1233 IRB.SetInsertPoint(BBCase); 1234 IndexPHI->addIncoming(ConstantInt::get(IndexTy, I), BBCase); 1235 IRB.CreateBr(BBSuccess); 1236 if (DTU) { 1237 Updates.push_back({DominatorTree::Insert, BB, BBCase}); 1238 Updates.push_back({DominatorTree::Insert, BBCase, BBSuccess}); 1239 } 1240 } 1241 1242 PHINode *PHI = 1243 PHINode::Create(Call->getType(), 2, Call->getName(), BBNext->begin()); 1244 PHI->addIncoming(Constant::getNullValue(Call->getType()), BB); 1245 PHI->addIncoming(FirstOccursLocation, BBSuccess); 1246 1247 Call->replaceAllUsesWith(PHI); 1248 Call->eraseFromParent(); 1249 1250 if (DTU) 1251 DTU->applyUpdates(Updates); 1252 1253 return true; 1254 } 1255 1256 static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI, 1257 TargetLibraryInfo &TLI, AssumptionCache &AC, 1258 DominatorTree &DT, const DataLayout &DL, 1259 bool &MadeCFGChange) { 1260 1261 auto *CI = dyn_cast<CallInst>(&I); 1262 if (!CI || CI->isNoBuiltin()) 1263 return false; 1264 1265 Function *CalledFunc = CI->getCalledFunction(); 1266 if (!CalledFunc) 1267 return false; 1268 1269 LibFunc LF; 1270 if (!TLI.getLibFunc(*CalledFunc, LF) || 1271 !isLibFuncEmittable(CI->getModule(), &TLI, LF)) 1272 return false; 1273 1274 DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Lazy); 1275 1276 switch (LF) { 1277 case LibFunc_sqrt: 1278 case LibFunc_sqrtf: 1279 case LibFunc_sqrtl: 1280 return foldSqrt(CI, LF, TTI, TLI, AC, DT); 1281 case LibFunc_strcmp: 1282 case LibFunc_strncmp: 1283 if (StrNCmpInliner(CI, LF, &DTU, DL).optimizeStrNCmp()) { 1284 MadeCFGChange = true; 1285 return true; 1286 } 1287 break; 1288 case LibFunc_memchr: 1289 if (foldMemChr(CI, &DTU, DL)) { 1290 MadeCFGChange = true; 1291 return true; 1292 } 1293 break; 1294 default:; 1295 } 1296 return false; 1297 } 1298 1299 /// This is the entry point for folds that could be implemented in regular 1300 /// InstCombine, but they are separated because they are not expected to 1301 /// occur frequently and/or have more than a constant-length pattern match. 1302 static bool foldUnusualPatterns(Function &F, DominatorTree &DT, 1303 TargetTransformInfo &TTI, 1304 TargetLibraryInfo &TLI, AliasAnalysis &AA, 1305 AssumptionCache &AC, bool &MadeCFGChange) { 1306 bool MadeChange = false; 1307 for (BasicBlock &BB : F) { 1308 // Ignore unreachable basic blocks. 1309 if (!DT.isReachableFromEntry(&BB)) 1310 continue; 1311 1312 const DataLayout &DL = F.getDataLayout(); 1313 1314 // Walk the block backwards for efficiency. We're matching a chain of 1315 // use->defs, so we're more likely to succeed by starting from the bottom. 1316 // Also, we want to avoid matching partial patterns. 1317 // TODO: It would be more efficient if we removed dead instructions 1318 // iteratively in this loop rather than waiting until the end. 1319 for (Instruction &I : make_early_inc_range(llvm::reverse(BB))) { 1320 MadeChange |= foldAnyOrAllBitsSet(I); 1321 MadeChange |= foldGuardedFunnelShift(I, DT); 1322 MadeChange |= tryToRecognizePopCount(I); 1323 MadeChange |= tryToFPToSat(I, TTI); 1324 MadeChange |= tryToRecognizeTableBasedCttz(I); 1325 MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT); 1326 MadeChange |= foldPatternedLoads(I, DL); 1327 MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT); 1328 // NOTE: This function introduces erasing of the instruction `I`, so it 1329 // needs to be called at the end of this sequence, otherwise we may make 1330 // bugs. 1331 MadeChange |= foldLibCalls(I, TTI, TLI, AC, DT, DL, MadeCFGChange); 1332 } 1333 } 1334 1335 // We're done with transforms, so remove dead instructions. 1336 if (MadeChange) 1337 for (BasicBlock &BB : F) 1338 SimplifyInstructionsInBlock(&BB); 1339 1340 return MadeChange; 1341 } 1342 1343 /// This is the entry point for all transforms. Pass manager differences are 1344 /// handled in the callers of this function. 1345 static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI, 1346 TargetLibraryInfo &TLI, DominatorTree &DT, 1347 AliasAnalysis &AA, bool &MadeCFGChange) { 1348 bool MadeChange = false; 1349 const DataLayout &DL = F.getDataLayout(); 1350 TruncInstCombine TIC(AC, TLI, DL, DT); 1351 MadeChange |= TIC.run(F); 1352 MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC, MadeCFGChange); 1353 return MadeChange; 1354 } 1355 1356 PreservedAnalyses AggressiveInstCombinePass::run(Function &F, 1357 FunctionAnalysisManager &AM) { 1358 auto &AC = AM.getResult<AssumptionAnalysis>(F); 1359 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); 1360 auto &DT = AM.getResult<DominatorTreeAnalysis>(F); 1361 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 1362 auto &AA = AM.getResult<AAManager>(F); 1363 bool MadeCFGChange = false; 1364 if (!runImpl(F, AC, TTI, TLI, DT, AA, MadeCFGChange)) { 1365 // No changes, all analyses are preserved. 1366 return PreservedAnalyses::all(); 1367 } 1368 // Mark all the analyses that instcombine updates as preserved. 1369 PreservedAnalyses PA; 1370 if (MadeCFGChange) 1371 PA.preserve<DominatorTreeAnalysis>(); 1372 else 1373 PA.preserveSet<CFGAnalyses>(); 1374 return PA; 1375 } 1376