xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===- AggressiveInstCombine.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the aggressive expression pattern combiner classes.
10 // Currently, it handles expression patterns for:
11 //  * Truncate instruction
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Transforms/AggressiveInstCombine/AggressiveInstCombine.h"
16 #include "AggressiveInstCombineInternal.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/AliasAnalysis.h"
19 #include "llvm/Analysis/AssumptionCache.h"
20 #include "llvm/Analysis/BasicAliasAnalysis.h"
21 #include "llvm/Analysis/ConstantFolding.h"
22 #include "llvm/Analysis/DomTreeUpdater.h"
23 #include "llvm/Analysis/GlobalsModRef.h"
24 #include "llvm/Analysis/TargetLibraryInfo.h"
25 #include "llvm/Analysis/TargetTransformInfo.h"
26 #include "llvm/Analysis/ValueTracking.h"
27 #include "llvm/IR/DataLayout.h"
28 #include "llvm/IR/Dominators.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/PatternMatch.h"
32 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
33 #include "llvm/Transforms/Utils/BuildLibCalls.h"
34 #include "llvm/Transforms/Utils/Local.h"
35 
36 using namespace llvm;
37 using namespace PatternMatch;
38 
39 #define DEBUG_TYPE "aggressive-instcombine"
40 
41 STATISTIC(NumAnyOrAllBitsSet, "Number of any/all-bits-set patterns folded");
42 STATISTIC(NumGuardedRotates,
43           "Number of guarded rotates transformed into funnel shifts");
44 STATISTIC(NumGuardedFunnelShifts,
45           "Number of guarded funnel shifts transformed into funnel shifts");
46 STATISTIC(NumPopCountRecognized, "Number of popcount idioms recognized");
47 
48 static cl::opt<unsigned> MaxInstrsToScan(
49     "aggressive-instcombine-max-scan-instrs", cl::init(64), cl::Hidden,
50     cl::desc("Max number of instructions to scan for aggressive instcombine."));
51 
52 static cl::opt<unsigned> StrNCmpInlineThreshold(
53     "strncmp-inline-threshold", cl::init(3), cl::Hidden,
54     cl::desc("The maximum length of a constant string for a builtin string cmp "
55              "call eligible for inlining. The default value is 3."));
56 
57 static cl::opt<unsigned>
58     MemChrInlineThreshold("memchr-inline-threshold", cl::init(3), cl::Hidden,
59                           cl::desc("The maximum length of a constant string to "
60                                    "inline a memchr call."));
61 
62 /// Match a pattern for a bitwise funnel/rotate operation that partially guards
63 /// against undefined behavior by branching around the funnel-shift/rotation
64 /// when the shift amount is 0.
65 static bool foldGuardedFunnelShift(Instruction &I, const DominatorTree &DT) {
66   if (I.getOpcode() != Instruction::PHI || I.getNumOperands() != 2)
67     return false;
68 
69   // As with the one-use checks below, this is not strictly necessary, but we
70   // are being cautious to avoid potential perf regressions on targets that
71   // do not actually have a funnel/rotate instruction (where the funnel shift
72   // would be expanded back into math/shift/logic ops).
73   if (!isPowerOf2_32(I.getType()->getScalarSizeInBits()))
74     return false;
75 
76   // Match V to funnel shift left/right and capture the source operands and
77   // shift amount.
78   auto matchFunnelShift = [](Value *V, Value *&ShVal0, Value *&ShVal1,
79                              Value *&ShAmt) {
80     unsigned Width = V->getType()->getScalarSizeInBits();
81 
82     // fshl(ShVal0, ShVal1, ShAmt)
83     //  == (ShVal0 << ShAmt) | (ShVal1 >> (Width -ShAmt))
84     if (match(V, m_OneUse(m_c_Or(
85                      m_Shl(m_Value(ShVal0), m_Value(ShAmt)),
86                      m_LShr(m_Value(ShVal1),
87                             m_Sub(m_SpecificInt(Width), m_Deferred(ShAmt))))))) {
88       return Intrinsic::fshl;
89     }
90 
91     // fshr(ShVal0, ShVal1, ShAmt)
92     //  == (ShVal0 >> ShAmt) | (ShVal1 << (Width - ShAmt))
93     if (match(V,
94               m_OneUse(m_c_Or(m_Shl(m_Value(ShVal0), m_Sub(m_SpecificInt(Width),
95                                                            m_Value(ShAmt))),
96                               m_LShr(m_Value(ShVal1), m_Deferred(ShAmt)))))) {
97       return Intrinsic::fshr;
98     }
99 
100     return Intrinsic::not_intrinsic;
101   };
102 
103   // One phi operand must be a funnel/rotate operation, and the other phi
104   // operand must be the source value of that funnel/rotate operation:
105   // phi [ rotate(RotSrc, ShAmt), FunnelBB ], [ RotSrc, GuardBB ]
106   // phi [ fshl(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal0, GuardBB ]
107   // phi [ fshr(ShVal0, ShVal1, ShAmt), FunnelBB ], [ ShVal1, GuardBB ]
108   PHINode &Phi = cast<PHINode>(I);
109   unsigned FunnelOp = 0, GuardOp = 1;
110   Value *P0 = Phi.getOperand(0), *P1 = Phi.getOperand(1);
111   Value *ShVal0, *ShVal1, *ShAmt;
112   Intrinsic::ID IID = matchFunnelShift(P0, ShVal0, ShVal1, ShAmt);
113   if (IID == Intrinsic::not_intrinsic ||
114       (IID == Intrinsic::fshl && ShVal0 != P1) ||
115       (IID == Intrinsic::fshr && ShVal1 != P1)) {
116     IID = matchFunnelShift(P1, ShVal0, ShVal1, ShAmt);
117     if (IID == Intrinsic::not_intrinsic ||
118         (IID == Intrinsic::fshl && ShVal0 != P0) ||
119         (IID == Intrinsic::fshr && ShVal1 != P0))
120       return false;
121     assert((IID == Intrinsic::fshl || IID == Intrinsic::fshr) &&
122            "Pattern must match funnel shift left or right");
123     std::swap(FunnelOp, GuardOp);
124   }
125 
126   // The incoming block with our source operand must be the "guard" block.
127   // That must contain a cmp+branch to avoid the funnel/rotate when the shift
128   // amount is equal to 0. The other incoming block is the block with the
129   // funnel/rotate.
130   BasicBlock *GuardBB = Phi.getIncomingBlock(GuardOp);
131   BasicBlock *FunnelBB = Phi.getIncomingBlock(FunnelOp);
132   Instruction *TermI = GuardBB->getTerminator();
133 
134   // Ensure that the shift values dominate each block.
135   if (!DT.dominates(ShVal0, TermI) || !DT.dominates(ShVal1, TermI))
136     return false;
137 
138   BasicBlock *PhiBB = Phi.getParent();
139   if (!match(TermI, m_Br(m_SpecificICmp(CmpInst::ICMP_EQ, m_Specific(ShAmt),
140                                         m_ZeroInt()),
141                          m_SpecificBB(PhiBB), m_SpecificBB(FunnelBB))))
142     return false;
143 
144   IRBuilder<> Builder(PhiBB, PhiBB->getFirstInsertionPt());
145 
146   if (ShVal0 == ShVal1)
147     ++NumGuardedRotates;
148   else
149     ++NumGuardedFunnelShifts;
150 
151   // If this is not a rotate then the select was blocking poison from the
152   // 'shift-by-zero' non-TVal, but a funnel shift won't - so freeze it.
153   bool IsFshl = IID == Intrinsic::fshl;
154   if (ShVal0 != ShVal1) {
155     if (IsFshl && !llvm::isGuaranteedNotToBePoison(ShVal1))
156       ShVal1 = Builder.CreateFreeze(ShVal1);
157     else if (!IsFshl && !llvm::isGuaranteedNotToBePoison(ShVal0))
158       ShVal0 = Builder.CreateFreeze(ShVal0);
159   }
160 
161   // We matched a variation of this IR pattern:
162   // GuardBB:
163   //   %cmp = icmp eq i32 %ShAmt, 0
164   //   br i1 %cmp, label %PhiBB, label %FunnelBB
165   // FunnelBB:
166   //   %sub = sub i32 32, %ShAmt
167   //   %shr = lshr i32 %ShVal1, %sub
168   //   %shl = shl i32 %ShVal0, %ShAmt
169   //   %fsh = or i32 %shr, %shl
170   //   br label %PhiBB
171   // PhiBB:
172   //   %cond = phi i32 [ %fsh, %FunnelBB ], [ %ShVal0, %GuardBB ]
173   // -->
174   // llvm.fshl.i32(i32 %ShVal0, i32 %ShVal1, i32 %ShAmt)
175   Phi.replaceAllUsesWith(
176       Builder.CreateIntrinsic(IID, Phi.getType(), {ShVal0, ShVal1, ShAmt}));
177   return true;
178 }
179 
180 /// This is used by foldAnyOrAllBitsSet() to capture a source value (Root) and
181 /// the bit indexes (Mask) needed by a masked compare. If we're matching a chain
182 /// of 'and' ops, then we also need to capture the fact that we saw an
183 /// "and X, 1", so that's an extra return value for that case.
184 namespace {
185 struct MaskOps {
186   Value *Root = nullptr;
187   APInt Mask;
188   bool MatchAndChain;
189   bool FoundAnd1 = false;
190 
191   MaskOps(unsigned BitWidth, bool MatchAnds)
192       : Mask(APInt::getZero(BitWidth)), MatchAndChain(MatchAnds) {}
193 };
194 } // namespace
195 
196 /// This is a recursive helper for foldAnyOrAllBitsSet() that walks through a
197 /// chain of 'and' or 'or' instructions looking for shift ops of a common source
198 /// value. Examples:
199 ///   or (or (or X, (X >> 3)), (X >> 5)), (X >> 8)
200 /// returns { X, 0x129 }
201 ///   and (and (X >> 1), 1), (X >> 4)
202 /// returns { X, 0x12 }
203 static bool matchAndOrChain(Value *V, MaskOps &MOps) {
204   Value *Op0, *Op1;
205   if (MOps.MatchAndChain) {
206     // Recurse through a chain of 'and' operands. This requires an extra check
207     // vs. the 'or' matcher: we must find an "and X, 1" instruction somewhere
208     // in the chain to know that all of the high bits are cleared.
209     if (match(V, m_And(m_Value(Op0), m_One()))) {
210       MOps.FoundAnd1 = true;
211       return matchAndOrChain(Op0, MOps);
212     }
213     if (match(V, m_And(m_Value(Op0), m_Value(Op1))))
214       return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps);
215   } else {
216     // Recurse through a chain of 'or' operands.
217     if (match(V, m_Or(m_Value(Op0), m_Value(Op1))))
218       return matchAndOrChain(Op0, MOps) && matchAndOrChain(Op1, MOps);
219   }
220 
221   // We need a shift-right or a bare value representing a compare of bit 0 of
222   // the original source operand.
223   Value *Candidate;
224   const APInt *BitIndex = nullptr;
225   if (!match(V, m_LShr(m_Value(Candidate), m_APInt(BitIndex))))
226     Candidate = V;
227 
228   // Initialize result source operand.
229   if (!MOps.Root)
230     MOps.Root = Candidate;
231 
232   // The shift constant is out-of-range? This code hasn't been simplified.
233   if (BitIndex && BitIndex->uge(MOps.Mask.getBitWidth()))
234     return false;
235 
236   // Fill in the mask bit derived from the shift constant.
237   MOps.Mask.setBit(BitIndex ? BitIndex->getZExtValue() : 0);
238   return MOps.Root == Candidate;
239 }
240 
241 /// Match patterns that correspond to "any-bits-set" and "all-bits-set".
242 /// These will include a chain of 'or' or 'and'-shifted bits from a
243 /// common source value:
244 /// and (or  (lshr X, C), ...), 1 --> (X & CMask) != 0
245 /// and (and (lshr X, C), ...), 1 --> (X & CMask) == CMask
246 /// Note: "any-bits-clear" and "all-bits-clear" are variations of these patterns
247 /// that differ only with a final 'not' of the result. We expect that final
248 /// 'not' to be folded with the compare that we create here (invert predicate).
249 static bool foldAnyOrAllBitsSet(Instruction &I) {
250   // The 'any-bits-set' ('or' chain) pattern is simpler to match because the
251   // final "and X, 1" instruction must be the final op in the sequence.
252   bool MatchAllBitsSet;
253   if (match(&I, m_c_And(m_OneUse(m_And(m_Value(), m_Value())), m_Value())))
254     MatchAllBitsSet = true;
255   else if (match(&I, m_And(m_OneUse(m_Or(m_Value(), m_Value())), m_One())))
256     MatchAllBitsSet = false;
257   else
258     return false;
259 
260   MaskOps MOps(I.getType()->getScalarSizeInBits(), MatchAllBitsSet);
261   if (MatchAllBitsSet) {
262     if (!matchAndOrChain(cast<BinaryOperator>(&I), MOps) || !MOps.FoundAnd1)
263       return false;
264   } else {
265     if (!matchAndOrChain(cast<BinaryOperator>(&I)->getOperand(0), MOps))
266       return false;
267   }
268 
269   // The pattern was found. Create a masked compare that replaces all of the
270   // shift and logic ops.
271   IRBuilder<> Builder(&I);
272   Constant *Mask = ConstantInt::get(I.getType(), MOps.Mask);
273   Value *And = Builder.CreateAnd(MOps.Root, Mask);
274   Value *Cmp = MatchAllBitsSet ? Builder.CreateICmpEQ(And, Mask)
275                                : Builder.CreateIsNotNull(And);
276   Value *Zext = Builder.CreateZExt(Cmp, I.getType());
277   I.replaceAllUsesWith(Zext);
278   ++NumAnyOrAllBitsSet;
279   return true;
280 }
281 
282 // Try to recognize below function as popcount intrinsic.
283 // This is the "best" algorithm from
284 // http://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel
285 // Also used in TargetLowering::expandCTPOP().
286 //
287 // int popcount(unsigned int i) {
288 //   i = i - ((i >> 1) & 0x55555555);
289 //   i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
290 //   i = ((i + (i >> 4)) & 0x0F0F0F0F);
291 //   return (i * 0x01010101) >> 24;
292 // }
293 static bool tryToRecognizePopCount(Instruction &I) {
294   if (I.getOpcode() != Instruction::LShr)
295     return false;
296 
297   Type *Ty = I.getType();
298   if (!Ty->isIntOrIntVectorTy())
299     return false;
300 
301   unsigned Len = Ty->getScalarSizeInBits();
302   // FIXME: fix Len == 8 and other irregular type lengths.
303   if (!(Len <= 128 && Len > 8 && Len % 8 == 0))
304     return false;
305 
306   APInt Mask55 = APInt::getSplat(Len, APInt(8, 0x55));
307   APInt Mask33 = APInt::getSplat(Len, APInt(8, 0x33));
308   APInt Mask0F = APInt::getSplat(Len, APInt(8, 0x0F));
309   APInt Mask01 = APInt::getSplat(Len, APInt(8, 0x01));
310   APInt MaskShift = APInt(Len, Len - 8);
311 
312   Value *Op0 = I.getOperand(0);
313   Value *Op1 = I.getOperand(1);
314   Value *MulOp0;
315   // Matching "(i * 0x01010101...) >> 24".
316   if ((match(Op0, m_Mul(m_Value(MulOp0), m_SpecificInt(Mask01)))) &&
317       match(Op1, m_SpecificInt(MaskShift))) {
318     Value *ShiftOp0;
319     // Matching "((i + (i >> 4)) & 0x0F0F0F0F...)".
320     if (match(MulOp0, m_And(m_c_Add(m_LShr(m_Value(ShiftOp0), m_SpecificInt(4)),
321                                     m_Deferred(ShiftOp0)),
322                             m_SpecificInt(Mask0F)))) {
323       Value *AndOp0;
324       // Matching "(i & 0x33333333...) + ((i >> 2) & 0x33333333...)".
325       if (match(ShiftOp0,
326                 m_c_Add(m_And(m_Value(AndOp0), m_SpecificInt(Mask33)),
327                         m_And(m_LShr(m_Deferred(AndOp0), m_SpecificInt(2)),
328                               m_SpecificInt(Mask33))))) {
329         Value *Root, *SubOp1;
330         // Matching "i - ((i >> 1) & 0x55555555...)".
331         const APInt *AndMask;
332         if (match(AndOp0, m_Sub(m_Value(Root), m_Value(SubOp1))) &&
333             match(SubOp1, m_And(m_LShr(m_Specific(Root), m_SpecificInt(1)),
334                                 m_APInt(AndMask)))) {
335           auto CheckAndMask = [&]() {
336             if (*AndMask == Mask55)
337               return true;
338 
339             // Exact match failed, see if any bits are known to be 0 where we
340             // expect a 1 in the mask.
341             if (!AndMask->isSubsetOf(Mask55))
342               return false;
343 
344             APInt NeededMask = Mask55 & ~*AndMask;
345             return MaskedValueIsZero(cast<Instruction>(SubOp1)->getOperand(0),
346                                      NeededMask,
347                                      SimplifyQuery(I.getDataLayout()));
348           };
349 
350           if (CheckAndMask()) {
351             LLVM_DEBUG(dbgs() << "Recognized popcount intrinsic\n");
352             IRBuilder<> Builder(&I);
353             I.replaceAllUsesWith(
354                 Builder.CreateIntrinsic(Intrinsic::ctpop, I.getType(), {Root}));
355             ++NumPopCountRecognized;
356             return true;
357           }
358         }
359       }
360     }
361   }
362 
363   return false;
364 }
365 
366 /// Fold smin(smax(fptosi(x), C1), C2) to llvm.fptosi.sat(x), providing C1 and
367 /// C2 saturate the value of the fp conversion. The transform is not reversable
368 /// as the fptosi.sat is more defined than the input - all values produce a
369 /// valid value for the fptosi.sat, where as some produce poison for original
370 /// that were out of range of the integer conversion. The reversed pattern may
371 /// use fmax and fmin instead. As we cannot directly reverse the transform, and
372 /// it is not always profitable, we make it conditional on the cost being
373 /// reported as lower by TTI.
374 static bool tryToFPToSat(Instruction &I, TargetTransformInfo &TTI) {
375   // Look for min(max(fptosi, converting to fptosi_sat.
376   Value *In;
377   const APInt *MinC, *MaxC;
378   if (!match(&I, m_SMax(m_OneUse(m_SMin(m_OneUse(m_FPToSI(m_Value(In))),
379                                         m_APInt(MinC))),
380                         m_APInt(MaxC))) &&
381       !match(&I, m_SMin(m_OneUse(m_SMax(m_OneUse(m_FPToSI(m_Value(In))),
382                                         m_APInt(MaxC))),
383                         m_APInt(MinC))))
384     return false;
385 
386   // Check that the constants clamp a saturate.
387   if (!(*MinC + 1).isPowerOf2() || -*MaxC != *MinC + 1)
388     return false;
389 
390   Type *IntTy = I.getType();
391   Type *FpTy = In->getType();
392   Type *SatTy =
393       IntegerType::get(IntTy->getContext(), (*MinC + 1).exactLogBase2() + 1);
394   if (auto *VecTy = dyn_cast<VectorType>(IntTy))
395     SatTy = VectorType::get(SatTy, VecTy->getElementCount());
396 
397   // Get the cost of the intrinsic, and check that against the cost of
398   // fptosi+smin+smax
399   InstructionCost SatCost = TTI.getIntrinsicInstrCost(
400       IntrinsicCostAttributes(Intrinsic::fptosi_sat, SatTy, {In}, {FpTy}),
401       TTI::TCK_RecipThroughput);
402   SatCost += TTI.getCastInstrCost(Instruction::SExt, IntTy, SatTy,
403                                   TTI::CastContextHint::None,
404                                   TTI::TCK_RecipThroughput);
405 
406   InstructionCost MinMaxCost = TTI.getCastInstrCost(
407       Instruction::FPToSI, IntTy, FpTy, TTI::CastContextHint::None,
408       TTI::TCK_RecipThroughput);
409   MinMaxCost += TTI.getIntrinsicInstrCost(
410       IntrinsicCostAttributes(Intrinsic::smin, IntTy, {IntTy}),
411       TTI::TCK_RecipThroughput);
412   MinMaxCost += TTI.getIntrinsicInstrCost(
413       IntrinsicCostAttributes(Intrinsic::smax, IntTy, {IntTy}),
414       TTI::TCK_RecipThroughput);
415 
416   if (SatCost >= MinMaxCost)
417     return false;
418 
419   IRBuilder<> Builder(&I);
420   Value *Sat =
421       Builder.CreateIntrinsic(Intrinsic::fptosi_sat, {SatTy, FpTy}, In);
422   I.replaceAllUsesWith(Builder.CreateSExt(Sat, IntTy));
423   return true;
424 }
425 
426 /// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids
427 /// pessimistic codegen that has to account for setting errno and can enable
428 /// vectorization.
429 static bool foldSqrt(CallInst *Call, LibFunc Func, TargetTransformInfo &TTI,
430                      TargetLibraryInfo &TLI, AssumptionCache &AC,
431                      DominatorTree &DT) {
432   // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created
433   // (because NNAN or the operand arg must not be less than -0.0) and (2) we
434   // would not end up lowering to a libcall anyway (which could change the value
435   // of errno), then:
436   // (1) errno won't be set.
437   // (2) it is safe to convert this to an intrinsic call.
438   Type *Ty = Call->getType();
439   Value *Arg = Call->getArgOperand(0);
440   if (TTI.haveFastSqrt(Ty) &&
441       (Call->hasNoNaNs() ||
442        cannotBeOrderedLessThanZero(
443            Arg, SimplifyQuery(Call->getDataLayout(), &TLI, &DT, &AC, Call)))) {
444     IRBuilder<> Builder(Call);
445     Value *NewSqrt =
446         Builder.CreateIntrinsic(Intrinsic::sqrt, Ty, Arg, Call, "sqrt");
447     Call->replaceAllUsesWith(NewSqrt);
448 
449     // Explicitly erase the old call because a call with side effects is not
450     // trivially dead.
451     Call->eraseFromParent();
452     return true;
453   }
454 
455   return false;
456 }
457 
458 // Check if this array of constants represents a cttz table.
459 // Iterate over the elements from \p Table by trying to find/match all
460 // the numbers from 0 to \p InputBits that should represent cttz results.
461 static bool isCTTZTable(const ConstantDataArray &Table, uint64_t Mul,
462                         uint64_t Shift, uint64_t InputBits) {
463   unsigned Length = Table.getNumElements();
464   if (Length < InputBits || Length > InputBits * 2)
465     return false;
466 
467   APInt Mask = APInt::getBitsSetFrom(InputBits, Shift);
468   unsigned Matched = 0;
469 
470   for (unsigned i = 0; i < Length; i++) {
471     uint64_t Element = Table.getElementAsInteger(i);
472     if (Element >= InputBits)
473       continue;
474 
475     // Check if \p Element matches a concrete answer. It could fail for some
476     // elements that are never accessed, so we keep iterating over each element
477     // from the table. The number of matched elements should be equal to the
478     // number of potential right answers which is \p InputBits actually.
479     if ((((Mul << Element) & Mask.getZExtValue()) >> Shift) == i)
480       Matched++;
481   }
482 
483   return Matched == InputBits;
484 }
485 
486 // Try to recognize table-based ctz implementation.
487 // E.g., an example in C (for more cases please see the llvm/tests):
488 // int f(unsigned x) {
489 //    static const char table[32] =
490 //      {0, 1, 28, 2, 29, 14, 24, 3, 30,
491 //       22, 20, 15, 25, 17, 4, 8, 31, 27,
492 //       13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9};
493 //    return table[((unsigned)((x & -x) * 0x077CB531U)) >> 27];
494 // }
495 // this can be lowered to `cttz` instruction.
496 // There is also a special case when the element is 0.
497 //
498 // Here are some examples or LLVM IR for a 64-bit target:
499 //
500 // CASE 1:
501 // %sub = sub i32 0, %x
502 // %and = and i32 %sub, %x
503 // %mul = mul i32 %and, 125613361
504 // %shr = lshr i32 %mul, 27
505 // %idxprom = zext i32 %shr to i64
506 // %arrayidx = getelementptr inbounds [32 x i8], [32 x i8]* @ctz1.table, i64 0,
507 //     i64 %idxprom
508 // %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
509 //
510 // CASE 2:
511 // %sub = sub i32 0, %x
512 // %and = and i32 %sub, %x
513 // %mul = mul i32 %and, 72416175
514 // %shr = lshr i32 %mul, 26
515 // %idxprom = zext i32 %shr to i64
516 // %arrayidx = getelementptr inbounds [64 x i16], [64 x i16]* @ctz2.table,
517 //     i64 0, i64 %idxprom
518 // %0 = load i16, i16* %arrayidx, align 2, !tbaa !8
519 //
520 // CASE 3:
521 // %sub = sub i32 0, %x
522 // %and = and i32 %sub, %x
523 // %mul = mul i32 %and, 81224991
524 // %shr = lshr i32 %mul, 27
525 // %idxprom = zext i32 %shr to i64
526 // %arrayidx = getelementptr inbounds [32 x i32], [32 x i32]* @ctz3.table,
527 //     i64 0, i64 %idxprom
528 // %0 = load i32, i32* %arrayidx, align 4, !tbaa !8
529 //
530 // CASE 4:
531 // %sub = sub i64 0, %x
532 // %and = and i64 %sub, %x
533 // %mul = mul i64 %and, 283881067100198605
534 // %shr = lshr i64 %mul, 58
535 // %arrayidx = getelementptr inbounds [64 x i8], [64 x i8]* @table, i64 0,
536 //     i64 %shr
537 // %0 = load i8, i8* %arrayidx, align 1, !tbaa !8
538 //
539 // All this can be lowered to @llvm.cttz.i32/64 intrinsic.
540 static bool tryToRecognizeTableBasedCttz(Instruction &I) {
541   LoadInst *LI = dyn_cast<LoadInst>(&I);
542   if (!LI)
543     return false;
544 
545   Type *AccessType = LI->getType();
546   if (!AccessType->isIntegerTy())
547     return false;
548 
549   GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getPointerOperand());
550   if (!GEP || !GEP->hasNoUnsignedSignedWrap() || GEP->getNumIndices() != 2)
551     return false;
552 
553   if (!GEP->getSourceElementType()->isArrayTy())
554     return false;
555 
556   uint64_t ArraySize = GEP->getSourceElementType()->getArrayNumElements();
557   if (ArraySize != 32 && ArraySize != 64)
558     return false;
559 
560   GlobalVariable *GVTable = dyn_cast<GlobalVariable>(GEP->getPointerOperand());
561   if (!GVTable || !GVTable->hasInitializer() || !GVTable->isConstant())
562     return false;
563 
564   ConstantDataArray *ConstData =
565       dyn_cast<ConstantDataArray>(GVTable->getInitializer());
566   if (!ConstData)
567     return false;
568 
569   if (!match(GEP->idx_begin()->get(), m_ZeroInt()))
570     return false;
571 
572   Value *Idx2 = std::next(GEP->idx_begin())->get();
573   Value *X1;
574   uint64_t MulConst, ShiftConst;
575   // FIXME: 64-bit targets have `i64` type for the GEP index, so this match will
576   // probably fail for other (e.g. 32-bit) targets.
577   if (!match(Idx2, m_ZExtOrSelf(
578                        m_LShr(m_Mul(m_c_And(m_Neg(m_Value(X1)), m_Deferred(X1)),
579                                     m_ConstantInt(MulConst)),
580                               m_ConstantInt(ShiftConst)))))
581     return false;
582 
583   unsigned InputBits = X1->getType()->getScalarSizeInBits();
584   if (InputBits != 32 && InputBits != 64)
585     return false;
586 
587   // Shift should extract top 5..7 bits.
588   if (InputBits - Log2_32(InputBits) != ShiftConst &&
589       InputBits - Log2_32(InputBits) - 1 != ShiftConst)
590     return false;
591 
592   if (!isCTTZTable(*ConstData, MulConst, ShiftConst, InputBits))
593     return false;
594 
595   auto ZeroTableElem = ConstData->getElementAsInteger(0);
596   bool DefinedForZero = ZeroTableElem == InputBits;
597 
598   IRBuilder<> B(LI);
599   ConstantInt *BoolConst = B.getInt1(!DefinedForZero);
600   Type *XType = X1->getType();
601   auto Cttz = B.CreateIntrinsic(Intrinsic::cttz, {XType}, {X1, BoolConst});
602   Value *ZExtOrTrunc = nullptr;
603 
604   if (DefinedForZero) {
605     ZExtOrTrunc = B.CreateZExtOrTrunc(Cttz, AccessType);
606   } else {
607     // If the value in elem 0 isn't the same as InputBits, we still want to
608     // produce the value from the table.
609     auto Cmp = B.CreateICmpEQ(X1, ConstantInt::get(XType, 0));
610     auto Select =
611         B.CreateSelect(Cmp, ConstantInt::get(XType, ZeroTableElem), Cttz);
612 
613     // NOTE: If the table[0] is 0, but the cttz(0) is defined by the Target
614     // it should be handled as: `cttz(x) & (typeSize - 1)`.
615 
616     ZExtOrTrunc = B.CreateZExtOrTrunc(Select, AccessType);
617   }
618 
619   LI->replaceAllUsesWith(ZExtOrTrunc);
620 
621   return true;
622 }
623 
624 /// This is used by foldLoadsRecursive() to capture a Root Load node which is
625 /// of type or(load, load) and recursively build the wide load. Also capture the
626 /// shift amount, zero extend type and loadSize.
627 struct LoadOps {
628   LoadInst *Root = nullptr;
629   LoadInst *RootInsert = nullptr;
630   bool FoundRoot = false;
631   uint64_t LoadSize = 0;
632   const APInt *Shift = nullptr;
633   Type *ZextType;
634   AAMDNodes AATags;
635 };
636 
637 // Identify and Merge consecutive loads recursively which is of the form
638 // (ZExt(L1) << shift1) | (ZExt(L2) << shift2) -> ZExt(L3) << shift1
639 // (ZExt(L1) << shift1) | ZExt(L2) -> ZExt(L3)
640 static bool foldLoadsRecursive(Value *V, LoadOps &LOps, const DataLayout &DL,
641                                AliasAnalysis &AA) {
642   const APInt *ShAmt2 = nullptr;
643   Value *X;
644   Instruction *L1, *L2;
645 
646   // Go to the last node with loads.
647   if (match(V, m_OneUse(m_c_Or(
648                    m_Value(X),
649                    m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))),
650                                   m_APInt(ShAmt2)))))) ||
651       match(V, m_OneUse(m_Or(m_Value(X),
652                              m_OneUse(m_ZExt(m_OneUse(m_Instruction(L2)))))))) {
653     if (!foldLoadsRecursive(X, LOps, DL, AA) && LOps.FoundRoot)
654       // Avoid Partial chain merge.
655       return false;
656   } else
657     return false;
658 
659   // Check if the pattern has loads
660   LoadInst *LI1 = LOps.Root;
661   const APInt *ShAmt1 = LOps.Shift;
662   if (LOps.FoundRoot == false &&
663       (match(X, m_OneUse(m_ZExt(m_Instruction(L1)))) ||
664        match(X, m_OneUse(m_Shl(m_OneUse(m_ZExt(m_OneUse(m_Instruction(L1)))),
665                                m_APInt(ShAmt1)))))) {
666     LI1 = dyn_cast<LoadInst>(L1);
667   }
668   LoadInst *LI2 = dyn_cast<LoadInst>(L2);
669 
670   // Check if loads are same, atomic, volatile and having same address space.
671   if (LI1 == LI2 || !LI1 || !LI2 || !LI1->isSimple() || !LI2->isSimple() ||
672       LI1->getPointerAddressSpace() != LI2->getPointerAddressSpace())
673     return false;
674 
675   // Check if Loads come from same BB.
676   if (LI1->getParent() != LI2->getParent())
677     return false;
678 
679   // Find the data layout
680   bool IsBigEndian = DL.isBigEndian();
681 
682   // Check if loads are consecutive and same size.
683   Value *Load1Ptr = LI1->getPointerOperand();
684   APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0);
685   Load1Ptr =
686       Load1Ptr->stripAndAccumulateConstantOffsets(DL, Offset1,
687                                                   /* AllowNonInbounds */ true);
688 
689   Value *Load2Ptr = LI2->getPointerOperand();
690   APInt Offset2(DL.getIndexTypeSizeInBits(Load2Ptr->getType()), 0);
691   Load2Ptr =
692       Load2Ptr->stripAndAccumulateConstantOffsets(DL, Offset2,
693                                                   /* AllowNonInbounds */ true);
694 
695   // Verify if both loads have same base pointers
696   uint64_t LoadSize1 = LI1->getType()->getPrimitiveSizeInBits();
697   uint64_t LoadSize2 = LI2->getType()->getPrimitiveSizeInBits();
698   if (Load1Ptr != Load2Ptr)
699     return false;
700 
701   // Make sure that there are no padding bits.
702   if (!DL.typeSizeEqualsStoreSize(LI1->getType()) ||
703       !DL.typeSizeEqualsStoreSize(LI2->getType()))
704     return false;
705 
706   // Alias Analysis to check for stores b/w the loads.
707   LoadInst *Start = LOps.FoundRoot ? LOps.RootInsert : LI1, *End = LI2;
708   MemoryLocation Loc;
709   if (!Start->comesBefore(End)) {
710     std::swap(Start, End);
711     Loc = MemoryLocation::get(End);
712     if (LOps.FoundRoot)
713       Loc = Loc.getWithNewSize(LOps.LoadSize);
714   } else
715     Loc = MemoryLocation::get(End);
716   unsigned NumScanned = 0;
717   for (Instruction &Inst :
718        make_range(Start->getIterator(), End->getIterator())) {
719     if (Inst.mayWriteToMemory() && isModSet(AA.getModRefInfo(&Inst, Loc)))
720       return false;
721 
722     if (++NumScanned > MaxInstrsToScan)
723       return false;
724   }
725 
726   // Make sure Load with lower Offset is at LI1
727   bool Reverse = false;
728   if (Offset2.slt(Offset1)) {
729     std::swap(LI1, LI2);
730     std::swap(ShAmt1, ShAmt2);
731     std::swap(Offset1, Offset2);
732     std::swap(Load1Ptr, Load2Ptr);
733     std::swap(LoadSize1, LoadSize2);
734     Reverse = true;
735   }
736 
737   // Big endian swap the shifts
738   if (IsBigEndian)
739     std::swap(ShAmt1, ShAmt2);
740 
741   // Find Shifts values.
742   uint64_t Shift1 = 0, Shift2 = 0;
743   if (ShAmt1)
744     Shift1 = ShAmt1->getZExtValue();
745   if (ShAmt2)
746     Shift2 = ShAmt2->getZExtValue();
747 
748   // First load is always LI1. This is where we put the new load.
749   // Use the merged load size available from LI1 for forward loads.
750   if (LOps.FoundRoot) {
751     if (!Reverse)
752       LoadSize1 = LOps.LoadSize;
753     else
754       LoadSize2 = LOps.LoadSize;
755   }
756 
757   // Verify if shift amount and load index aligns and verifies that loads
758   // are consecutive.
759   uint64_t ShiftDiff = IsBigEndian ? LoadSize2 : LoadSize1;
760   uint64_t PrevSize =
761       DL.getTypeStoreSize(IntegerType::get(LI1->getContext(), LoadSize1));
762   if ((Shift2 - Shift1) != ShiftDiff || (Offset2 - Offset1) != PrevSize)
763     return false;
764 
765   // Update LOps
766   AAMDNodes AATags1 = LOps.AATags;
767   AAMDNodes AATags2 = LI2->getAAMetadata();
768   if (LOps.FoundRoot == false) {
769     LOps.FoundRoot = true;
770     AATags1 = LI1->getAAMetadata();
771   }
772   LOps.LoadSize = LoadSize1 + LoadSize2;
773   LOps.RootInsert = Start;
774 
775   // Concatenate the AATags of the Merged Loads.
776   LOps.AATags = AATags1.concat(AATags2);
777 
778   LOps.Root = LI1;
779   LOps.Shift = ShAmt1;
780   LOps.ZextType = X->getType();
781   return true;
782 }
783 
784 // For a given BB instruction, evaluate all loads in the chain that form a
785 // pattern which suggests that the loads can be combined. The one and only use
786 // of the loads is to form a wider load.
787 static bool foldConsecutiveLoads(Instruction &I, const DataLayout &DL,
788                                  TargetTransformInfo &TTI, AliasAnalysis &AA,
789                                  const DominatorTree &DT) {
790   // Only consider load chains of scalar values.
791   if (isa<VectorType>(I.getType()))
792     return false;
793 
794   LoadOps LOps;
795   if (!foldLoadsRecursive(&I, LOps, DL, AA) || !LOps.FoundRoot)
796     return false;
797 
798   IRBuilder<> Builder(&I);
799   LoadInst *NewLoad = nullptr, *LI1 = LOps.Root;
800 
801   IntegerType *WiderType = IntegerType::get(I.getContext(), LOps.LoadSize);
802   // TTI based checks if we want to proceed with wider load
803   bool Allowed = TTI.isTypeLegal(WiderType);
804   if (!Allowed)
805     return false;
806 
807   unsigned AS = LI1->getPointerAddressSpace();
808   unsigned Fast = 0;
809   Allowed = TTI.allowsMisalignedMemoryAccesses(I.getContext(), LOps.LoadSize,
810                                                AS, LI1->getAlign(), &Fast);
811   if (!Allowed || !Fast)
812     return false;
813 
814   // Get the Index and Ptr for the new GEP.
815   Value *Load1Ptr = LI1->getPointerOperand();
816   Builder.SetInsertPoint(LOps.RootInsert);
817   if (!DT.dominates(Load1Ptr, LOps.RootInsert)) {
818     APInt Offset1(DL.getIndexTypeSizeInBits(Load1Ptr->getType()), 0);
819     Load1Ptr = Load1Ptr->stripAndAccumulateConstantOffsets(
820         DL, Offset1, /* AllowNonInbounds */ true);
821     Load1Ptr = Builder.CreatePtrAdd(Load1Ptr, Builder.getInt(Offset1));
822   }
823   // Generate wider load.
824   NewLoad = Builder.CreateAlignedLoad(WiderType, Load1Ptr, LI1->getAlign(),
825                                       LI1->isVolatile(), "");
826   NewLoad->takeName(LI1);
827   // Set the New Load AATags Metadata.
828   if (LOps.AATags)
829     NewLoad->setAAMetadata(LOps.AATags);
830 
831   Value *NewOp = NewLoad;
832   // Check if zero extend needed.
833   if (LOps.ZextType)
834     NewOp = Builder.CreateZExt(NewOp, LOps.ZextType);
835 
836   // Check if shift needed. We need to shift with the amount of load1
837   // shift if not zero.
838   if (LOps.Shift)
839     NewOp = Builder.CreateShl(NewOp, ConstantInt::get(I.getContext(), *LOps.Shift));
840   I.replaceAllUsesWith(NewOp);
841 
842   return true;
843 }
844 
845 /// Combine away instructions providing they are still equivalent when compared
846 /// against 0. i.e do they have any bits set.
847 static Value *optimizeShiftInOrChain(Value *V, IRBuilder<> &Builder) {
848   auto *I = dyn_cast<Instruction>(V);
849   if (!I || I->getOpcode() != Instruction::Or || !I->hasOneUse())
850     return nullptr;
851 
852   Value *A;
853 
854   // Look deeper into the chain of or's, combining away shl (so long as they are
855   // nuw or nsw).
856   Value *Op0 = I->getOperand(0);
857   if (match(Op0, m_CombineOr(m_NSWShl(m_Value(A), m_Value()),
858                              m_NUWShl(m_Value(A), m_Value()))))
859     Op0 = A;
860   else if (auto *NOp = optimizeShiftInOrChain(Op0, Builder))
861     Op0 = NOp;
862 
863   Value *Op1 = I->getOperand(1);
864   if (match(Op1, m_CombineOr(m_NSWShl(m_Value(A), m_Value()),
865                              m_NUWShl(m_Value(A), m_Value()))))
866     Op1 = A;
867   else if (auto *NOp = optimizeShiftInOrChain(Op1, Builder))
868     Op1 = NOp;
869 
870   if (Op0 != I->getOperand(0) || Op1 != I->getOperand(1))
871     return Builder.CreateOr(Op0, Op1);
872   return nullptr;
873 }
874 
875 static bool foldICmpOrChain(Instruction &I, const DataLayout &DL,
876                             TargetTransformInfo &TTI, AliasAnalysis &AA,
877                             const DominatorTree &DT) {
878   CmpPredicate Pred;
879   Value *Op0;
880   if (!match(&I, m_ICmp(Pred, m_Value(Op0), m_Zero())) ||
881       !ICmpInst::isEquality(Pred))
882     return false;
883 
884   // If the chain or or's matches a load, combine to that before attempting to
885   // remove shifts.
886   if (auto OpI = dyn_cast<Instruction>(Op0))
887     if (OpI->getOpcode() == Instruction::Or)
888       if (foldConsecutiveLoads(*OpI, DL, TTI, AA, DT))
889         return true;
890 
891   IRBuilder<> Builder(&I);
892   // icmp eq/ne or(shl(a), b), 0 -> icmp eq/ne or(a, b), 0
893   if (auto *Res = optimizeShiftInOrChain(Op0, Builder)) {
894     I.replaceAllUsesWith(Builder.CreateICmp(Pred, Res, I.getOperand(1)));
895     return true;
896   }
897 
898   return false;
899 }
900 
901 // Calculate GEP Stride and accumulated const ModOffset. Return Stride and
902 // ModOffset
903 static std::pair<APInt, APInt>
904 getStrideAndModOffsetOfGEP(Value *PtrOp, const DataLayout &DL) {
905   unsigned BW = DL.getIndexTypeSizeInBits(PtrOp->getType());
906   std::optional<APInt> Stride;
907   APInt ModOffset(BW, 0);
908   // Return a minimum gep stride, greatest common divisor of consective gep
909   // index scales(c.f. Bézout's identity).
910   while (auto *GEP = dyn_cast<GEPOperator>(PtrOp)) {
911     SmallMapVector<Value *, APInt, 4> VarOffsets;
912     if (!GEP->collectOffset(DL, BW, VarOffsets, ModOffset))
913       break;
914 
915     for (auto [V, Scale] : VarOffsets) {
916       // Only keep a power of two factor for non-inbounds
917       if (!GEP->hasNoUnsignedSignedWrap())
918         Scale = APInt::getOneBitSet(Scale.getBitWidth(), Scale.countr_zero());
919 
920       if (!Stride)
921         Stride = Scale;
922       else
923         Stride = APIntOps::GreatestCommonDivisor(*Stride, Scale);
924     }
925 
926     PtrOp = GEP->getPointerOperand();
927   }
928 
929   // Check whether pointer arrives back at Global Variable via at least one GEP.
930   // Even if it doesn't, we can check by alignment.
931   if (!isa<GlobalVariable>(PtrOp) || !Stride)
932     return {APInt(BW, 1), APInt(BW, 0)};
933 
934   // In consideration of signed GEP indices, non-negligible offset become
935   // remainder of division by minimum GEP stride.
936   ModOffset = ModOffset.srem(*Stride);
937   if (ModOffset.isNegative())
938     ModOffset += *Stride;
939 
940   return {*Stride, ModOffset};
941 }
942 
943 /// If C is a constant patterned array and all valid loaded results for given
944 /// alignment are same to a constant, return that constant.
945 static bool foldPatternedLoads(Instruction &I, const DataLayout &DL) {
946   auto *LI = dyn_cast<LoadInst>(&I);
947   if (!LI || LI->isVolatile())
948     return false;
949 
950   // We can only fold the load if it is from a constant global with definitive
951   // initializer. Skip expensive logic if this is not the case.
952   auto *PtrOp = LI->getPointerOperand();
953   auto *GV = dyn_cast<GlobalVariable>(getUnderlyingObject(PtrOp));
954   if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer())
955     return false;
956 
957   // Bail for large initializers in excess of 4K to avoid too many scans.
958   Constant *C = GV->getInitializer();
959   uint64_t GVSize = DL.getTypeAllocSize(C->getType());
960   if (!GVSize || 4096 < GVSize)
961     return false;
962 
963   Type *LoadTy = LI->getType();
964   unsigned BW = DL.getIndexTypeSizeInBits(PtrOp->getType());
965   auto [Stride, ConstOffset] = getStrideAndModOffsetOfGEP(PtrOp, DL);
966 
967   // Any possible offset could be multiple of GEP stride. And any valid
968   // offset is multiple of load alignment, so checking only multiples of bigger
969   // one is sufficient to say results' equality.
970   if (auto LA = LI->getAlign();
971       LA <= GV->getAlign().valueOrOne() && Stride.getZExtValue() < LA.value()) {
972     ConstOffset = APInt(BW, 0);
973     Stride = APInt(BW, LA.value());
974   }
975 
976   Constant *Ca = ConstantFoldLoadFromConst(C, LoadTy, ConstOffset, DL);
977   if (!Ca)
978     return false;
979 
980   unsigned E = GVSize - DL.getTypeStoreSize(LoadTy);
981   for (; ConstOffset.getZExtValue() <= E; ConstOffset += Stride)
982     if (Ca != ConstantFoldLoadFromConst(C, LoadTy, ConstOffset, DL))
983       return false;
984 
985   I.replaceAllUsesWith(Ca);
986 
987   return true;
988 }
989 
990 namespace {
991 class StrNCmpInliner {
992 public:
993   StrNCmpInliner(CallInst *CI, LibFunc Func, DomTreeUpdater *DTU,
994                  const DataLayout &DL)
995       : CI(CI), Func(Func), DTU(DTU), DL(DL) {}
996 
997   bool optimizeStrNCmp();
998 
999 private:
1000   void inlineCompare(Value *LHS, StringRef RHS, uint64_t N, bool Swapped);
1001 
1002   CallInst *CI;
1003   LibFunc Func;
1004   DomTreeUpdater *DTU;
1005   const DataLayout &DL;
1006 };
1007 
1008 } // namespace
1009 
1010 /// First we normalize calls to strncmp/strcmp to the form of
1011 /// compare(s1, s2, N), which means comparing first N bytes of s1 and s2
1012 /// (without considering '\0').
1013 ///
1014 /// Examples:
1015 ///
1016 /// \code
1017 ///   strncmp(s, "a", 3) -> compare(s, "a", 2)
1018 ///   strncmp(s, "abc", 3) -> compare(s, "abc", 3)
1019 ///   strncmp(s, "a\0b", 3) -> compare(s, "a\0b", 2)
1020 ///   strcmp(s, "a") -> compare(s, "a", 2)
1021 ///
1022 ///   char s2[] = {'a'}
1023 ///   strncmp(s, s2, 3) -> compare(s, s2, 3)
1024 ///
1025 ///   char s2[] = {'a', 'b', 'c', 'd'}
1026 ///   strncmp(s, s2, 3) -> compare(s, s2, 3)
1027 /// \endcode
1028 ///
1029 /// We only handle cases where N and exactly one of s1 and s2 are constant.
1030 /// Cases that s1 and s2 are both constant are already handled by the
1031 /// instcombine pass.
1032 ///
1033 /// We do not handle cases where N > StrNCmpInlineThreshold.
1034 ///
1035 /// We also do not handles cases where N < 2, which are already
1036 /// handled by the instcombine pass.
1037 ///
1038 bool StrNCmpInliner::optimizeStrNCmp() {
1039   if (StrNCmpInlineThreshold < 2)
1040     return false;
1041 
1042   if (!isOnlyUsedInZeroComparison(CI))
1043     return false;
1044 
1045   Value *Str1P = CI->getArgOperand(0);
1046   Value *Str2P = CI->getArgOperand(1);
1047   // Should be handled elsewhere.
1048   if (Str1P == Str2P)
1049     return false;
1050 
1051   StringRef Str1, Str2;
1052   bool HasStr1 = getConstantStringInfo(Str1P, Str1, /*TrimAtNul=*/false);
1053   bool HasStr2 = getConstantStringInfo(Str2P, Str2, /*TrimAtNul=*/false);
1054   if (HasStr1 == HasStr2)
1055     return false;
1056 
1057   // Note that '\0' and characters after it are not trimmed.
1058   StringRef Str = HasStr1 ? Str1 : Str2;
1059   Value *StrP = HasStr1 ? Str2P : Str1P;
1060 
1061   size_t Idx = Str.find('\0');
1062   uint64_t N = Idx == StringRef::npos ? UINT64_MAX : Idx + 1;
1063   if (Func == LibFunc_strncmp) {
1064     if (auto *ConstInt = dyn_cast<ConstantInt>(CI->getArgOperand(2)))
1065       N = std::min(N, ConstInt->getZExtValue());
1066     else
1067       return false;
1068   }
1069   // Now N means how many bytes we need to compare at most.
1070   if (N > Str.size() || N < 2 || N > StrNCmpInlineThreshold)
1071     return false;
1072 
1073   // Cases where StrP has two or more dereferenceable bytes might be better
1074   // optimized elsewhere.
1075   bool CanBeNull = false, CanBeFreed = false;
1076   if (StrP->getPointerDereferenceableBytes(DL, CanBeNull, CanBeFreed) > 1)
1077     return false;
1078   inlineCompare(StrP, Str, N, HasStr1);
1079   return true;
1080 }
1081 
1082 /// Convert
1083 ///
1084 /// \code
1085 ///   ret = compare(s1, s2, N)
1086 /// \endcode
1087 ///
1088 /// into
1089 ///
1090 /// \code
1091 ///   ret = (int)s1[0] - (int)s2[0]
1092 ///   if (ret != 0)
1093 ///     goto NE
1094 ///   ...
1095 ///   ret = (int)s1[N-2] - (int)s2[N-2]
1096 ///   if (ret != 0)
1097 ///     goto NE
1098 ///   ret = (int)s1[N-1] - (int)s2[N-1]
1099 ///   NE:
1100 /// \endcode
1101 ///
1102 /// CFG before and after the transformation:
1103 ///
1104 /// (before)
1105 /// BBCI
1106 ///
1107 /// (after)
1108 /// BBCI -> BBSubs[0] (sub,icmp) --NE-> BBNE -> BBTail
1109 ///                 |                    ^
1110 ///                 E                    |
1111 ///                 |                    |
1112 ///        BBSubs[1] (sub,icmp) --NE-----+
1113 ///                ...                   |
1114 ///        BBSubs[N-1]    (sub) ---------+
1115 ///
1116 void StrNCmpInliner::inlineCompare(Value *LHS, StringRef RHS, uint64_t N,
1117                                    bool Swapped) {
1118   auto &Ctx = CI->getContext();
1119   IRBuilder<> B(Ctx);
1120   // We want these instructions to be recognized as inlined instructions for the
1121   // compare call, but we don't have a source location for the definition of
1122   // that function, since we're generating that code now. Because the generated
1123   // code is a viable point for a memory access error, we make the pragmatic
1124   // choice here to directly use CI's location so that we have useful
1125   // attribution for the generated code.
1126   B.SetCurrentDebugLocation(CI->getDebugLoc());
1127 
1128   BasicBlock *BBCI = CI->getParent();
1129   BasicBlock *BBTail =
1130       SplitBlock(BBCI, CI, DTU, nullptr, nullptr, BBCI->getName() + ".tail");
1131 
1132   SmallVector<BasicBlock *> BBSubs;
1133   for (uint64_t I = 0; I < N; ++I)
1134     BBSubs.push_back(
1135         BasicBlock::Create(Ctx, "sub_" + Twine(I), BBCI->getParent(), BBTail));
1136   BasicBlock *BBNE = BasicBlock::Create(Ctx, "ne", BBCI->getParent(), BBTail);
1137 
1138   cast<BranchInst>(BBCI->getTerminator())->setSuccessor(0, BBSubs[0]);
1139 
1140   B.SetInsertPoint(BBNE);
1141   PHINode *Phi = B.CreatePHI(CI->getType(), N);
1142   B.CreateBr(BBTail);
1143 
1144   Value *Base = LHS;
1145   for (uint64_t i = 0; i < N; ++i) {
1146     B.SetInsertPoint(BBSubs[i]);
1147     Value *VL =
1148         B.CreateZExt(B.CreateLoad(B.getInt8Ty(),
1149                                   B.CreateInBoundsPtrAdd(Base, B.getInt64(i))),
1150                      CI->getType());
1151     Value *VR =
1152         ConstantInt::get(CI->getType(), static_cast<unsigned char>(RHS[i]));
1153     Value *Sub = Swapped ? B.CreateSub(VR, VL) : B.CreateSub(VL, VR);
1154     if (i < N - 1)
1155       B.CreateCondBr(B.CreateICmpNE(Sub, ConstantInt::get(CI->getType(), 0)),
1156                      BBNE, BBSubs[i + 1]);
1157     else
1158       B.CreateBr(BBNE);
1159 
1160     Phi->addIncoming(Sub, BBSubs[i]);
1161   }
1162 
1163   CI->replaceAllUsesWith(Phi);
1164   CI->eraseFromParent();
1165 
1166   if (DTU) {
1167     SmallVector<DominatorTree::UpdateType, 8> Updates;
1168     Updates.push_back({DominatorTree::Insert, BBCI, BBSubs[0]});
1169     for (uint64_t i = 0; i < N; ++i) {
1170       if (i < N - 1)
1171         Updates.push_back({DominatorTree::Insert, BBSubs[i], BBSubs[i + 1]});
1172       Updates.push_back({DominatorTree::Insert, BBSubs[i], BBNE});
1173     }
1174     Updates.push_back({DominatorTree::Insert, BBNE, BBTail});
1175     Updates.push_back({DominatorTree::Delete, BBCI, BBTail});
1176     DTU->applyUpdates(Updates);
1177   }
1178 }
1179 
1180 /// Convert memchr with a small constant string into a switch
1181 static bool foldMemChr(CallInst *Call, DomTreeUpdater *DTU,
1182                        const DataLayout &DL) {
1183   if (isa<Constant>(Call->getArgOperand(1)))
1184     return false;
1185 
1186   StringRef Str;
1187   Value *Base = Call->getArgOperand(0);
1188   if (!getConstantStringInfo(Base, Str, /*TrimAtNul=*/false))
1189     return false;
1190 
1191   uint64_t N = Str.size();
1192   if (auto *ConstInt = dyn_cast<ConstantInt>(Call->getArgOperand(2))) {
1193     uint64_t Val = ConstInt->getZExtValue();
1194     // Ignore the case that n is larger than the size of string.
1195     if (Val > N)
1196       return false;
1197     N = Val;
1198   } else
1199     return false;
1200 
1201   if (N > MemChrInlineThreshold)
1202     return false;
1203 
1204   BasicBlock *BB = Call->getParent();
1205   BasicBlock *BBNext = SplitBlock(BB, Call, DTU);
1206   IRBuilder<> IRB(BB);
1207   IRB.SetCurrentDebugLocation(Call->getDebugLoc());
1208   IntegerType *ByteTy = IRB.getInt8Ty();
1209   BB->getTerminator()->eraseFromParent();
1210   SwitchInst *SI = IRB.CreateSwitch(
1211       IRB.CreateTrunc(Call->getArgOperand(1), ByteTy), BBNext, N);
1212   Type *IndexTy = DL.getIndexType(Call->getType());
1213   SmallVector<DominatorTree::UpdateType, 8> Updates;
1214 
1215   BasicBlock *BBSuccess = BasicBlock::Create(
1216       Call->getContext(), "memchr.success", BB->getParent(), BBNext);
1217   IRB.SetInsertPoint(BBSuccess);
1218   PHINode *IndexPHI = IRB.CreatePHI(IndexTy, N, "memchr.idx");
1219   Value *FirstOccursLocation = IRB.CreateInBoundsPtrAdd(Base, IndexPHI);
1220   IRB.CreateBr(BBNext);
1221   if (DTU)
1222     Updates.push_back({DominatorTree::Insert, BBSuccess, BBNext});
1223 
1224   SmallPtrSet<ConstantInt *, 4> Cases;
1225   for (uint64_t I = 0; I < N; ++I) {
1226     ConstantInt *CaseVal = ConstantInt::get(ByteTy, Str[I]);
1227     if (!Cases.insert(CaseVal).second)
1228       continue;
1229 
1230     BasicBlock *BBCase = BasicBlock::Create(Call->getContext(), "memchr.case",
1231                                             BB->getParent(), BBSuccess);
1232     SI->addCase(CaseVal, BBCase);
1233     IRB.SetInsertPoint(BBCase);
1234     IndexPHI->addIncoming(ConstantInt::get(IndexTy, I), BBCase);
1235     IRB.CreateBr(BBSuccess);
1236     if (DTU) {
1237       Updates.push_back({DominatorTree::Insert, BB, BBCase});
1238       Updates.push_back({DominatorTree::Insert, BBCase, BBSuccess});
1239     }
1240   }
1241 
1242   PHINode *PHI =
1243       PHINode::Create(Call->getType(), 2, Call->getName(), BBNext->begin());
1244   PHI->addIncoming(Constant::getNullValue(Call->getType()), BB);
1245   PHI->addIncoming(FirstOccursLocation, BBSuccess);
1246 
1247   Call->replaceAllUsesWith(PHI);
1248   Call->eraseFromParent();
1249 
1250   if (DTU)
1251     DTU->applyUpdates(Updates);
1252 
1253   return true;
1254 }
1255 
1256 static bool foldLibCalls(Instruction &I, TargetTransformInfo &TTI,
1257                          TargetLibraryInfo &TLI, AssumptionCache &AC,
1258                          DominatorTree &DT, const DataLayout &DL,
1259                          bool &MadeCFGChange) {
1260 
1261   auto *CI = dyn_cast<CallInst>(&I);
1262   if (!CI || CI->isNoBuiltin())
1263     return false;
1264 
1265   Function *CalledFunc = CI->getCalledFunction();
1266   if (!CalledFunc)
1267     return false;
1268 
1269   LibFunc LF;
1270   if (!TLI.getLibFunc(*CalledFunc, LF) ||
1271       !isLibFuncEmittable(CI->getModule(), &TLI, LF))
1272     return false;
1273 
1274   DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Lazy);
1275 
1276   switch (LF) {
1277   case LibFunc_sqrt:
1278   case LibFunc_sqrtf:
1279   case LibFunc_sqrtl:
1280     return foldSqrt(CI, LF, TTI, TLI, AC, DT);
1281   case LibFunc_strcmp:
1282   case LibFunc_strncmp:
1283     if (StrNCmpInliner(CI, LF, &DTU, DL).optimizeStrNCmp()) {
1284       MadeCFGChange = true;
1285       return true;
1286     }
1287     break;
1288   case LibFunc_memchr:
1289     if (foldMemChr(CI, &DTU, DL)) {
1290       MadeCFGChange = true;
1291       return true;
1292     }
1293     break;
1294   default:;
1295   }
1296   return false;
1297 }
1298 
1299 /// This is the entry point for folds that could be implemented in regular
1300 /// InstCombine, but they are separated because they are not expected to
1301 /// occur frequently and/or have more than a constant-length pattern match.
1302 static bool foldUnusualPatterns(Function &F, DominatorTree &DT,
1303                                 TargetTransformInfo &TTI,
1304                                 TargetLibraryInfo &TLI, AliasAnalysis &AA,
1305                                 AssumptionCache &AC, bool &MadeCFGChange) {
1306   bool MadeChange = false;
1307   for (BasicBlock &BB : F) {
1308     // Ignore unreachable basic blocks.
1309     if (!DT.isReachableFromEntry(&BB))
1310       continue;
1311 
1312     const DataLayout &DL = F.getDataLayout();
1313 
1314     // Walk the block backwards for efficiency. We're matching a chain of
1315     // use->defs, so we're more likely to succeed by starting from the bottom.
1316     // Also, we want to avoid matching partial patterns.
1317     // TODO: It would be more efficient if we removed dead instructions
1318     // iteratively in this loop rather than waiting until the end.
1319     for (Instruction &I : make_early_inc_range(llvm::reverse(BB))) {
1320       MadeChange |= foldAnyOrAllBitsSet(I);
1321       MadeChange |= foldGuardedFunnelShift(I, DT);
1322       MadeChange |= tryToRecognizePopCount(I);
1323       MadeChange |= tryToFPToSat(I, TTI);
1324       MadeChange |= tryToRecognizeTableBasedCttz(I);
1325       MadeChange |= foldConsecutiveLoads(I, DL, TTI, AA, DT);
1326       MadeChange |= foldPatternedLoads(I, DL);
1327       MadeChange |= foldICmpOrChain(I, DL, TTI, AA, DT);
1328       // NOTE: This function introduces erasing of the instruction `I`, so it
1329       // needs to be called at the end of this sequence, otherwise we may make
1330       // bugs.
1331       MadeChange |= foldLibCalls(I, TTI, TLI, AC, DT, DL, MadeCFGChange);
1332     }
1333   }
1334 
1335   // We're done with transforms, so remove dead instructions.
1336   if (MadeChange)
1337     for (BasicBlock &BB : F)
1338       SimplifyInstructionsInBlock(&BB);
1339 
1340   return MadeChange;
1341 }
1342 
1343 /// This is the entry point for all transforms. Pass manager differences are
1344 /// handled in the callers of this function.
1345 static bool runImpl(Function &F, AssumptionCache &AC, TargetTransformInfo &TTI,
1346                     TargetLibraryInfo &TLI, DominatorTree &DT,
1347                     AliasAnalysis &AA, bool &MadeCFGChange) {
1348   bool MadeChange = false;
1349   const DataLayout &DL = F.getDataLayout();
1350   TruncInstCombine TIC(AC, TLI, DL, DT);
1351   MadeChange |= TIC.run(F);
1352   MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI, AA, AC, MadeCFGChange);
1353   return MadeChange;
1354 }
1355 
1356 PreservedAnalyses AggressiveInstCombinePass::run(Function &F,
1357                                                  FunctionAnalysisManager &AM) {
1358   auto &AC = AM.getResult<AssumptionAnalysis>(F);
1359   auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
1360   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
1361   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
1362   auto &AA = AM.getResult<AAManager>(F);
1363   bool MadeCFGChange = false;
1364   if (!runImpl(F, AC, TTI, TLI, DT, AA, MadeCFGChange)) {
1365     // No changes, all analyses are preserved.
1366     return PreservedAnalyses::all();
1367   }
1368   // Mark all the analyses that instcombine updates as preserved.
1369   PreservedAnalyses PA;
1370   if (MadeCFGChange)
1371     PA.preserve<DominatorTreeAnalysis>();
1372   else
1373     PA.preserveSet<CFGAnalyses>();
1374   return PA;
1375 }
1376