xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/Utils/SimplifyLibCalls.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===------ SimplifyLibCalls.cpp - Library calls simplifier ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the library calls simplifier. It does not implement
10 // any pass, but can't be used by other passes to do simplifications.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Transforms/Utils/SimplifyLibCalls.h"
15 #include "llvm/ADT/APSInt.h"
16 #include "llvm/ADT/SmallString.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/Analysis/ConstantFolding.h"
19 #include "llvm/Analysis/Loads.h"
20 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
21 #include "llvm/Analysis/ValueTracking.h"
22 #include "llvm/IR/AttributeMask.h"
23 #include "llvm/IR/DataLayout.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/IntrinsicInst.h"
27 #include "llvm/IR/Intrinsics.h"
28 #include "llvm/IR/Module.h"
29 #include "llvm/IR/PatternMatch.h"
30 #include "llvm/Support/Casting.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Support/KnownBits.h"
33 #include "llvm/Support/MathExtras.h"
34 #include "llvm/TargetParser/Triple.h"
35 #include "llvm/Transforms/Utils/BuildLibCalls.h"
36 #include "llvm/Transforms/Utils/Local.h"
37 #include "llvm/Transforms/Utils/SizeOpts.h"
38 
39 #include <cmath>
40 
41 using namespace llvm;
42 using namespace PatternMatch;
43 
44 static cl::opt<bool>
45     EnableUnsafeFPShrink("enable-double-float-shrink", cl::Hidden,
46                          cl::init(false),
47                          cl::desc("Enable unsafe double to float "
48                                   "shrinking for math lib calls"));
49 
50 // Enable conversion of operator new calls with a MemProf hot or cold hint
51 // to an operator new call that takes a hot/cold hint. Off by default since
52 // not all allocators currently support this extension.
53 static cl::opt<bool>
54     OptimizeHotColdNew("optimize-hot-cold-new", cl::Hidden, cl::init(false),
55                        cl::desc("Enable hot/cold operator new library calls"));
56 static cl::opt<bool> OptimizeExistingHotColdNew(
57     "optimize-existing-hot-cold-new", cl::Hidden, cl::init(false),
58     cl::desc(
59         "Enable optimization of existing hot/cold operator new library calls"));
60 
61 namespace {
62 
63 // Specialized parser to ensure the hint is an 8 bit value (we can't specify
64 // uint8_t to opt<> as that is interpreted to mean that we are passing a char
65 // option with a specific set of values.
66 struct HotColdHintParser : public cl::parser<unsigned> {
HotColdHintParser__anon1eca78e10111::HotColdHintParser67   HotColdHintParser(cl::Option &O) : cl::parser<unsigned>(O) {}
68 
parse__anon1eca78e10111::HotColdHintParser69   bool parse(cl::Option &O, StringRef ArgName, StringRef Arg, unsigned &Value) {
70     if (Arg.getAsInteger(0, Value))
71       return O.error("'" + Arg + "' value invalid for uint argument!");
72 
73     if (Value > 255)
74       return O.error("'" + Arg + "' value must be in the range [0, 255]!");
75 
76     return false;
77   }
78 };
79 
80 } // end anonymous namespace
81 
82 // Hot/cold operator new takes an 8 bit hotness hint, where 0 is the coldest
83 // and 255 is the hottest. Default to 1 value away from the coldest and hottest
84 // hints, so that the compiler hinted allocations are slightly less strong than
85 // manually inserted hints at the two extremes.
86 static cl::opt<unsigned, false, HotColdHintParser> ColdNewHintValue(
87     "cold-new-hint-value", cl::Hidden, cl::init(1),
88     cl::desc("Value to pass to hot/cold operator new for cold allocation"));
89 static cl::opt<unsigned, false, HotColdHintParser>
90     NotColdNewHintValue("notcold-new-hint-value", cl::Hidden, cl::init(128),
91                         cl::desc("Value to pass to hot/cold operator new for "
92                                  "notcold (warm) allocation"));
93 static cl::opt<unsigned, false, HotColdHintParser> HotNewHintValue(
94     "hot-new-hint-value", cl::Hidden, cl::init(254),
95     cl::desc("Value to pass to hot/cold operator new for hot allocation"));
96 
97 //===----------------------------------------------------------------------===//
98 // Helper Functions
99 //===----------------------------------------------------------------------===//
100 
ignoreCallingConv(LibFunc Func)101 static bool ignoreCallingConv(LibFunc Func) {
102   return Func == LibFunc_abs || Func == LibFunc_labs ||
103          Func == LibFunc_llabs || Func == LibFunc_strlen;
104 }
105 
106 /// Return true if it is only used in equality comparisons with With.
isOnlyUsedInEqualityComparison(Value * V,Value * With)107 static bool isOnlyUsedInEqualityComparison(Value *V, Value *With) {
108   for (User *U : V->users()) {
109     if (ICmpInst *IC = dyn_cast<ICmpInst>(U))
110       if (IC->isEquality() && IC->getOperand(1) == With)
111         continue;
112     // Unknown instruction.
113     return false;
114   }
115   return true;
116 }
117 
callHasFloatingPointArgument(const CallInst * CI)118 static bool callHasFloatingPointArgument(const CallInst *CI) {
119   return any_of(CI->operands(), [](const Use &OI) {
120     return OI->getType()->isFloatingPointTy();
121   });
122 }
123 
callHasFP128Argument(const CallInst * CI)124 static bool callHasFP128Argument(const CallInst *CI) {
125   return any_of(CI->operands(), [](const Use &OI) {
126     return OI->getType()->isFP128Ty();
127   });
128 }
129 
130 // Convert the entire string Str representing an integer in Base, up to
131 // the terminating nul if present, to a constant according to the rules
132 // of strtoul[l] or, when AsSigned is set, of strtol[l].  On success
133 // return the result, otherwise null.
134 // The function assumes the string is encoded in ASCII and carefully
135 // avoids converting sequences (including "") that the corresponding
136 // library call might fail and set errno for.
convertStrToInt(CallInst * CI,StringRef & Str,Value * EndPtr,uint64_t Base,bool AsSigned,IRBuilderBase & B)137 static Value *convertStrToInt(CallInst *CI, StringRef &Str, Value *EndPtr,
138                               uint64_t Base, bool AsSigned, IRBuilderBase &B) {
139   if (Base < 2 || Base > 36)
140     if (Base != 0)
141       // Fail for an invalid base (required by POSIX).
142       return nullptr;
143 
144   // Current offset into the original string to reflect in EndPtr.
145   size_t Offset = 0;
146   // Strip leading whitespace.
147   for ( ; Offset != Str.size(); ++Offset)
148     if (!isSpace((unsigned char)Str[Offset])) {
149       Str = Str.substr(Offset);
150       break;
151     }
152 
153   if (Str.empty())
154     // Fail for empty subject sequences (POSIX allows but doesn't require
155     // strtol[l]/strtoul[l] to fail with EINVAL).
156     return nullptr;
157 
158   // Strip but remember the sign.
159   bool Negate = Str[0] == '-';
160   if (Str[0] == '-' || Str[0] == '+') {
161     Str = Str.drop_front();
162     if (Str.empty())
163       // Fail for a sign with nothing after it.
164       return nullptr;
165     ++Offset;
166   }
167 
168   // Set Max to the absolute value of the minimum (for signed), or
169   // to the maximum (for unsigned) value representable in the type.
170   Type *RetTy = CI->getType();
171   unsigned NBits = RetTy->getPrimitiveSizeInBits();
172   uint64_t Max = AsSigned && Negate ? 1 : 0;
173   Max += AsSigned ? maxIntN(NBits) : maxUIntN(NBits);
174 
175   // Autodetect Base if it's zero and consume the "0x" prefix.
176   if (Str.size() > 1) {
177     if (Str[0] == '0') {
178       if (toUpper((unsigned char)Str[1]) == 'X') {
179         if (Str.size() == 2 || (Base && Base != 16))
180           // Fail if Base doesn't allow the "0x" prefix or for the prefix
181           // alone that implementations like BSD set errno to EINVAL for.
182           return nullptr;
183 
184         Str = Str.drop_front(2);
185         Offset += 2;
186         Base = 16;
187       }
188       else if (Base == 0)
189         Base = 8;
190     } else if (Base == 0)
191       Base = 10;
192   }
193   else if (Base == 0)
194     Base = 10;
195 
196   // Convert the rest of the subject sequence, not including the sign,
197   // to its uint64_t representation (this assumes the source character
198   // set is ASCII).
199   uint64_t Result = 0;
200   for (unsigned i = 0; i != Str.size(); ++i) {
201     unsigned char DigVal = Str[i];
202     if (isDigit(DigVal))
203       DigVal = DigVal - '0';
204     else {
205       DigVal = toUpper(DigVal);
206       if (isAlpha(DigVal))
207         DigVal = DigVal - 'A' + 10;
208       else
209         return nullptr;
210     }
211 
212     if (DigVal >= Base)
213       // Fail if the digit is not valid in the Base.
214       return nullptr;
215 
216     // Add the digit and fail if the result is not representable in
217     // the (unsigned form of the) destination type.
218     bool VFlow;
219     Result = SaturatingMultiplyAdd(Result, Base, (uint64_t)DigVal, &VFlow);
220     if (VFlow || Result > Max)
221       return nullptr;
222   }
223 
224   if (EndPtr) {
225     // Store the pointer to the end.
226     Value *Off = B.getInt64(Offset + Str.size());
227     Value *StrBeg = CI->getArgOperand(0);
228     Value *StrEnd = B.CreateInBoundsGEP(B.getInt8Ty(), StrBeg, Off, "endptr");
229     B.CreateStore(StrEnd, EndPtr);
230   }
231 
232   if (Negate)
233     // Unsigned negation doesn't overflow.
234     Result = -Result;
235 
236   return ConstantInt::get(RetTy, Result);
237 }
238 
isOnlyUsedInComparisonWithZero(Value * V)239 static bool isOnlyUsedInComparisonWithZero(Value *V) {
240   for (User *U : V->users()) {
241     if (ICmpInst *IC = dyn_cast<ICmpInst>(U))
242       if (Constant *C = dyn_cast<Constant>(IC->getOperand(1)))
243         if (C->isNullValue())
244           continue;
245     // Unknown instruction.
246     return false;
247   }
248   return true;
249 }
250 
canTransformToMemCmp(CallInst * CI,Value * Str,uint64_t Len,const DataLayout & DL)251 static bool canTransformToMemCmp(CallInst *CI, Value *Str, uint64_t Len,
252                                  const DataLayout &DL) {
253   if (!isOnlyUsedInComparisonWithZero(CI))
254     return false;
255 
256   if (!isDereferenceableAndAlignedPointer(Str, Align(1), APInt(64, Len), DL))
257     return false;
258 
259   if (CI->getFunction()->hasFnAttribute(Attribute::SanitizeMemory))
260     return false;
261 
262   return true;
263 }
264 
annotateDereferenceableBytes(CallInst * CI,ArrayRef<unsigned> ArgNos,uint64_t DereferenceableBytes)265 static void annotateDereferenceableBytes(CallInst *CI,
266                                          ArrayRef<unsigned> ArgNos,
267                                          uint64_t DereferenceableBytes) {
268   const Function *F = CI->getCaller();
269   if (!F)
270     return;
271   for (unsigned ArgNo : ArgNos) {
272     uint64_t DerefBytes = DereferenceableBytes;
273     unsigned AS = CI->getArgOperand(ArgNo)->getType()->getPointerAddressSpace();
274     if (!llvm::NullPointerIsDefined(F, AS) ||
275         CI->paramHasAttr(ArgNo, Attribute::NonNull))
276       DerefBytes = std::max(CI->getParamDereferenceableOrNullBytes(ArgNo),
277                             DereferenceableBytes);
278 
279     if (CI->getParamDereferenceableBytes(ArgNo) < DerefBytes) {
280       CI->removeParamAttr(ArgNo, Attribute::Dereferenceable);
281       if (!llvm::NullPointerIsDefined(F, AS) ||
282           CI->paramHasAttr(ArgNo, Attribute::NonNull))
283         CI->removeParamAttr(ArgNo, Attribute::DereferenceableOrNull);
284       CI->addParamAttr(ArgNo, Attribute::getWithDereferenceableBytes(
285                                   CI->getContext(), DerefBytes));
286     }
287   }
288 }
289 
annotateNonNullNoUndefBasedOnAccess(CallInst * CI,ArrayRef<unsigned> ArgNos)290 static void annotateNonNullNoUndefBasedOnAccess(CallInst *CI,
291                                          ArrayRef<unsigned> ArgNos) {
292   Function *F = CI->getCaller();
293   if (!F)
294     return;
295 
296   for (unsigned ArgNo : ArgNos) {
297     if (!CI->paramHasAttr(ArgNo, Attribute::NoUndef))
298       CI->addParamAttr(ArgNo, Attribute::NoUndef);
299 
300     if (!CI->paramHasAttr(ArgNo, Attribute::NonNull)) {
301       unsigned AS =
302           CI->getArgOperand(ArgNo)->getType()->getPointerAddressSpace();
303       if (llvm::NullPointerIsDefined(F, AS))
304         continue;
305       CI->addParamAttr(ArgNo, Attribute::NonNull);
306     }
307 
308     annotateDereferenceableBytes(CI, ArgNo, 1);
309   }
310 }
311 
annotateNonNullAndDereferenceable(CallInst * CI,ArrayRef<unsigned> ArgNos,Value * Size,const DataLayout & DL)312 static void annotateNonNullAndDereferenceable(CallInst *CI, ArrayRef<unsigned> ArgNos,
313                                Value *Size, const DataLayout &DL) {
314   if (ConstantInt *LenC = dyn_cast<ConstantInt>(Size)) {
315     annotateNonNullNoUndefBasedOnAccess(CI, ArgNos);
316     annotateDereferenceableBytes(CI, ArgNos, LenC->getZExtValue());
317   } else if (isKnownNonZero(Size, DL)) {
318     annotateNonNullNoUndefBasedOnAccess(CI, ArgNos);
319     const APInt *X, *Y;
320     uint64_t DerefMin = 1;
321     if (match(Size, m_Select(m_Value(), m_APInt(X), m_APInt(Y)))) {
322       DerefMin = std::min(X->getZExtValue(), Y->getZExtValue());
323       annotateDereferenceableBytes(CI, ArgNos, DerefMin);
324     }
325   }
326 }
327 
328 // Copy CallInst "flags" like musttail, notail, and tail. Return New param for
329 // easier chaining. Calls to emit* and B.createCall should probably be wrapped
330 // in this function when New is created to replace Old. Callers should take
331 // care to check Old.isMustTailCall() if they aren't replacing Old directly
332 // with New.
copyFlags(const CallInst & Old,Value * New)333 static Value *copyFlags(const CallInst &Old, Value *New) {
334   assert(!Old.isMustTailCall() && "do not copy musttail call flags");
335   assert(!Old.isNoTailCall() && "do not copy notail call flags");
336   if (auto *NewCI = dyn_cast_or_null<CallInst>(New))
337     NewCI->setTailCallKind(Old.getTailCallKind());
338   return New;
339 }
340 
mergeAttributesAndFlags(CallInst * NewCI,const CallInst & Old)341 static Value *mergeAttributesAndFlags(CallInst *NewCI, const CallInst &Old) {
342   NewCI->setAttributes(AttributeList::get(
343       NewCI->getContext(), {NewCI->getAttributes(), Old.getAttributes()}));
344   NewCI->removeRetAttrs(AttributeFuncs::typeIncompatible(NewCI->getType()));
345   return copyFlags(Old, NewCI);
346 }
347 
348 // Helper to avoid truncating the length if size_t is 32-bits.
substr(StringRef Str,uint64_t Len)349 static StringRef substr(StringRef Str, uint64_t Len) {
350   return Len >= Str.size() ? Str : Str.substr(0, Len);
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // String and Memory Library Call Optimizations
355 //===----------------------------------------------------------------------===//
356 
optimizeStrCat(CallInst * CI,IRBuilderBase & B)357 Value *LibCallSimplifier::optimizeStrCat(CallInst *CI, IRBuilderBase &B) {
358   // Extract some information from the instruction
359   Value *Dst = CI->getArgOperand(0);
360   Value *Src = CI->getArgOperand(1);
361   annotateNonNullNoUndefBasedOnAccess(CI, {0, 1});
362 
363   // See if we can get the length of the input string.
364   uint64_t Len = GetStringLength(Src);
365   if (Len)
366     annotateDereferenceableBytes(CI, 1, Len);
367   else
368     return nullptr;
369   --Len; // Unbias length.
370 
371   // Handle the simple, do-nothing case: strcat(x, "") -> x
372   if (Len == 0)
373     return Dst;
374 
375   return copyFlags(*CI, emitStrLenMemCpy(Src, Dst, Len, B));
376 }
377 
emitStrLenMemCpy(Value * Src,Value * Dst,uint64_t Len,IRBuilderBase & B)378 Value *LibCallSimplifier::emitStrLenMemCpy(Value *Src, Value *Dst, uint64_t Len,
379                                            IRBuilderBase &B) {
380   // We need to find the end of the destination string.  That's where the
381   // memory is to be moved to. We just generate a call to strlen.
382   Value *DstLen = emitStrLen(Dst, B, DL, TLI);
383   if (!DstLen)
384     return nullptr;
385 
386   // Now that we have the destination's length, we must index into the
387   // destination's pointer to get the actual memcpy destination (end of
388   // the string .. we're concatenating).
389   Value *CpyDst = B.CreateInBoundsGEP(B.getInt8Ty(), Dst, DstLen, "endptr");
390 
391   // We have enough information to now generate the memcpy call to do the
392   // concatenation for us.  Make a memcpy to copy the nul byte with align = 1.
393   B.CreateMemCpy(
394       CpyDst, Align(1), Src, Align(1),
395       ConstantInt::get(DL.getIntPtrType(Src->getContext()), Len + 1));
396   return Dst;
397 }
398 
optimizeStrNCat(CallInst * CI,IRBuilderBase & B)399 Value *LibCallSimplifier::optimizeStrNCat(CallInst *CI, IRBuilderBase &B) {
400   // Extract some information from the instruction.
401   Value *Dst = CI->getArgOperand(0);
402   Value *Src = CI->getArgOperand(1);
403   Value *Size = CI->getArgOperand(2);
404   uint64_t Len;
405   annotateNonNullNoUndefBasedOnAccess(CI, 0);
406   if (isKnownNonZero(Size, DL))
407     annotateNonNullNoUndefBasedOnAccess(CI, 1);
408 
409   // We don't do anything if length is not constant.
410   ConstantInt *LengthArg = dyn_cast<ConstantInt>(Size);
411   if (LengthArg) {
412     Len = LengthArg->getZExtValue();
413     // strncat(x, c, 0) -> x
414     if (!Len)
415       return Dst;
416   } else {
417     return nullptr;
418   }
419 
420   // See if we can get the length of the input string.
421   uint64_t SrcLen = GetStringLength(Src);
422   if (SrcLen) {
423     annotateDereferenceableBytes(CI, 1, SrcLen);
424     --SrcLen; // Unbias length.
425   } else {
426     return nullptr;
427   }
428 
429   // strncat(x, "", c) -> x
430   if (SrcLen == 0)
431     return Dst;
432 
433   // We don't optimize this case.
434   if (Len < SrcLen)
435     return nullptr;
436 
437   // strncat(x, s, c) -> strcat(x, s)
438   // s is constant so the strcat can be optimized further.
439   return copyFlags(*CI, emitStrLenMemCpy(Src, Dst, SrcLen, B));
440 }
441 
442 // Helper to transform memchr(S, C, N) == S to N && *S == C and, when
443 // NBytes is null, strchr(S, C) to *S == C.  A precondition of the function
444 // is that either S is dereferenceable or the value of N is nonzero.
memChrToCharCompare(CallInst * CI,Value * NBytes,IRBuilderBase & B,const DataLayout & DL)445 static Value* memChrToCharCompare(CallInst *CI, Value *NBytes,
446                                   IRBuilderBase &B, const DataLayout &DL)
447 {
448   Value *Src = CI->getArgOperand(0);
449   Value *CharVal = CI->getArgOperand(1);
450 
451   // Fold memchr(A, C, N) == A to N && *A == C.
452   Type *CharTy = B.getInt8Ty();
453   Value *Char0 = B.CreateLoad(CharTy, Src);
454   CharVal = B.CreateTrunc(CharVal, CharTy);
455   Value *Cmp = B.CreateICmpEQ(Char0, CharVal, "char0cmp");
456 
457   if (NBytes) {
458     Value *Zero = ConstantInt::get(NBytes->getType(), 0);
459     Value *And = B.CreateICmpNE(NBytes, Zero);
460     Cmp = B.CreateLogicalAnd(And, Cmp);
461   }
462 
463   Value *NullPtr = Constant::getNullValue(CI->getType());
464   return B.CreateSelect(Cmp, Src, NullPtr);
465 }
466 
optimizeStrChr(CallInst * CI,IRBuilderBase & B)467 Value *LibCallSimplifier::optimizeStrChr(CallInst *CI, IRBuilderBase &B) {
468   Value *SrcStr = CI->getArgOperand(0);
469   Value *CharVal = CI->getArgOperand(1);
470   annotateNonNullNoUndefBasedOnAccess(CI, 0);
471 
472   if (isOnlyUsedInEqualityComparison(CI, SrcStr))
473     return memChrToCharCompare(CI, nullptr, B, DL);
474 
475   // If the second operand is non-constant, see if we can compute the length
476   // of the input string and turn this into memchr.
477   ConstantInt *CharC = dyn_cast<ConstantInt>(CharVal);
478   if (!CharC) {
479     uint64_t Len = GetStringLength(SrcStr);
480     if (Len)
481       annotateDereferenceableBytes(CI, 0, Len);
482     else
483       return nullptr;
484 
485     Function *Callee = CI->getCalledFunction();
486     FunctionType *FT = Callee->getFunctionType();
487     unsigned IntBits = TLI->getIntSize();
488     if (!FT->getParamType(1)->isIntegerTy(IntBits)) // memchr needs 'int'.
489       return nullptr;
490 
491     unsigned SizeTBits = TLI->getSizeTSize(*CI->getModule());
492     Type *SizeTTy = IntegerType::get(CI->getContext(), SizeTBits);
493     return copyFlags(*CI,
494                      emitMemChr(SrcStr, CharVal, // include nul.
495                                 ConstantInt::get(SizeTTy, Len), B,
496                                 DL, TLI));
497   }
498 
499   if (CharC->isZero()) {
500     Value *NullPtr = Constant::getNullValue(CI->getType());
501     if (isOnlyUsedInEqualityComparison(CI, NullPtr))
502       // Pre-empt the transformation to strlen below and fold
503       // strchr(A, '\0') == null to false.
504       return B.CreateIntToPtr(B.getTrue(), CI->getType());
505   }
506 
507   // Otherwise, the character is a constant, see if the first argument is
508   // a string literal.  If so, we can constant fold.
509   StringRef Str;
510   if (!getConstantStringInfo(SrcStr, Str)) {
511     if (CharC->isZero()) // strchr(p, 0) -> p + strlen(p)
512       if (Value *StrLen = emitStrLen(SrcStr, B, DL, TLI))
513         return B.CreateInBoundsGEP(B.getInt8Ty(), SrcStr, StrLen, "strchr");
514     return nullptr;
515   }
516 
517   // Compute the offset, make sure to handle the case when we're searching for
518   // zero (a weird way to spell strlen).
519   size_t I = (0xFF & CharC->getSExtValue()) == 0
520                  ? Str.size()
521                  : Str.find(CharC->getSExtValue());
522   if (I == StringRef::npos) // Didn't find the char.  strchr returns null.
523     return Constant::getNullValue(CI->getType());
524 
525   // strchr(s+n,c)  -> gep(s+n+i,c)
526   return B.CreateInBoundsGEP(B.getInt8Ty(), SrcStr, B.getInt64(I), "strchr");
527 }
528 
optimizeStrRChr(CallInst * CI,IRBuilderBase & B)529 Value *LibCallSimplifier::optimizeStrRChr(CallInst *CI, IRBuilderBase &B) {
530   Value *SrcStr = CI->getArgOperand(0);
531   Value *CharVal = CI->getArgOperand(1);
532   ConstantInt *CharC = dyn_cast<ConstantInt>(CharVal);
533   annotateNonNullNoUndefBasedOnAccess(CI, 0);
534 
535   StringRef Str;
536   if (!getConstantStringInfo(SrcStr, Str)) {
537     // strrchr(s, 0) -> strchr(s, 0)
538     if (CharC && CharC->isZero())
539       return copyFlags(*CI, emitStrChr(SrcStr, '\0', B, TLI));
540     return nullptr;
541   }
542 
543   unsigned SizeTBits = TLI->getSizeTSize(*CI->getModule());
544   Type *SizeTTy = IntegerType::get(CI->getContext(), SizeTBits);
545 
546   // Try to expand strrchr to the memrchr nonstandard extension if it's
547   // available, or simply fail otherwise.
548   uint64_t NBytes = Str.size() + 1;   // Include the terminating nul.
549   Value *Size = ConstantInt::get(SizeTTy, NBytes);
550   return copyFlags(*CI, emitMemRChr(SrcStr, CharVal, Size, B, DL, TLI));
551 }
552 
optimizeStrCmp(CallInst * CI,IRBuilderBase & B)553 Value *LibCallSimplifier::optimizeStrCmp(CallInst *CI, IRBuilderBase &B) {
554   Value *Str1P = CI->getArgOperand(0), *Str2P = CI->getArgOperand(1);
555   if (Str1P == Str2P) // strcmp(x,x)  -> 0
556     return ConstantInt::get(CI->getType(), 0);
557 
558   StringRef Str1, Str2;
559   bool HasStr1 = getConstantStringInfo(Str1P, Str1);
560   bool HasStr2 = getConstantStringInfo(Str2P, Str2);
561 
562   // strcmp(x, y)  -> cnst  (if both x and y are constant strings)
563   if (HasStr1 && HasStr2)
564     return ConstantInt::get(CI->getType(),
565                             std::clamp(Str1.compare(Str2), -1, 1));
566 
567   if (HasStr1 && Str1.empty()) // strcmp("", x) -> -*x
568     return B.CreateNeg(B.CreateZExt(
569         B.CreateLoad(B.getInt8Ty(), Str2P, "strcmpload"), CI->getType()));
570 
571   if (HasStr2 && Str2.empty()) // strcmp(x,"") -> *x
572     return B.CreateZExt(B.CreateLoad(B.getInt8Ty(), Str1P, "strcmpload"),
573                         CI->getType());
574 
575   // strcmp(P, "x") -> memcmp(P, "x", 2)
576   uint64_t Len1 = GetStringLength(Str1P);
577   if (Len1)
578     annotateDereferenceableBytes(CI, 0, Len1);
579   uint64_t Len2 = GetStringLength(Str2P);
580   if (Len2)
581     annotateDereferenceableBytes(CI, 1, Len2);
582 
583   if (Len1 && Len2) {
584     return copyFlags(
585         *CI, emitMemCmp(Str1P, Str2P,
586                         ConstantInt::get(DL.getIntPtrType(CI->getContext()),
587                                          std::min(Len1, Len2)),
588                         B, DL, TLI));
589   }
590 
591   // strcmp to memcmp
592   if (!HasStr1 && HasStr2) {
593     if (canTransformToMemCmp(CI, Str1P, Len2, DL))
594       return copyFlags(
595           *CI,
596           emitMemCmp(Str1P, Str2P,
597                      ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len2),
598                      B, DL, TLI));
599   } else if (HasStr1 && !HasStr2) {
600     if (canTransformToMemCmp(CI, Str2P, Len1, DL))
601       return copyFlags(
602           *CI,
603           emitMemCmp(Str1P, Str2P,
604                      ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len1),
605                      B, DL, TLI));
606   }
607 
608   annotateNonNullNoUndefBasedOnAccess(CI, {0, 1});
609   return nullptr;
610 }
611 
612 // Optimize a memcmp or, when StrNCmp is true, strncmp call CI with constant
613 // arrays LHS and RHS and nonconstant Size.
614 static Value *optimizeMemCmpVarSize(CallInst *CI, Value *LHS, Value *RHS,
615                                     Value *Size, bool StrNCmp,
616                                     IRBuilderBase &B, const DataLayout &DL);
617 
optimizeStrNCmp(CallInst * CI,IRBuilderBase & B)618 Value *LibCallSimplifier::optimizeStrNCmp(CallInst *CI, IRBuilderBase &B) {
619   Value *Str1P = CI->getArgOperand(0);
620   Value *Str2P = CI->getArgOperand(1);
621   Value *Size = CI->getArgOperand(2);
622   if (Str1P == Str2P) // strncmp(x,x,n)  -> 0
623     return ConstantInt::get(CI->getType(), 0);
624 
625   if (isKnownNonZero(Size, DL))
626     annotateNonNullNoUndefBasedOnAccess(CI, {0, 1});
627   // Get the length argument if it is constant.
628   uint64_t Length;
629   if (ConstantInt *LengthArg = dyn_cast<ConstantInt>(Size))
630     Length = LengthArg->getZExtValue();
631   else
632     return optimizeMemCmpVarSize(CI, Str1P, Str2P, Size, true, B, DL);
633 
634   if (Length == 0) // strncmp(x,y,0)   -> 0
635     return ConstantInt::get(CI->getType(), 0);
636 
637   if (Length == 1) // strncmp(x,y,1) -> memcmp(x,y,1)
638     return copyFlags(*CI, emitMemCmp(Str1P, Str2P, Size, B, DL, TLI));
639 
640   StringRef Str1, Str2;
641   bool HasStr1 = getConstantStringInfo(Str1P, Str1);
642   bool HasStr2 = getConstantStringInfo(Str2P, Str2);
643 
644   // strncmp(x, y)  -> cnst  (if both x and y are constant strings)
645   if (HasStr1 && HasStr2) {
646     // Avoid truncating the 64-bit Length to 32 bits in ILP32.
647     StringRef SubStr1 = substr(Str1, Length);
648     StringRef SubStr2 = substr(Str2, Length);
649     return ConstantInt::get(CI->getType(),
650                             std::clamp(SubStr1.compare(SubStr2), -1, 1));
651   }
652 
653   if (HasStr1 && Str1.empty()) // strncmp("", x, n) -> -*x
654     return B.CreateNeg(B.CreateZExt(
655         B.CreateLoad(B.getInt8Ty(), Str2P, "strcmpload"), CI->getType()));
656 
657   if (HasStr2 && Str2.empty()) // strncmp(x, "", n) -> *x
658     return B.CreateZExt(B.CreateLoad(B.getInt8Ty(), Str1P, "strcmpload"),
659                         CI->getType());
660 
661   uint64_t Len1 = GetStringLength(Str1P);
662   if (Len1)
663     annotateDereferenceableBytes(CI, 0, Len1);
664   uint64_t Len2 = GetStringLength(Str2P);
665   if (Len2)
666     annotateDereferenceableBytes(CI, 1, Len2);
667 
668   // strncmp to memcmp
669   if (!HasStr1 && HasStr2) {
670     Len2 = std::min(Len2, Length);
671     if (canTransformToMemCmp(CI, Str1P, Len2, DL))
672       return copyFlags(
673           *CI,
674           emitMemCmp(Str1P, Str2P,
675                      ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len2),
676                      B, DL, TLI));
677   } else if (HasStr1 && !HasStr2) {
678     Len1 = std::min(Len1, Length);
679     if (canTransformToMemCmp(CI, Str2P, Len1, DL))
680       return copyFlags(
681           *CI,
682           emitMemCmp(Str1P, Str2P,
683                      ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len1),
684                      B, DL, TLI));
685   }
686 
687   return nullptr;
688 }
689 
optimizeStrNDup(CallInst * CI,IRBuilderBase & B)690 Value *LibCallSimplifier::optimizeStrNDup(CallInst *CI, IRBuilderBase &B) {
691   Value *Src = CI->getArgOperand(0);
692   ConstantInt *Size = dyn_cast<ConstantInt>(CI->getArgOperand(1));
693   uint64_t SrcLen = GetStringLength(Src);
694   if (SrcLen && Size) {
695     annotateDereferenceableBytes(CI, 0, SrcLen);
696     if (SrcLen <= Size->getZExtValue() + 1)
697       return copyFlags(*CI, emitStrDup(Src, B, TLI));
698   }
699 
700   return nullptr;
701 }
702 
optimizeStrCpy(CallInst * CI,IRBuilderBase & B)703 Value *LibCallSimplifier::optimizeStrCpy(CallInst *CI, IRBuilderBase &B) {
704   Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1);
705   if (Dst == Src) // strcpy(x,x)  -> x
706     return Src;
707 
708   annotateNonNullNoUndefBasedOnAccess(CI, {0, 1});
709   // See if we can get the length of the input string.
710   uint64_t Len = GetStringLength(Src);
711   if (Len)
712     annotateDereferenceableBytes(CI, 1, Len);
713   else
714     return nullptr;
715 
716   // We have enough information to now generate the memcpy call to do the
717   // copy for us.  Make a memcpy to copy the nul byte with align = 1.
718   CallInst *NewCI =
719       B.CreateMemCpy(Dst, Align(1), Src, Align(1),
720                      ConstantInt::get(DL.getIntPtrType(CI->getContext()), Len));
721   mergeAttributesAndFlags(NewCI, *CI);
722   return Dst;
723 }
724 
optimizeStpCpy(CallInst * CI,IRBuilderBase & B)725 Value *LibCallSimplifier::optimizeStpCpy(CallInst *CI, IRBuilderBase &B) {
726   Function *Callee = CI->getCalledFunction();
727   Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1);
728 
729   // stpcpy(d,s) -> strcpy(d,s) if the result is not used.
730   if (CI->use_empty())
731     return copyFlags(*CI, emitStrCpy(Dst, Src, B, TLI));
732 
733   if (Dst == Src) { // stpcpy(x,x)  -> x+strlen(x)
734     Value *StrLen = emitStrLen(Src, B, DL, TLI);
735     return StrLen ? B.CreateInBoundsGEP(B.getInt8Ty(), Dst, StrLen) : nullptr;
736   }
737 
738   // See if we can get the length of the input string.
739   uint64_t Len = GetStringLength(Src);
740   if (Len)
741     annotateDereferenceableBytes(CI, 1, Len);
742   else
743     return nullptr;
744 
745   Type *PT = Callee->getFunctionType()->getParamType(0);
746   Value *LenV = ConstantInt::get(DL.getIntPtrType(PT), Len);
747   Value *DstEnd = B.CreateInBoundsGEP(
748       B.getInt8Ty(), Dst, ConstantInt::get(DL.getIntPtrType(PT), Len - 1));
749 
750   // We have enough information to now generate the memcpy call to do the
751   // copy for us.  Make a memcpy to copy the nul byte with align = 1.
752   CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1), LenV);
753   mergeAttributesAndFlags(NewCI, *CI);
754   return DstEnd;
755 }
756 
757 // Optimize a call to size_t strlcpy(char*, const char*, size_t).
758 
optimizeStrLCpy(CallInst * CI,IRBuilderBase & B)759 Value *LibCallSimplifier::optimizeStrLCpy(CallInst *CI, IRBuilderBase &B) {
760   Value *Size = CI->getArgOperand(2);
761   if (isKnownNonZero(Size, DL))
762     // Like snprintf, the function stores into the destination only when
763     // the size argument is nonzero.
764     annotateNonNullNoUndefBasedOnAccess(CI, 0);
765   // The function reads the source argument regardless of Size (it returns
766   // its length).
767   annotateNonNullNoUndefBasedOnAccess(CI, 1);
768 
769   uint64_t NBytes;
770   if (ConstantInt *SizeC = dyn_cast<ConstantInt>(Size))
771     NBytes = SizeC->getZExtValue();
772   else
773     return nullptr;
774 
775   Value *Dst = CI->getArgOperand(0);
776   Value *Src = CI->getArgOperand(1);
777   if (NBytes <= 1) {
778     if (NBytes == 1)
779       // For a call to strlcpy(D, S, 1) first store a nul in *D.
780       B.CreateStore(B.getInt8(0), Dst);
781 
782     // Transform strlcpy(D, S, 0) to a call to strlen(S).
783     return copyFlags(*CI, emitStrLen(Src, B, DL, TLI));
784   }
785 
786   // Try to determine the length of the source, substituting its size
787   // when it's not nul-terminated (as it's required to be) to avoid
788   // reading past its end.
789   StringRef Str;
790   if (!getConstantStringInfo(Src, Str, /*TrimAtNul=*/false))
791     return nullptr;
792 
793   uint64_t SrcLen = Str.find('\0');
794   // Set if the terminating nul should be copied by the call to memcpy
795   // below.
796   bool NulTerm = SrcLen < NBytes;
797 
798   if (NulTerm)
799     // Overwrite NBytes with the number of bytes to copy, including
800     // the terminating nul.
801     NBytes = SrcLen + 1;
802   else {
803     // Set the length of the source for the function to return to its
804     // size, and cap NBytes at the same.
805     SrcLen = std::min(SrcLen, uint64_t(Str.size()));
806     NBytes = std::min(NBytes - 1, SrcLen);
807   }
808 
809   if (SrcLen == 0) {
810     // Transform strlcpy(D, "", N) to (*D = '\0, 0).
811     B.CreateStore(B.getInt8(0), Dst);
812     return ConstantInt::get(CI->getType(), 0);
813   }
814 
815   Function *Callee = CI->getCalledFunction();
816   Type *PT = Callee->getFunctionType()->getParamType(0);
817   // Transform strlcpy(D, S, N) to memcpy(D, S, N') where N' is the lower
818   // bound on strlen(S) + 1 and N, optionally followed by a nul store to
819   // D[N' - 1] if necessary.
820   CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1),
821                         ConstantInt::get(DL.getIntPtrType(PT), NBytes));
822   mergeAttributesAndFlags(NewCI, *CI);
823 
824   if (!NulTerm) {
825     Value *EndOff = ConstantInt::get(CI->getType(), NBytes);
826     Value *EndPtr = B.CreateInBoundsGEP(B.getInt8Ty(), Dst, EndOff);
827     B.CreateStore(B.getInt8(0), EndPtr);
828   }
829 
830   // Like snprintf, strlcpy returns the number of nonzero bytes that would
831   // have been copied if the bound had been sufficiently big (which in this
832   // case is strlen(Src)).
833   return ConstantInt::get(CI->getType(), SrcLen);
834 }
835 
836 // Optimize a call CI to either stpncpy when RetEnd is true, or to strncpy
837 // otherwise.
optimizeStringNCpy(CallInst * CI,bool RetEnd,IRBuilderBase & B)838 Value *LibCallSimplifier::optimizeStringNCpy(CallInst *CI, bool RetEnd,
839                                              IRBuilderBase &B) {
840   Function *Callee = CI->getCalledFunction();
841   Value *Dst = CI->getArgOperand(0);
842   Value *Src = CI->getArgOperand(1);
843   Value *Size = CI->getArgOperand(2);
844 
845   if (isKnownNonZero(Size, DL)) {
846     // Both st{p,r}ncpy(D, S, N) access the source and destination arrays
847     // only when N is nonzero.
848     annotateNonNullNoUndefBasedOnAccess(CI, 0);
849     annotateNonNullNoUndefBasedOnAccess(CI, 1);
850   }
851 
852   // If the "bound" argument is known set N to it.  Otherwise set it to
853   // UINT64_MAX and handle it later.
854   uint64_t N = UINT64_MAX;
855   if (ConstantInt *SizeC = dyn_cast<ConstantInt>(Size))
856     N = SizeC->getZExtValue();
857 
858   if (N == 0)
859     // Fold st{p,r}ncpy(D, S, 0) to D.
860     return Dst;
861 
862   if (N == 1) {
863     Type *CharTy = B.getInt8Ty();
864     Value *CharVal = B.CreateLoad(CharTy, Src, "stxncpy.char0");
865     B.CreateStore(CharVal, Dst);
866     if (!RetEnd)
867       // Transform strncpy(D, S, 1) to return (*D = *S), D.
868       return Dst;
869 
870     // Transform stpncpy(D, S, 1) to return (*D = *S) ? D + 1 : D.
871     Value *ZeroChar = ConstantInt::get(CharTy, 0);
872     Value *Cmp = B.CreateICmpEQ(CharVal, ZeroChar, "stpncpy.char0cmp");
873 
874     Value *Off1 = B.getInt32(1);
875     Value *EndPtr = B.CreateInBoundsGEP(CharTy, Dst, Off1, "stpncpy.end");
876     return B.CreateSelect(Cmp, Dst, EndPtr, "stpncpy.sel");
877   }
878 
879   // If the length of the input string is known set SrcLen to it.
880   uint64_t SrcLen = GetStringLength(Src);
881   if (SrcLen)
882     annotateDereferenceableBytes(CI, 1, SrcLen);
883   else
884     return nullptr;
885 
886   --SrcLen; // Unbias length.
887 
888   if (SrcLen == 0) {
889     // Transform st{p,r}ncpy(D, "", N) to memset(D, '\0', N) for any N.
890     Align MemSetAlign =
891       CI->getAttributes().getParamAttrs(0).getAlignment().valueOrOne();
892     CallInst *NewCI = B.CreateMemSet(Dst, B.getInt8('\0'), Size, MemSetAlign);
893     AttrBuilder ArgAttrs(CI->getContext(), CI->getAttributes().getParamAttrs(0));
894     NewCI->setAttributes(NewCI->getAttributes().addParamAttributes(
895         CI->getContext(), 0, ArgAttrs));
896     copyFlags(*CI, NewCI);
897     return Dst;
898   }
899 
900   if (N > SrcLen + 1) {
901     if (N > 128)
902       // Bail if N is large or unknown.
903       return nullptr;
904 
905     // st{p,r}ncpy(D, "a", N) -> memcpy(D, "a\0\0\0", N) for N <= 128.
906     StringRef Str;
907     if (!getConstantStringInfo(Src, Str))
908       return nullptr;
909     std::string SrcStr = Str.str();
910     // Create a bigger, nul-padded array with the same length, SrcLen,
911     // as the original string.
912     SrcStr.resize(N, '\0');
913     Src = B.CreateGlobalString(SrcStr, "str");
914   }
915 
916   Type *PT = Callee->getFunctionType()->getParamType(0);
917   // st{p,r}ncpy(D, S, N) -> memcpy(align 1 D, align 1 S, N) when both
918   // S and N are constant.
919   CallInst *NewCI = B.CreateMemCpy(Dst, Align(1), Src, Align(1),
920                                    ConstantInt::get(DL.getIntPtrType(PT), N));
921   mergeAttributesAndFlags(NewCI, *CI);
922   if (!RetEnd)
923     return Dst;
924 
925   // stpncpy(D, S, N) returns the address of the first null in D if it writes
926   // one, otherwise D + N.
927   Value *Off = B.getInt64(std::min(SrcLen, N));
928   return B.CreateInBoundsGEP(B.getInt8Ty(), Dst, Off, "endptr");
929 }
930 
optimizeStringLength(CallInst * CI,IRBuilderBase & B,unsigned CharSize,Value * Bound)931 Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilderBase &B,
932                                                unsigned CharSize,
933                                                Value *Bound) {
934   Value *Src = CI->getArgOperand(0);
935   Type *CharTy = B.getIntNTy(CharSize);
936 
937   if (isOnlyUsedInZeroEqualityComparison(CI) &&
938       (!Bound || isKnownNonZero(Bound, DL))) {
939     // Fold strlen:
940     //   strlen(x) != 0 --> *x != 0
941     //   strlen(x) == 0 --> *x == 0
942     // and likewise strnlen with constant N > 0:
943     //   strnlen(x, N) != 0 --> *x != 0
944     //   strnlen(x, N) == 0 --> *x == 0
945     return B.CreateZExt(B.CreateLoad(CharTy, Src, "char0"),
946                         CI->getType());
947   }
948 
949   if (Bound) {
950     if (ConstantInt *BoundCst = dyn_cast<ConstantInt>(Bound)) {
951       if (BoundCst->isZero())
952         // Fold strnlen(s, 0) -> 0 for any s, constant or otherwise.
953         return ConstantInt::get(CI->getType(), 0);
954 
955       if (BoundCst->isOne()) {
956         // Fold strnlen(s, 1) -> *s ? 1 : 0 for any s.
957         Value *CharVal = B.CreateLoad(CharTy, Src, "strnlen.char0");
958         Value *ZeroChar = ConstantInt::get(CharTy, 0);
959         Value *Cmp = B.CreateICmpNE(CharVal, ZeroChar, "strnlen.char0cmp");
960         return B.CreateZExt(Cmp, CI->getType());
961       }
962     }
963   }
964 
965   if (uint64_t Len = GetStringLength(Src, CharSize)) {
966     Value *LenC = ConstantInt::get(CI->getType(), Len - 1);
967     // Fold strlen("xyz") -> 3 and strnlen("xyz", 2) -> 2
968     // and strnlen("xyz", Bound) -> min(3, Bound) for nonconstant Bound.
969     if (Bound)
970       return B.CreateBinaryIntrinsic(Intrinsic::umin, LenC, Bound);
971     return LenC;
972   }
973 
974   if (Bound)
975     // Punt for strnlen for now.
976     return nullptr;
977 
978   // If s is a constant pointer pointing to a string literal, we can fold
979   // strlen(s + x) to strlen(s) - x, when x is known to be in the range
980   // [0, strlen(s)] or the string has a single null terminator '\0' at the end.
981   // We only try to simplify strlen when the pointer s points to an array
982   // of CharSize elements. Otherwise, we would need to scale the offset x before
983   // doing the subtraction. This will make the optimization more complex, and
984   // it's not very useful because calling strlen for a pointer of other types is
985   // very uncommon.
986   if (GEPOperator *GEP = dyn_cast<GEPOperator>(Src)) {
987     // TODO: Handle subobjects.
988     if (!isGEPBasedOnPointerToString(GEP, CharSize))
989       return nullptr;
990 
991     ConstantDataArraySlice Slice;
992     if (getConstantDataArrayInfo(GEP->getOperand(0), Slice, CharSize)) {
993       uint64_t NullTermIdx;
994       if (Slice.Array == nullptr) {
995         NullTermIdx = 0;
996       } else {
997         NullTermIdx = ~((uint64_t)0);
998         for (uint64_t I = 0, E = Slice.Length; I < E; ++I) {
999           if (Slice.Array->getElementAsInteger(I + Slice.Offset) == 0) {
1000             NullTermIdx = I;
1001             break;
1002           }
1003         }
1004         // If the string does not have '\0', leave it to strlen to compute
1005         // its length.
1006         if (NullTermIdx == ~((uint64_t)0))
1007           return nullptr;
1008       }
1009 
1010       Value *Offset = GEP->getOperand(2);
1011       KnownBits Known = computeKnownBits(Offset, DL, 0, nullptr, CI, nullptr);
1012       uint64_t ArrSize =
1013              cast<ArrayType>(GEP->getSourceElementType())->getNumElements();
1014 
1015       // If Offset is not provably in the range [0, NullTermIdx], we can still
1016       // optimize if we can prove that the program has undefined behavior when
1017       // Offset is outside that range. That is the case when GEP->getOperand(0)
1018       // is a pointer to an object whose memory extent is NullTermIdx+1.
1019       if ((Known.isNonNegative() && Known.getMaxValue().ule(NullTermIdx)) ||
1020           (isa<GlobalVariable>(GEP->getOperand(0)) &&
1021            NullTermIdx == ArrSize - 1)) {
1022         Offset = B.CreateSExtOrTrunc(Offset, CI->getType());
1023         return B.CreateSub(ConstantInt::get(CI->getType(), NullTermIdx),
1024                            Offset);
1025       }
1026     }
1027   }
1028 
1029   // strlen(x?"foo":"bars") --> x ? 3 : 4
1030   if (SelectInst *SI = dyn_cast<SelectInst>(Src)) {
1031     uint64_t LenTrue = GetStringLength(SI->getTrueValue(), CharSize);
1032     uint64_t LenFalse = GetStringLength(SI->getFalseValue(), CharSize);
1033     if (LenTrue && LenFalse) {
1034       ORE.emit([&]() {
1035         return OptimizationRemark("instcombine", "simplify-libcalls", CI)
1036                << "folded strlen(select) to select of constants";
1037       });
1038       return B.CreateSelect(SI->getCondition(),
1039                             ConstantInt::get(CI->getType(), LenTrue - 1),
1040                             ConstantInt::get(CI->getType(), LenFalse - 1));
1041     }
1042   }
1043 
1044   return nullptr;
1045 }
1046 
optimizeStrLen(CallInst * CI,IRBuilderBase & B)1047 Value *LibCallSimplifier::optimizeStrLen(CallInst *CI, IRBuilderBase &B) {
1048   if (Value *V = optimizeStringLength(CI, B, 8))
1049     return V;
1050   annotateNonNullNoUndefBasedOnAccess(CI, 0);
1051   return nullptr;
1052 }
1053 
optimizeStrNLen(CallInst * CI,IRBuilderBase & B)1054 Value *LibCallSimplifier::optimizeStrNLen(CallInst *CI, IRBuilderBase &B) {
1055   Value *Bound = CI->getArgOperand(1);
1056   if (Value *V = optimizeStringLength(CI, B, 8, Bound))
1057     return V;
1058 
1059   if (isKnownNonZero(Bound, DL))
1060     annotateNonNullNoUndefBasedOnAccess(CI, 0);
1061   return nullptr;
1062 }
1063 
optimizeWcslen(CallInst * CI,IRBuilderBase & B)1064 Value *LibCallSimplifier::optimizeWcslen(CallInst *CI, IRBuilderBase &B) {
1065   Module &M = *CI->getModule();
1066   unsigned WCharSize = TLI->getWCharSize(M) * 8;
1067   // We cannot perform this optimization without wchar_size metadata.
1068   if (WCharSize == 0)
1069     return nullptr;
1070 
1071   return optimizeStringLength(CI, B, WCharSize);
1072 }
1073 
optimizeStrPBrk(CallInst * CI,IRBuilderBase & B)1074 Value *LibCallSimplifier::optimizeStrPBrk(CallInst *CI, IRBuilderBase &B) {
1075   StringRef S1, S2;
1076   bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1);
1077   bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2);
1078 
1079   // strpbrk(s, "") -> nullptr
1080   // strpbrk("", s) -> nullptr
1081   if ((HasS1 && S1.empty()) || (HasS2 && S2.empty()))
1082     return Constant::getNullValue(CI->getType());
1083 
1084   // Constant folding.
1085   if (HasS1 && HasS2) {
1086     size_t I = S1.find_first_of(S2);
1087     if (I == StringRef::npos) // No match.
1088       return Constant::getNullValue(CI->getType());
1089 
1090     return B.CreateInBoundsGEP(B.getInt8Ty(), CI->getArgOperand(0),
1091                                B.getInt64(I), "strpbrk");
1092   }
1093 
1094   // strpbrk(s, "a") -> strchr(s, 'a')
1095   if (HasS2 && S2.size() == 1)
1096     return copyFlags(*CI, emitStrChr(CI->getArgOperand(0), S2[0], B, TLI));
1097 
1098   return nullptr;
1099 }
1100 
optimizeStrTo(CallInst * CI,IRBuilderBase & B)1101 Value *LibCallSimplifier::optimizeStrTo(CallInst *CI, IRBuilderBase &B) {
1102   Value *EndPtr = CI->getArgOperand(1);
1103   if (isa<ConstantPointerNull>(EndPtr)) {
1104     // With a null EndPtr, this function won't capture the main argument.
1105     // It would be readonly too, except that it still may write to errno.
1106     CI->addParamAttr(0, Attribute::NoCapture);
1107   }
1108 
1109   return nullptr;
1110 }
1111 
optimizeStrSpn(CallInst * CI,IRBuilderBase & B)1112 Value *LibCallSimplifier::optimizeStrSpn(CallInst *CI, IRBuilderBase &B) {
1113   StringRef S1, S2;
1114   bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1);
1115   bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2);
1116 
1117   // strspn(s, "") -> 0
1118   // strspn("", s) -> 0
1119   if ((HasS1 && S1.empty()) || (HasS2 && S2.empty()))
1120     return Constant::getNullValue(CI->getType());
1121 
1122   // Constant folding.
1123   if (HasS1 && HasS2) {
1124     size_t Pos = S1.find_first_not_of(S2);
1125     if (Pos == StringRef::npos)
1126       Pos = S1.size();
1127     return ConstantInt::get(CI->getType(), Pos);
1128   }
1129 
1130   return nullptr;
1131 }
1132 
optimizeStrCSpn(CallInst * CI,IRBuilderBase & B)1133 Value *LibCallSimplifier::optimizeStrCSpn(CallInst *CI, IRBuilderBase &B) {
1134   StringRef S1, S2;
1135   bool HasS1 = getConstantStringInfo(CI->getArgOperand(0), S1);
1136   bool HasS2 = getConstantStringInfo(CI->getArgOperand(1), S2);
1137 
1138   // strcspn("", s) -> 0
1139   if (HasS1 && S1.empty())
1140     return Constant::getNullValue(CI->getType());
1141 
1142   // Constant folding.
1143   if (HasS1 && HasS2) {
1144     size_t Pos = S1.find_first_of(S2);
1145     if (Pos == StringRef::npos)
1146       Pos = S1.size();
1147     return ConstantInt::get(CI->getType(), Pos);
1148   }
1149 
1150   // strcspn(s, "") -> strlen(s)
1151   if (HasS2 && S2.empty())
1152     return copyFlags(*CI, emitStrLen(CI->getArgOperand(0), B, DL, TLI));
1153 
1154   return nullptr;
1155 }
1156 
optimizeStrStr(CallInst * CI,IRBuilderBase & B)1157 Value *LibCallSimplifier::optimizeStrStr(CallInst *CI, IRBuilderBase &B) {
1158   // fold strstr(x, x) -> x.
1159   if (CI->getArgOperand(0) == CI->getArgOperand(1))
1160     return CI->getArgOperand(0);
1161 
1162   // fold strstr(a, b) == a -> strncmp(a, b, strlen(b)) == 0
1163   if (isOnlyUsedInEqualityComparison(CI, CI->getArgOperand(0))) {
1164     Value *StrLen = emitStrLen(CI->getArgOperand(1), B, DL, TLI);
1165     if (!StrLen)
1166       return nullptr;
1167     Value *StrNCmp = emitStrNCmp(CI->getArgOperand(0), CI->getArgOperand(1),
1168                                  StrLen, B, DL, TLI);
1169     if (!StrNCmp)
1170       return nullptr;
1171     for (User *U : llvm::make_early_inc_range(CI->users())) {
1172       ICmpInst *Old = cast<ICmpInst>(U);
1173       Value *Cmp =
1174           B.CreateICmp(Old->getPredicate(), StrNCmp,
1175                        ConstantInt::getNullValue(StrNCmp->getType()), "cmp");
1176       replaceAllUsesWith(Old, Cmp);
1177     }
1178     return CI;
1179   }
1180 
1181   // See if either input string is a constant string.
1182   StringRef SearchStr, ToFindStr;
1183   bool HasStr1 = getConstantStringInfo(CI->getArgOperand(0), SearchStr);
1184   bool HasStr2 = getConstantStringInfo(CI->getArgOperand(1), ToFindStr);
1185 
1186   // fold strstr(x, "") -> x.
1187   if (HasStr2 && ToFindStr.empty())
1188     return CI->getArgOperand(0);
1189 
1190   // If both strings are known, constant fold it.
1191   if (HasStr1 && HasStr2) {
1192     size_t Offset = SearchStr.find(ToFindStr);
1193 
1194     if (Offset == StringRef::npos) // strstr("foo", "bar") -> null
1195       return Constant::getNullValue(CI->getType());
1196 
1197     // strstr("abcd", "bc") -> gep((char*)"abcd", 1)
1198     return B.CreateConstInBoundsGEP1_64(B.getInt8Ty(), CI->getArgOperand(0),
1199                                         Offset, "strstr");
1200   }
1201 
1202   // fold strstr(x, "y") -> strchr(x, 'y').
1203   if (HasStr2 && ToFindStr.size() == 1) {
1204     return emitStrChr(CI->getArgOperand(0), ToFindStr[0], B, TLI);
1205   }
1206 
1207   annotateNonNullNoUndefBasedOnAccess(CI, {0, 1});
1208   return nullptr;
1209 }
1210 
optimizeMemRChr(CallInst * CI,IRBuilderBase & B)1211 Value *LibCallSimplifier::optimizeMemRChr(CallInst *CI, IRBuilderBase &B) {
1212   Value *SrcStr = CI->getArgOperand(0);
1213   Value *Size = CI->getArgOperand(2);
1214   annotateNonNullAndDereferenceable(CI, 0, Size, DL);
1215   Value *CharVal = CI->getArgOperand(1);
1216   ConstantInt *LenC = dyn_cast<ConstantInt>(Size);
1217   Value *NullPtr = Constant::getNullValue(CI->getType());
1218 
1219   if (LenC) {
1220     if (LenC->isZero())
1221       // Fold memrchr(x, y, 0) --> null.
1222       return NullPtr;
1223 
1224     if (LenC->isOne()) {
1225       // Fold memrchr(x, y, 1) --> *x == y ? x : null for any x and y,
1226       // constant or otherwise.
1227       Value *Val = B.CreateLoad(B.getInt8Ty(), SrcStr, "memrchr.char0");
1228       // Slice off the character's high end bits.
1229       CharVal = B.CreateTrunc(CharVal, B.getInt8Ty());
1230       Value *Cmp = B.CreateICmpEQ(Val, CharVal, "memrchr.char0cmp");
1231       return B.CreateSelect(Cmp, SrcStr, NullPtr, "memrchr.sel");
1232     }
1233   }
1234 
1235   StringRef Str;
1236   if (!getConstantStringInfo(SrcStr, Str, /*TrimAtNul=*/false))
1237     return nullptr;
1238 
1239   if (Str.size() == 0)
1240     // If the array is empty fold memrchr(A, C, N) to null for any value
1241     // of C and N on the basis that the only valid value of N is zero
1242     // (otherwise the call is undefined).
1243     return NullPtr;
1244 
1245   uint64_t EndOff = UINT64_MAX;
1246   if (LenC) {
1247     EndOff = LenC->getZExtValue();
1248     if (Str.size() < EndOff)
1249       // Punt out-of-bounds accesses to sanitizers and/or libc.
1250       return nullptr;
1251   }
1252 
1253   if (ConstantInt *CharC = dyn_cast<ConstantInt>(CharVal)) {
1254     // Fold memrchr(S, C, N) for a constant C.
1255     size_t Pos = Str.rfind(CharC->getZExtValue(), EndOff);
1256     if (Pos == StringRef::npos)
1257       // When the character is not in the source array fold the result
1258       // to null regardless of Size.
1259       return NullPtr;
1260 
1261     if (LenC)
1262       // Fold memrchr(s, c, N) --> s + Pos for constant N > Pos.
1263       return B.CreateInBoundsGEP(B.getInt8Ty(), SrcStr, B.getInt64(Pos));
1264 
1265     if (Str.find(Str[Pos]) == Pos) {
1266       // When there is just a single occurrence of C in S, i.e., the one
1267       // in Str[Pos], fold
1268       //   memrchr(s, c, N) --> N <= Pos ? null : s + Pos
1269       // for nonconstant N.
1270       Value *Cmp = B.CreateICmpULE(Size, ConstantInt::get(Size->getType(), Pos),
1271                                    "memrchr.cmp");
1272       Value *SrcPlus = B.CreateInBoundsGEP(B.getInt8Ty(), SrcStr,
1273                                            B.getInt64(Pos), "memrchr.ptr_plus");
1274       return B.CreateSelect(Cmp, NullPtr, SrcPlus, "memrchr.sel");
1275     }
1276   }
1277 
1278   // Truncate the string to search at most EndOff characters.
1279   Str = Str.substr(0, EndOff);
1280   if (Str.find_first_not_of(Str[0]) != StringRef::npos)
1281     return nullptr;
1282 
1283   // If the source array consists of all equal characters, then for any
1284   // C and N (whether in bounds or not), fold memrchr(S, C, N) to
1285   //   N != 0 && *S == C ? S + N - 1 : null
1286   Type *SizeTy = Size->getType();
1287   Type *Int8Ty = B.getInt8Ty();
1288   Value *NNeZ = B.CreateICmpNE(Size, ConstantInt::get(SizeTy, 0));
1289   // Slice off the sought character's high end bits.
1290   CharVal = B.CreateTrunc(CharVal, Int8Ty);
1291   Value *CEqS0 = B.CreateICmpEQ(ConstantInt::get(Int8Ty, Str[0]), CharVal);
1292   Value *And = B.CreateLogicalAnd(NNeZ, CEqS0);
1293   Value *SizeM1 = B.CreateSub(Size, ConstantInt::get(SizeTy, 1));
1294   Value *SrcPlus =
1295       B.CreateInBoundsGEP(Int8Ty, SrcStr, SizeM1, "memrchr.ptr_plus");
1296   return B.CreateSelect(And, SrcPlus, NullPtr, "memrchr.sel");
1297 }
1298 
optimizeMemChr(CallInst * CI,IRBuilderBase & B)1299 Value *LibCallSimplifier::optimizeMemChr(CallInst *CI, IRBuilderBase &B) {
1300   Value *SrcStr = CI->getArgOperand(0);
1301   Value *Size = CI->getArgOperand(2);
1302 
1303   if (isKnownNonZero(Size, DL)) {
1304     annotateNonNullNoUndefBasedOnAccess(CI, 0);
1305     if (isOnlyUsedInEqualityComparison(CI, SrcStr))
1306       return memChrToCharCompare(CI, Size, B, DL);
1307   }
1308 
1309   Value *CharVal = CI->getArgOperand(1);
1310   ConstantInt *CharC = dyn_cast<ConstantInt>(CharVal);
1311   ConstantInt *LenC = dyn_cast<ConstantInt>(Size);
1312   Value *NullPtr = Constant::getNullValue(CI->getType());
1313 
1314   // memchr(x, y, 0) -> null
1315   if (LenC) {
1316     if (LenC->isZero())
1317       return NullPtr;
1318 
1319     if (LenC->isOne()) {
1320       // Fold memchr(x, y, 1) --> *x == y ? x : null for any x and y,
1321       // constant or otherwise.
1322       Value *Val = B.CreateLoad(B.getInt8Ty(), SrcStr, "memchr.char0");
1323       // Slice off the character's high end bits.
1324       CharVal = B.CreateTrunc(CharVal, B.getInt8Ty());
1325       Value *Cmp = B.CreateICmpEQ(Val, CharVal, "memchr.char0cmp");
1326       return B.CreateSelect(Cmp, SrcStr, NullPtr, "memchr.sel");
1327     }
1328   }
1329 
1330   StringRef Str;
1331   if (!getConstantStringInfo(SrcStr, Str, /*TrimAtNul=*/false))
1332     return nullptr;
1333 
1334   if (CharC) {
1335     size_t Pos = Str.find(CharC->getZExtValue());
1336     if (Pos == StringRef::npos)
1337       // When the character is not in the source array fold the result
1338       // to null regardless of Size.
1339       return NullPtr;
1340 
1341     // Fold memchr(s, c, n) -> n <= Pos ? null : s + Pos
1342     // When the constant Size is less than or equal to the character
1343     // position also fold the result to null.
1344     Value *Cmp = B.CreateICmpULE(Size, ConstantInt::get(Size->getType(), Pos),
1345                                  "memchr.cmp");
1346     Value *SrcPlus = B.CreateInBoundsGEP(B.getInt8Ty(), SrcStr, B.getInt64(Pos),
1347                                          "memchr.ptr");
1348     return B.CreateSelect(Cmp, NullPtr, SrcPlus);
1349   }
1350 
1351   if (Str.size() == 0)
1352     // If the array is empty fold memchr(A, C, N) to null for any value
1353     // of C and N on the basis that the only valid value of N is zero
1354     // (otherwise the call is undefined).
1355     return NullPtr;
1356 
1357   if (LenC)
1358     Str = substr(Str, LenC->getZExtValue());
1359 
1360   size_t Pos = Str.find_first_not_of(Str[0]);
1361   if (Pos == StringRef::npos
1362       || Str.find_first_not_of(Str[Pos], Pos) == StringRef::npos) {
1363     // If the source array consists of at most two consecutive sequences
1364     // of the same characters, then for any C and N (whether in bounds or
1365     // not), fold memchr(S, C, N) to
1366     //   N != 0 && *S == C ? S : null
1367     // or for the two sequences to:
1368     //   N != 0 && *S == C ? S : (N > Pos && S[Pos] == C ? S + Pos : null)
1369     //   ^Sel2                   ^Sel1 are denoted above.
1370     // The latter makes it also possible to fold strchr() calls with strings
1371     // of the same characters.
1372     Type *SizeTy = Size->getType();
1373     Type *Int8Ty = B.getInt8Ty();
1374 
1375     // Slice off the sought character's high end bits.
1376     CharVal = B.CreateTrunc(CharVal, Int8Ty);
1377 
1378     Value *Sel1 = NullPtr;
1379     if (Pos != StringRef::npos) {
1380       // Handle two consecutive sequences of the same characters.
1381       Value *PosVal = ConstantInt::get(SizeTy, Pos);
1382       Value *StrPos = ConstantInt::get(Int8Ty, Str[Pos]);
1383       Value *CEqSPos = B.CreateICmpEQ(CharVal, StrPos);
1384       Value *NGtPos = B.CreateICmp(ICmpInst::ICMP_UGT, Size, PosVal);
1385       Value *And = B.CreateAnd(CEqSPos, NGtPos);
1386       Value *SrcPlus = B.CreateInBoundsGEP(B.getInt8Ty(), SrcStr, PosVal);
1387       Sel1 = B.CreateSelect(And, SrcPlus, NullPtr, "memchr.sel1");
1388     }
1389 
1390     Value *Str0 = ConstantInt::get(Int8Ty, Str[0]);
1391     Value *CEqS0 = B.CreateICmpEQ(Str0, CharVal);
1392     Value *NNeZ = B.CreateICmpNE(Size, ConstantInt::get(SizeTy, 0));
1393     Value *And = B.CreateAnd(NNeZ, CEqS0);
1394     return B.CreateSelect(And, SrcStr, Sel1, "memchr.sel2");
1395   }
1396 
1397   if (!LenC) {
1398     if (isOnlyUsedInEqualityComparison(CI, SrcStr))
1399       // S is dereferenceable so it's safe to load from it and fold
1400       //   memchr(S, C, N) == S to N && *S == C for any C and N.
1401       // TODO: This is safe even for nonconstant S.
1402       return memChrToCharCompare(CI, Size, B, DL);
1403 
1404     // From now on we need a constant length and constant array.
1405     return nullptr;
1406   }
1407 
1408   bool OptForSize = CI->getFunction()->hasOptSize() ||
1409                     llvm::shouldOptimizeForSize(CI->getParent(), PSI, BFI,
1410                                                 PGSOQueryType::IRPass);
1411 
1412   // If the char is variable but the input str and length are not we can turn
1413   // this memchr call into a simple bit field test. Of course this only works
1414   // when the return value is only checked against null.
1415   //
1416   // It would be really nice to reuse switch lowering here but we can't change
1417   // the CFG at this point.
1418   //
1419   // memchr("\r\n", C, 2) != nullptr -> (1 << C & ((1 << '\r') | (1 << '\n')))
1420   // != 0
1421   //   after bounds check.
1422   if (OptForSize || Str.empty() || !isOnlyUsedInZeroEqualityComparison(CI))
1423     return nullptr;
1424 
1425   unsigned char Max =
1426       *std::max_element(reinterpret_cast<const unsigned char *>(Str.begin()),
1427                         reinterpret_cast<const unsigned char *>(Str.end()));
1428 
1429   // Make sure the bit field we're about to create fits in a register on the
1430   // target.
1431   // FIXME: On a 64 bit architecture this prevents us from using the
1432   // interesting range of alpha ascii chars. We could do better by emitting
1433   // two bitfields or shifting the range by 64 if no lower chars are used.
1434   if (!DL.fitsInLegalInteger(Max + 1)) {
1435     // Build chain of ORs
1436     // Transform:
1437     //    memchr("abcd", C, 4) != nullptr
1438     // to:
1439     //    (C == 'a' || C == 'b' || C == 'c' || C == 'd') != 0
1440     std::string SortedStr = Str.str();
1441     llvm::sort(SortedStr);
1442     // Compute the number of of non-contiguous ranges.
1443     unsigned NonContRanges = 1;
1444     for (size_t i = 1; i < SortedStr.size(); ++i) {
1445       if (SortedStr[i] > SortedStr[i - 1] + 1) {
1446         NonContRanges++;
1447       }
1448     }
1449 
1450     // Restrict this optimization to profitable cases with one or two range
1451     // checks.
1452     if (NonContRanges > 2)
1453       return nullptr;
1454 
1455     SmallVector<Value *> CharCompares;
1456     for (unsigned char C : SortedStr)
1457       CharCompares.push_back(
1458           B.CreateICmpEQ(CharVal, ConstantInt::get(CharVal->getType(), C)));
1459 
1460     return B.CreateIntToPtr(B.CreateOr(CharCompares), CI->getType());
1461   }
1462 
1463   // For the bit field use a power-of-2 type with at least 8 bits to avoid
1464   // creating unnecessary illegal types.
1465   unsigned char Width = NextPowerOf2(std::max((unsigned char)7, Max));
1466 
1467   // Now build the bit field.
1468   APInt Bitfield(Width, 0);
1469   for (char C : Str)
1470     Bitfield.setBit((unsigned char)C);
1471   Value *BitfieldC = B.getInt(Bitfield);
1472 
1473   // Adjust width of "C" to the bitfield width, then mask off the high bits.
1474   Value *C = B.CreateZExtOrTrunc(CharVal, BitfieldC->getType());
1475   C = B.CreateAnd(C, B.getIntN(Width, 0xFF));
1476 
1477   // First check that the bit field access is within bounds.
1478   Value *Bounds = B.CreateICmp(ICmpInst::ICMP_ULT, C, B.getIntN(Width, Width),
1479                                "memchr.bounds");
1480 
1481   // Create code that checks if the given bit is set in the field.
1482   Value *Shl = B.CreateShl(B.getIntN(Width, 1ULL), C);
1483   Value *Bits = B.CreateIsNotNull(B.CreateAnd(Shl, BitfieldC), "memchr.bits");
1484 
1485   // Finally merge both checks and cast to pointer type. The inttoptr
1486   // implicitly zexts the i1 to intptr type.
1487   return B.CreateIntToPtr(B.CreateLogicalAnd(Bounds, Bits, "memchr"),
1488                           CI->getType());
1489 }
1490 
1491 // Optimize a memcmp or, when StrNCmp is true, strncmp call CI with constant
1492 // arrays LHS and RHS and nonconstant Size.
optimizeMemCmpVarSize(CallInst * CI,Value * LHS,Value * RHS,Value * Size,bool StrNCmp,IRBuilderBase & B,const DataLayout & DL)1493 static Value *optimizeMemCmpVarSize(CallInst *CI, Value *LHS, Value *RHS,
1494                                     Value *Size, bool StrNCmp,
1495                                     IRBuilderBase &B, const DataLayout &DL) {
1496   if (LHS == RHS) // memcmp(s,s,x) -> 0
1497     return Constant::getNullValue(CI->getType());
1498 
1499   StringRef LStr, RStr;
1500   if (!getConstantStringInfo(LHS, LStr, /*TrimAtNul=*/false) ||
1501       !getConstantStringInfo(RHS, RStr, /*TrimAtNul=*/false))
1502     return nullptr;
1503 
1504   // If the contents of both constant arrays are known, fold a call to
1505   // memcmp(A, B, N) to
1506   //   N <= Pos ? 0 : (A < B ? -1 : B < A ? +1 : 0)
1507   // where Pos is the first mismatch between A and B, determined below.
1508 
1509   uint64_t Pos = 0;
1510   Value *Zero = ConstantInt::get(CI->getType(), 0);
1511   for (uint64_t MinSize = std::min(LStr.size(), RStr.size()); ; ++Pos) {
1512     if (Pos == MinSize ||
1513         (StrNCmp && (LStr[Pos] == '\0' && RStr[Pos] == '\0'))) {
1514       // One array is a leading part of the other of equal or greater
1515       // size, or for strncmp, the arrays are equal strings.
1516       // Fold the result to zero.  Size is assumed to be in bounds, since
1517       // otherwise the call would be undefined.
1518       return Zero;
1519     }
1520 
1521     if (LStr[Pos] != RStr[Pos])
1522       break;
1523   }
1524 
1525   // Normalize the result.
1526   typedef unsigned char UChar;
1527   int IRes = UChar(LStr[Pos]) < UChar(RStr[Pos]) ? -1 : 1;
1528   Value *MaxSize = ConstantInt::get(Size->getType(), Pos);
1529   Value *Cmp = B.CreateICmp(ICmpInst::ICMP_ULE, Size, MaxSize);
1530   Value *Res = ConstantInt::get(CI->getType(), IRes);
1531   return B.CreateSelect(Cmp, Zero, Res);
1532 }
1533 
1534 // Optimize a memcmp call CI with constant size Len.
optimizeMemCmpConstantSize(CallInst * CI,Value * LHS,Value * RHS,uint64_t Len,IRBuilderBase & B,const DataLayout & DL)1535 static Value *optimizeMemCmpConstantSize(CallInst *CI, Value *LHS, Value *RHS,
1536                                          uint64_t Len, IRBuilderBase &B,
1537                                          const DataLayout &DL) {
1538   if (Len == 0) // memcmp(s1,s2,0) -> 0
1539     return Constant::getNullValue(CI->getType());
1540 
1541   // memcmp(S1,S2,1) -> *(unsigned char*)LHS - *(unsigned char*)RHS
1542   if (Len == 1) {
1543     Value *LHSV = B.CreateZExt(B.CreateLoad(B.getInt8Ty(), LHS, "lhsc"),
1544                                CI->getType(), "lhsv");
1545     Value *RHSV = B.CreateZExt(B.CreateLoad(B.getInt8Ty(), RHS, "rhsc"),
1546                                CI->getType(), "rhsv");
1547     return B.CreateSub(LHSV, RHSV, "chardiff");
1548   }
1549 
1550   // memcmp(S1,S2,N/8)==0 -> (*(intN_t*)S1 != *(intN_t*)S2)==0
1551   // TODO: The case where both inputs are constants does not need to be limited
1552   // to legal integers or equality comparison. See block below this.
1553   if (DL.isLegalInteger(Len * 8) && isOnlyUsedInZeroEqualityComparison(CI)) {
1554     IntegerType *IntType = IntegerType::get(CI->getContext(), Len * 8);
1555     Align PrefAlignment = DL.getPrefTypeAlign(IntType);
1556 
1557     // First, see if we can fold either argument to a constant.
1558     Value *LHSV = nullptr;
1559     if (auto *LHSC = dyn_cast<Constant>(LHS))
1560       LHSV = ConstantFoldLoadFromConstPtr(LHSC, IntType, DL);
1561 
1562     Value *RHSV = nullptr;
1563     if (auto *RHSC = dyn_cast<Constant>(RHS))
1564       RHSV = ConstantFoldLoadFromConstPtr(RHSC, IntType, DL);
1565 
1566     // Don't generate unaligned loads. If either source is constant data,
1567     // alignment doesn't matter for that source because there is no load.
1568     if ((LHSV || getKnownAlignment(LHS, DL, CI) >= PrefAlignment) &&
1569         (RHSV || getKnownAlignment(RHS, DL, CI) >= PrefAlignment)) {
1570       if (!LHSV)
1571         LHSV = B.CreateLoad(IntType, LHS, "lhsv");
1572       if (!RHSV)
1573         RHSV = B.CreateLoad(IntType, RHS, "rhsv");
1574       return B.CreateZExt(B.CreateICmpNE(LHSV, RHSV), CI->getType(), "memcmp");
1575     }
1576   }
1577 
1578   return nullptr;
1579 }
1580 
1581 // Most simplifications for memcmp also apply to bcmp.
optimizeMemCmpBCmpCommon(CallInst * CI,IRBuilderBase & B)1582 Value *LibCallSimplifier::optimizeMemCmpBCmpCommon(CallInst *CI,
1583                                                    IRBuilderBase &B) {
1584   Value *LHS = CI->getArgOperand(0), *RHS = CI->getArgOperand(1);
1585   Value *Size = CI->getArgOperand(2);
1586 
1587   annotateNonNullAndDereferenceable(CI, {0, 1}, Size, DL);
1588 
1589   if (Value *Res = optimizeMemCmpVarSize(CI, LHS, RHS, Size, false, B, DL))
1590     return Res;
1591 
1592   // Handle constant Size.
1593   ConstantInt *LenC = dyn_cast<ConstantInt>(Size);
1594   if (!LenC)
1595     return nullptr;
1596 
1597   return optimizeMemCmpConstantSize(CI, LHS, RHS, LenC->getZExtValue(), B, DL);
1598 }
1599 
optimizeMemCmp(CallInst * CI,IRBuilderBase & B)1600 Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilderBase &B) {
1601   Module *M = CI->getModule();
1602   if (Value *V = optimizeMemCmpBCmpCommon(CI, B))
1603     return V;
1604 
1605   // memcmp(x, y, Len) == 0 -> bcmp(x, y, Len) == 0
1606   // bcmp can be more efficient than memcmp because it only has to know that
1607   // there is a difference, not how different one is to the other.
1608   if (isLibFuncEmittable(M, TLI, LibFunc_bcmp) &&
1609       isOnlyUsedInZeroEqualityComparison(CI)) {
1610     Value *LHS = CI->getArgOperand(0);
1611     Value *RHS = CI->getArgOperand(1);
1612     Value *Size = CI->getArgOperand(2);
1613     return copyFlags(*CI, emitBCmp(LHS, RHS, Size, B, DL, TLI));
1614   }
1615 
1616   return nullptr;
1617 }
1618 
optimizeBCmp(CallInst * CI,IRBuilderBase & B)1619 Value *LibCallSimplifier::optimizeBCmp(CallInst *CI, IRBuilderBase &B) {
1620   return optimizeMemCmpBCmpCommon(CI, B);
1621 }
1622 
optimizeMemCpy(CallInst * CI,IRBuilderBase & B)1623 Value *LibCallSimplifier::optimizeMemCpy(CallInst *CI, IRBuilderBase &B) {
1624   Value *Size = CI->getArgOperand(2);
1625   annotateNonNullAndDereferenceable(CI, {0, 1}, Size, DL);
1626   if (isa<IntrinsicInst>(CI))
1627     return nullptr;
1628 
1629   // memcpy(x, y, n) -> llvm.memcpy(align 1 x, align 1 y, n)
1630   CallInst *NewCI = B.CreateMemCpy(CI->getArgOperand(0), Align(1),
1631                                    CI->getArgOperand(1), Align(1), Size);
1632   mergeAttributesAndFlags(NewCI, *CI);
1633   return CI->getArgOperand(0);
1634 }
1635 
optimizeMemCCpy(CallInst * CI,IRBuilderBase & B)1636 Value *LibCallSimplifier::optimizeMemCCpy(CallInst *CI, IRBuilderBase &B) {
1637   Value *Dst = CI->getArgOperand(0);
1638   Value *Src = CI->getArgOperand(1);
1639   ConstantInt *StopChar = dyn_cast<ConstantInt>(CI->getArgOperand(2));
1640   ConstantInt *N = dyn_cast<ConstantInt>(CI->getArgOperand(3));
1641   StringRef SrcStr;
1642   if (CI->use_empty() && Dst == Src)
1643     return Dst;
1644   // memccpy(d, s, c, 0) -> nullptr
1645   if (N) {
1646     if (N->isNullValue())
1647       return Constant::getNullValue(CI->getType());
1648     if (!getConstantStringInfo(Src, SrcStr, /*TrimAtNul=*/false) ||
1649         // TODO: Handle zeroinitializer.
1650         !StopChar)
1651       return nullptr;
1652   } else {
1653     return nullptr;
1654   }
1655 
1656   // Wrap arg 'c' of type int to char
1657   size_t Pos = SrcStr.find(StopChar->getSExtValue() & 0xFF);
1658   if (Pos == StringRef::npos) {
1659     if (N->getZExtValue() <= SrcStr.size()) {
1660       copyFlags(*CI, B.CreateMemCpy(Dst, Align(1), Src, Align(1),
1661                                     CI->getArgOperand(3)));
1662       return Constant::getNullValue(CI->getType());
1663     }
1664     return nullptr;
1665   }
1666 
1667   Value *NewN =
1668       ConstantInt::get(N->getType(), std::min(uint64_t(Pos + 1), N->getZExtValue()));
1669   // memccpy -> llvm.memcpy
1670   copyFlags(*CI, B.CreateMemCpy(Dst, Align(1), Src, Align(1), NewN));
1671   return Pos + 1 <= N->getZExtValue()
1672              ? B.CreateInBoundsGEP(B.getInt8Ty(), Dst, NewN)
1673              : Constant::getNullValue(CI->getType());
1674 }
1675 
optimizeMemPCpy(CallInst * CI,IRBuilderBase & B)1676 Value *LibCallSimplifier::optimizeMemPCpy(CallInst *CI, IRBuilderBase &B) {
1677   Value *Dst = CI->getArgOperand(0);
1678   Value *N = CI->getArgOperand(2);
1679   // mempcpy(x, y, n) -> llvm.memcpy(align 1 x, align 1 y, n), x + n
1680   CallInst *NewCI =
1681       B.CreateMemCpy(Dst, Align(1), CI->getArgOperand(1), Align(1), N);
1682   // Propagate attributes, but memcpy has no return value, so make sure that
1683   // any return attributes are compliant.
1684   // TODO: Attach return value attributes to the 1st operand to preserve them?
1685   mergeAttributesAndFlags(NewCI, *CI);
1686   return B.CreateInBoundsGEP(B.getInt8Ty(), Dst, N);
1687 }
1688 
optimizeMemMove(CallInst * CI,IRBuilderBase & B)1689 Value *LibCallSimplifier::optimizeMemMove(CallInst *CI, IRBuilderBase &B) {
1690   Value *Size = CI->getArgOperand(2);
1691   annotateNonNullAndDereferenceable(CI, {0, 1}, Size, DL);
1692   if (isa<IntrinsicInst>(CI))
1693     return nullptr;
1694 
1695   // memmove(x, y, n) -> llvm.memmove(align 1 x, align 1 y, n)
1696   CallInst *NewCI = B.CreateMemMove(CI->getArgOperand(0), Align(1),
1697                                     CI->getArgOperand(1), Align(1), Size);
1698   mergeAttributesAndFlags(NewCI, *CI);
1699   return CI->getArgOperand(0);
1700 }
1701 
optimizeMemSet(CallInst * CI,IRBuilderBase & B)1702 Value *LibCallSimplifier::optimizeMemSet(CallInst *CI, IRBuilderBase &B) {
1703   Value *Size = CI->getArgOperand(2);
1704   annotateNonNullAndDereferenceable(CI, 0, Size, DL);
1705   if (isa<IntrinsicInst>(CI))
1706     return nullptr;
1707 
1708   // memset(p, v, n) -> llvm.memset(align 1 p, v, n)
1709   Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false);
1710   CallInst *NewCI = B.CreateMemSet(CI->getArgOperand(0), Val, Size, Align(1));
1711   mergeAttributesAndFlags(NewCI, *CI);
1712   return CI->getArgOperand(0);
1713 }
1714 
optimizeRealloc(CallInst * CI,IRBuilderBase & B)1715 Value *LibCallSimplifier::optimizeRealloc(CallInst *CI, IRBuilderBase &B) {
1716   if (isa<ConstantPointerNull>(CI->getArgOperand(0)))
1717     return copyFlags(*CI, emitMalloc(CI->getArgOperand(1), B, DL, TLI));
1718 
1719   return nullptr;
1720 }
1721 
1722 // When enabled, replace operator new() calls marked with a hot or cold memprof
1723 // attribute with an operator new() call that takes a __hot_cold_t parameter.
1724 // Currently this is supported by the open source version of tcmalloc, see:
1725 // https://github.com/google/tcmalloc/blob/master/tcmalloc/new_extension.h
optimizeNew(CallInst * CI,IRBuilderBase & B,LibFunc & Func)1726 Value *LibCallSimplifier::optimizeNew(CallInst *CI, IRBuilderBase &B,
1727                                       LibFunc &Func) {
1728   if (!OptimizeHotColdNew)
1729     return nullptr;
1730 
1731   uint8_t HotCold;
1732   if (CI->getAttributes().getFnAttr("memprof").getValueAsString() == "cold")
1733     HotCold = ColdNewHintValue;
1734   else if (CI->getAttributes().getFnAttr("memprof").getValueAsString() ==
1735            "notcold")
1736     HotCold = NotColdNewHintValue;
1737   else if (CI->getAttributes().getFnAttr("memprof").getValueAsString() == "hot")
1738     HotCold = HotNewHintValue;
1739   else
1740     return nullptr;
1741 
1742   // For calls that already pass a hot/cold hint, only update the hint if
1743   // directed by OptimizeExistingHotColdNew. For other calls to new, add a hint
1744   // if cold or hot, and leave as-is for default handling if "notcold" aka warm.
1745   // Note that in cases where we decide it is "notcold", it might be slightly
1746   // better to replace the hinted call with a non hinted call, to avoid the
1747   // extra paramter and the if condition check of the hint value in the
1748   // allocator. This can be considered in the future.
1749   switch (Func) {
1750   case LibFunc_Znwm12__hot_cold_t:
1751     if (OptimizeExistingHotColdNew)
1752       return emitHotColdNew(CI->getArgOperand(0), B, TLI,
1753                             LibFunc_Znwm12__hot_cold_t, HotCold);
1754     break;
1755   case LibFunc_Znwm:
1756     if (HotCold != NotColdNewHintValue)
1757       return emitHotColdNew(CI->getArgOperand(0), B, TLI,
1758                             LibFunc_Znwm12__hot_cold_t, HotCold);
1759     break;
1760   case LibFunc_Znam12__hot_cold_t:
1761     if (OptimizeExistingHotColdNew)
1762       return emitHotColdNew(CI->getArgOperand(0), B, TLI,
1763                             LibFunc_Znam12__hot_cold_t, HotCold);
1764     break;
1765   case LibFunc_Znam:
1766     if (HotCold != NotColdNewHintValue)
1767       return emitHotColdNew(CI->getArgOperand(0), B, TLI,
1768                             LibFunc_Znam12__hot_cold_t, HotCold);
1769     break;
1770   case LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t:
1771     if (OptimizeExistingHotColdNew)
1772       return emitHotColdNewNoThrow(
1773           CI->getArgOperand(0), CI->getArgOperand(1), B, TLI,
1774           LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t, HotCold);
1775     break;
1776   case LibFunc_ZnwmRKSt9nothrow_t:
1777     if (HotCold != NotColdNewHintValue)
1778       return emitHotColdNewNoThrow(
1779           CI->getArgOperand(0), CI->getArgOperand(1), B, TLI,
1780           LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t, HotCold);
1781     break;
1782   case LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t:
1783     if (OptimizeExistingHotColdNew)
1784       return emitHotColdNewNoThrow(
1785           CI->getArgOperand(0), CI->getArgOperand(1), B, TLI,
1786           LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t, HotCold);
1787     break;
1788   case LibFunc_ZnamRKSt9nothrow_t:
1789     if (HotCold != NotColdNewHintValue)
1790       return emitHotColdNewNoThrow(
1791           CI->getArgOperand(0), CI->getArgOperand(1), B, TLI,
1792           LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t, HotCold);
1793     break;
1794   case LibFunc_ZnwmSt11align_val_t12__hot_cold_t:
1795     if (OptimizeExistingHotColdNew)
1796       return emitHotColdNewAligned(
1797           CI->getArgOperand(0), CI->getArgOperand(1), B, TLI,
1798           LibFunc_ZnwmSt11align_val_t12__hot_cold_t, HotCold);
1799     break;
1800   case LibFunc_ZnwmSt11align_val_t:
1801     if (HotCold != NotColdNewHintValue)
1802       return emitHotColdNewAligned(
1803           CI->getArgOperand(0), CI->getArgOperand(1), B, TLI,
1804           LibFunc_ZnwmSt11align_val_t12__hot_cold_t, HotCold);
1805     break;
1806   case LibFunc_ZnamSt11align_val_t12__hot_cold_t:
1807     if (OptimizeExistingHotColdNew)
1808       return emitHotColdNewAligned(
1809           CI->getArgOperand(0), CI->getArgOperand(1), B, TLI,
1810           LibFunc_ZnamSt11align_val_t12__hot_cold_t, HotCold);
1811     break;
1812   case LibFunc_ZnamSt11align_val_t:
1813     if (HotCold != NotColdNewHintValue)
1814       return emitHotColdNewAligned(
1815           CI->getArgOperand(0), CI->getArgOperand(1), B, TLI,
1816           LibFunc_ZnamSt11align_val_t12__hot_cold_t, HotCold);
1817     break;
1818   case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t:
1819     if (OptimizeExistingHotColdNew)
1820       return emitHotColdNewAlignedNoThrow(
1821           CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B,
1822           TLI, LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t,
1823           HotCold);
1824     break;
1825   case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t:
1826     if (HotCold != NotColdNewHintValue)
1827       return emitHotColdNewAlignedNoThrow(
1828           CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B,
1829           TLI, LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t,
1830           HotCold);
1831     break;
1832   case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t:
1833     if (OptimizeExistingHotColdNew)
1834       return emitHotColdNewAlignedNoThrow(
1835           CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B,
1836           TLI, LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t,
1837           HotCold);
1838     break;
1839   case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t:
1840     if (HotCold != NotColdNewHintValue)
1841       return emitHotColdNewAlignedNoThrow(
1842           CI->getArgOperand(0), CI->getArgOperand(1), CI->getArgOperand(2), B,
1843           TLI, LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t,
1844           HotCold);
1845     break;
1846   default:
1847     return nullptr;
1848   }
1849   return nullptr;
1850 }
1851 
1852 //===----------------------------------------------------------------------===//
1853 // Math Library Optimizations
1854 //===----------------------------------------------------------------------===//
1855 
1856 // Replace a libcall \p CI with a call to intrinsic \p IID
replaceUnaryCall(CallInst * CI,IRBuilderBase & B,Intrinsic::ID IID)1857 static Value *replaceUnaryCall(CallInst *CI, IRBuilderBase &B,
1858                                Intrinsic::ID IID) {
1859   CallInst *NewCall = B.CreateUnaryIntrinsic(IID, CI->getArgOperand(0), CI);
1860   NewCall->takeName(CI);
1861   return copyFlags(*CI, NewCall);
1862 }
1863 
1864 /// Return a variant of Val with float type.
1865 /// Currently this works in two cases: If Val is an FPExtension of a float
1866 /// value to something bigger, simply return the operand.
1867 /// If Val is a ConstantFP but can be converted to a float ConstantFP without
1868 /// loss of precision do so.
valueHasFloatPrecision(Value * Val)1869 static Value *valueHasFloatPrecision(Value *Val) {
1870   if (FPExtInst *Cast = dyn_cast<FPExtInst>(Val)) {
1871     Value *Op = Cast->getOperand(0);
1872     if (Op->getType()->isFloatTy())
1873       return Op;
1874   }
1875   if (ConstantFP *Const = dyn_cast<ConstantFP>(Val)) {
1876     APFloat F = Const->getValueAPF();
1877     bool losesInfo;
1878     (void)F.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven,
1879                     &losesInfo);
1880     if (!losesInfo)
1881       return ConstantFP::get(Const->getContext(), F);
1882   }
1883   return nullptr;
1884 }
1885 
1886 /// Shrink double -> float functions.
optimizeDoubleFP(CallInst * CI,IRBuilderBase & B,bool isBinary,const TargetLibraryInfo * TLI,bool isPrecise=false)1887 static Value *optimizeDoubleFP(CallInst *CI, IRBuilderBase &B,
1888                                bool isBinary, const TargetLibraryInfo *TLI,
1889                                bool isPrecise = false) {
1890   Function *CalleeFn = CI->getCalledFunction();
1891   if (!CI->getType()->isDoubleTy() || !CalleeFn)
1892     return nullptr;
1893 
1894   // If not all the uses of the function are converted to float, then bail out.
1895   // This matters if the precision of the result is more important than the
1896   // precision of the arguments.
1897   if (isPrecise)
1898     for (User *U : CI->users()) {
1899       FPTruncInst *Cast = dyn_cast<FPTruncInst>(U);
1900       if (!Cast || !Cast->getType()->isFloatTy())
1901         return nullptr;
1902     }
1903 
1904   // If this is something like 'g((double) float)', convert to 'gf(float)'.
1905   Value *V[2];
1906   V[0] = valueHasFloatPrecision(CI->getArgOperand(0));
1907   V[1] = isBinary ? valueHasFloatPrecision(CI->getArgOperand(1)) : nullptr;
1908   if (!V[0] || (isBinary && !V[1]))
1909     return nullptr;
1910 
1911   // If call isn't an intrinsic, check that it isn't within a function with the
1912   // same name as the float version of this call, otherwise the result is an
1913   // infinite loop.  For example, from MinGW-w64:
1914   //
1915   // float expf(float val) { return (float) exp((double) val); }
1916   StringRef CalleeName = CalleeFn->getName();
1917   bool IsIntrinsic = CalleeFn->isIntrinsic();
1918   if (!IsIntrinsic) {
1919     StringRef CallerName = CI->getFunction()->getName();
1920     if (!CallerName.empty() && CallerName.back() == 'f' &&
1921         CallerName.size() == (CalleeName.size() + 1) &&
1922         CallerName.starts_with(CalleeName))
1923       return nullptr;
1924   }
1925 
1926   // Propagate the math semantics from the current function to the new function.
1927   IRBuilderBase::FastMathFlagGuard Guard(B);
1928   B.setFastMathFlags(CI->getFastMathFlags());
1929 
1930   // g((double) float) -> (double) gf(float)
1931   Value *R;
1932   if (IsIntrinsic) {
1933     Module *M = CI->getModule();
1934     Intrinsic::ID IID = CalleeFn->getIntrinsicID();
1935     Function *Fn = Intrinsic::getDeclaration(M, IID, B.getFloatTy());
1936     R = isBinary ? B.CreateCall(Fn, V) : B.CreateCall(Fn, V[0]);
1937   } else {
1938     AttributeList CalleeAttrs = CalleeFn->getAttributes();
1939     R = isBinary ? emitBinaryFloatFnCall(V[0], V[1], TLI, CalleeName, B,
1940                                          CalleeAttrs)
1941                  : emitUnaryFloatFnCall(V[0], TLI, CalleeName, B, CalleeAttrs);
1942   }
1943   return B.CreateFPExt(R, B.getDoubleTy());
1944 }
1945 
1946 /// Shrink double -> float for unary functions.
optimizeUnaryDoubleFP(CallInst * CI,IRBuilderBase & B,const TargetLibraryInfo * TLI,bool isPrecise=false)1947 static Value *optimizeUnaryDoubleFP(CallInst *CI, IRBuilderBase &B,
1948                                     const TargetLibraryInfo *TLI,
1949                                     bool isPrecise = false) {
1950   return optimizeDoubleFP(CI, B, false, TLI, isPrecise);
1951 }
1952 
1953 /// Shrink double -> float for binary functions.
optimizeBinaryDoubleFP(CallInst * CI,IRBuilderBase & B,const TargetLibraryInfo * TLI,bool isPrecise=false)1954 static Value *optimizeBinaryDoubleFP(CallInst *CI, IRBuilderBase &B,
1955                                      const TargetLibraryInfo *TLI,
1956                                      bool isPrecise = false) {
1957   return optimizeDoubleFP(CI, B, true, TLI, isPrecise);
1958 }
1959 
1960 // cabs(z) -> sqrt((creal(z)*creal(z)) + (cimag(z)*cimag(z)))
optimizeCAbs(CallInst * CI,IRBuilderBase & B)1961 Value *LibCallSimplifier::optimizeCAbs(CallInst *CI, IRBuilderBase &B) {
1962   Value *Real, *Imag;
1963 
1964   if (CI->arg_size() == 1) {
1965 
1966     if (!CI->isFast())
1967       return nullptr;
1968 
1969     Value *Op = CI->getArgOperand(0);
1970     assert(Op->getType()->isArrayTy() && "Unexpected signature for cabs!");
1971 
1972     Real = B.CreateExtractValue(Op, 0, "real");
1973     Imag = B.CreateExtractValue(Op, 1, "imag");
1974 
1975   } else {
1976     assert(CI->arg_size() == 2 && "Unexpected signature for cabs!");
1977 
1978     Real = CI->getArgOperand(0);
1979     Imag = CI->getArgOperand(1);
1980 
1981     // if real or imaginary part is zero, simplify to abs(cimag(z))
1982     // or abs(creal(z))
1983     Value *AbsOp = nullptr;
1984     if (ConstantFP *ConstReal = dyn_cast<ConstantFP>(Real)) {
1985       if (ConstReal->isZero())
1986         AbsOp = Imag;
1987 
1988     } else if (ConstantFP *ConstImag = dyn_cast<ConstantFP>(Imag)) {
1989       if (ConstImag->isZero())
1990         AbsOp = Real;
1991     }
1992 
1993     if (AbsOp) {
1994       IRBuilderBase::FastMathFlagGuard Guard(B);
1995       B.setFastMathFlags(CI->getFastMathFlags());
1996 
1997       return copyFlags(
1998           *CI, B.CreateUnaryIntrinsic(Intrinsic::fabs, AbsOp, nullptr, "cabs"));
1999     }
2000 
2001     if (!CI->isFast())
2002       return nullptr;
2003   }
2004 
2005   // Propagate fast-math flags from the existing call to new instructions.
2006   IRBuilderBase::FastMathFlagGuard Guard(B);
2007   B.setFastMathFlags(CI->getFastMathFlags());
2008 
2009   Value *RealReal = B.CreateFMul(Real, Real);
2010   Value *ImagImag = B.CreateFMul(Imag, Imag);
2011 
2012   return copyFlags(*CI, B.CreateUnaryIntrinsic(Intrinsic::sqrt,
2013                                                B.CreateFAdd(RealReal, ImagImag),
2014                                                nullptr, "cabs"));
2015 }
2016 
2017 // Return a properly extended integer (DstWidth bits wide) if the operation is
2018 // an itofp.
getIntToFPVal(Value * I2F,IRBuilderBase & B,unsigned DstWidth)2019 static Value *getIntToFPVal(Value *I2F, IRBuilderBase &B, unsigned DstWidth) {
2020   if (isa<SIToFPInst>(I2F) || isa<UIToFPInst>(I2F)) {
2021     Value *Op = cast<Instruction>(I2F)->getOperand(0);
2022     // Make sure that the exponent fits inside an "int" of size DstWidth,
2023     // thus avoiding any range issues that FP has not.
2024     unsigned BitWidth = Op->getType()->getScalarSizeInBits();
2025     if (BitWidth < DstWidth || (BitWidth == DstWidth && isa<SIToFPInst>(I2F))) {
2026       Type *IntTy = Op->getType()->getWithNewBitWidth(DstWidth);
2027       return isa<SIToFPInst>(I2F) ? B.CreateSExt(Op, IntTy)
2028                                   : B.CreateZExt(Op, IntTy);
2029     }
2030   }
2031 
2032   return nullptr;
2033 }
2034 
2035 /// Use exp{,2}(x * y) for pow(exp{,2}(x), y);
2036 /// ldexp(1.0, x) for pow(2.0, itofp(x)); exp2(n * x) for pow(2.0 ** n, x);
2037 /// exp10(x) for pow(10.0, x); exp2(log2(n) * x) for pow(n, x).
replacePowWithExp(CallInst * Pow,IRBuilderBase & B)2038 Value *LibCallSimplifier::replacePowWithExp(CallInst *Pow, IRBuilderBase &B) {
2039   Module *M = Pow->getModule();
2040   Value *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1);
2041   Type *Ty = Pow->getType();
2042   bool Ignored;
2043 
2044   // Evaluate special cases related to a nested function as the base.
2045 
2046   // pow(exp(x), y) -> exp(x * y)
2047   // pow(exp2(x), y) -> exp2(x * y)
2048   // If exp{,2}() is used only once, it is better to fold two transcendental
2049   // math functions into one.  If used again, exp{,2}() would still have to be
2050   // called with the original argument, then keep both original transcendental
2051   // functions.  However, this transformation is only safe with fully relaxed
2052   // math semantics, since, besides rounding differences, it changes overflow
2053   // and underflow behavior quite dramatically.  For example:
2054   //   pow(exp(1000), 0.001) = pow(inf, 0.001) = inf
2055   // Whereas:
2056   //   exp(1000 * 0.001) = exp(1)
2057   // TODO: Loosen the requirement for fully relaxed math semantics.
2058   // TODO: Handle exp10() when more targets have it available.
2059   CallInst *BaseFn = dyn_cast<CallInst>(Base);
2060   if (BaseFn && BaseFn->hasOneUse() && BaseFn->isFast() && Pow->isFast()) {
2061     LibFunc LibFn;
2062 
2063     Function *CalleeFn = BaseFn->getCalledFunction();
2064     if (CalleeFn && TLI->getLibFunc(CalleeFn->getName(), LibFn) &&
2065         isLibFuncEmittable(M, TLI, LibFn)) {
2066       StringRef ExpName;
2067       Intrinsic::ID ID;
2068       Value *ExpFn;
2069       LibFunc LibFnFloat, LibFnDouble, LibFnLongDouble;
2070 
2071       switch (LibFn) {
2072       default:
2073         return nullptr;
2074       case LibFunc_expf:
2075       case LibFunc_exp:
2076       case LibFunc_expl:
2077         ExpName = TLI->getName(LibFunc_exp);
2078         ID = Intrinsic::exp;
2079         LibFnFloat = LibFunc_expf;
2080         LibFnDouble = LibFunc_exp;
2081         LibFnLongDouble = LibFunc_expl;
2082         break;
2083       case LibFunc_exp2f:
2084       case LibFunc_exp2:
2085       case LibFunc_exp2l:
2086         ExpName = TLI->getName(LibFunc_exp2);
2087         ID = Intrinsic::exp2;
2088         LibFnFloat = LibFunc_exp2f;
2089         LibFnDouble = LibFunc_exp2;
2090         LibFnLongDouble = LibFunc_exp2l;
2091         break;
2092       }
2093 
2094       // Create new exp{,2}() with the product as its argument.
2095       Value *FMul = B.CreateFMul(BaseFn->getArgOperand(0), Expo, "mul");
2096       ExpFn = BaseFn->doesNotAccessMemory()
2097                   ? B.CreateUnaryIntrinsic(ID, FMul, nullptr, ExpName)
2098                   : emitUnaryFloatFnCall(FMul, TLI, LibFnDouble, LibFnFloat,
2099                                          LibFnLongDouble, B,
2100                                          BaseFn->getAttributes());
2101 
2102       // Since the new exp{,2}() is different from the original one, dead code
2103       // elimination cannot be trusted to remove it, since it may have side
2104       // effects (e.g., errno).  When the only consumer for the original
2105       // exp{,2}() is pow(), then it has to be explicitly erased.
2106       substituteInParent(BaseFn, ExpFn);
2107       return ExpFn;
2108     }
2109   }
2110 
2111   // Evaluate special cases related to a constant base.
2112 
2113   const APFloat *BaseF;
2114   if (!match(Base, m_APFloat(BaseF)))
2115     return nullptr;
2116 
2117   AttributeList NoAttrs; // Attributes are only meaningful on the original call
2118 
2119   const bool UseIntrinsic = Pow->doesNotAccessMemory();
2120 
2121   // pow(2.0, itofp(x)) -> ldexp(1.0, x)
2122   if ((UseIntrinsic || !Ty->isVectorTy()) && BaseF->isExactlyValue(2.0) &&
2123       (isa<SIToFPInst>(Expo) || isa<UIToFPInst>(Expo)) &&
2124       (UseIntrinsic ||
2125        hasFloatFn(M, TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl))) {
2126 
2127     // TODO: Shouldn't really need to depend on getIntToFPVal for intrinsic. Can
2128     // just directly use the original integer type.
2129     if (Value *ExpoI = getIntToFPVal(Expo, B, TLI->getIntSize())) {
2130       Constant *One = ConstantFP::get(Ty, 1.0);
2131 
2132       if (UseIntrinsic) {
2133         return copyFlags(*Pow, B.CreateIntrinsic(Intrinsic::ldexp,
2134                                                  {Ty, ExpoI->getType()},
2135                                                  {One, ExpoI}, Pow, "exp2"));
2136       }
2137 
2138       return copyFlags(*Pow, emitBinaryFloatFnCall(
2139                                  One, ExpoI, TLI, LibFunc_ldexp, LibFunc_ldexpf,
2140                                  LibFunc_ldexpl, B, NoAttrs));
2141     }
2142   }
2143 
2144   // pow(2.0 ** n, x) -> exp2(n * x)
2145   if (hasFloatFn(M, TLI, Ty, LibFunc_exp2, LibFunc_exp2f, LibFunc_exp2l)) {
2146     APFloat BaseR = APFloat(1.0);
2147     BaseR.convert(BaseF->getSemantics(), APFloat::rmTowardZero, &Ignored);
2148     BaseR = BaseR / *BaseF;
2149     bool IsInteger = BaseF->isInteger(), IsReciprocal = BaseR.isInteger();
2150     const APFloat *NF = IsReciprocal ? &BaseR : BaseF;
2151     APSInt NI(64, false);
2152     if ((IsInteger || IsReciprocal) &&
2153         NF->convertToInteger(NI, APFloat::rmTowardZero, &Ignored) ==
2154             APFloat::opOK &&
2155         NI > 1 && NI.isPowerOf2()) {
2156       double N = NI.logBase2() * (IsReciprocal ? -1.0 : 1.0);
2157       Value *FMul = B.CreateFMul(Expo, ConstantFP::get(Ty, N), "mul");
2158       if (Pow->doesNotAccessMemory())
2159         return copyFlags(*Pow, B.CreateUnaryIntrinsic(Intrinsic::exp2, FMul,
2160                                                       nullptr, "exp2"));
2161       else
2162         return copyFlags(*Pow, emitUnaryFloatFnCall(FMul, TLI, LibFunc_exp2,
2163                                                     LibFunc_exp2f,
2164                                                     LibFunc_exp2l, B, NoAttrs));
2165     }
2166   }
2167 
2168   // pow(10.0, x) -> exp10(x)
2169   if (BaseF->isExactlyValue(10.0) &&
2170       hasFloatFn(M, TLI, Ty, LibFunc_exp10, LibFunc_exp10f, LibFunc_exp10l)) {
2171 
2172     if (Pow->doesNotAccessMemory()) {
2173       CallInst *NewExp10 =
2174           B.CreateIntrinsic(Intrinsic::exp10, {Ty}, {Expo}, Pow, "exp10");
2175       return copyFlags(*Pow, NewExp10);
2176     }
2177 
2178     return copyFlags(*Pow, emitUnaryFloatFnCall(Expo, TLI, LibFunc_exp10,
2179                                                 LibFunc_exp10f, LibFunc_exp10l,
2180                                                 B, NoAttrs));
2181   }
2182 
2183   // pow(x, y) -> exp2(log2(x) * y)
2184   if (Pow->hasApproxFunc() && Pow->hasNoNaNs() && BaseF->isFiniteNonZero() &&
2185       !BaseF->isNegative()) {
2186     // pow(1, inf) is defined to be 1 but exp2(log2(1) * inf) evaluates to NaN.
2187     // Luckily optimizePow has already handled the x == 1 case.
2188     assert(!match(Base, m_FPOne()) &&
2189            "pow(1.0, y) should have been simplified earlier!");
2190 
2191     Value *Log = nullptr;
2192     if (Ty->isFloatTy())
2193       Log = ConstantFP::get(Ty, std::log2(BaseF->convertToFloat()));
2194     else if (Ty->isDoubleTy())
2195       Log = ConstantFP::get(Ty, std::log2(BaseF->convertToDouble()));
2196 
2197     if (Log) {
2198       Value *FMul = B.CreateFMul(Log, Expo, "mul");
2199       if (Pow->doesNotAccessMemory())
2200         return copyFlags(*Pow, B.CreateUnaryIntrinsic(Intrinsic::exp2, FMul,
2201                                                       nullptr, "exp2"));
2202       else if (hasFloatFn(M, TLI, Ty, LibFunc_exp2, LibFunc_exp2f,
2203                           LibFunc_exp2l))
2204         return copyFlags(*Pow, emitUnaryFloatFnCall(FMul, TLI, LibFunc_exp2,
2205                                                     LibFunc_exp2f,
2206                                                     LibFunc_exp2l, B, NoAttrs));
2207     }
2208   }
2209 
2210   return nullptr;
2211 }
2212 
getSqrtCall(Value * V,AttributeList Attrs,bool NoErrno,Module * M,IRBuilderBase & B,const TargetLibraryInfo * TLI)2213 static Value *getSqrtCall(Value *V, AttributeList Attrs, bool NoErrno,
2214                           Module *M, IRBuilderBase &B,
2215                           const TargetLibraryInfo *TLI) {
2216   // If errno is never set, then use the intrinsic for sqrt().
2217   if (NoErrno)
2218     return B.CreateUnaryIntrinsic(Intrinsic::sqrt, V, nullptr, "sqrt");
2219 
2220   // Otherwise, use the libcall for sqrt().
2221   if (hasFloatFn(M, TLI, V->getType(), LibFunc_sqrt, LibFunc_sqrtf,
2222                  LibFunc_sqrtl))
2223     // TODO: We also should check that the target can in fact lower the sqrt()
2224     // libcall. We currently have no way to ask this question, so we ask if
2225     // the target has a sqrt() libcall, which is not exactly the same.
2226     return emitUnaryFloatFnCall(V, TLI, LibFunc_sqrt, LibFunc_sqrtf,
2227                                 LibFunc_sqrtl, B, Attrs);
2228 
2229   return nullptr;
2230 }
2231 
2232 /// Use square root in place of pow(x, +/-0.5).
replacePowWithSqrt(CallInst * Pow,IRBuilderBase & B)2233 Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilderBase &B) {
2234   Value *Sqrt, *Base = Pow->getArgOperand(0), *Expo = Pow->getArgOperand(1);
2235   Module *Mod = Pow->getModule();
2236   Type *Ty = Pow->getType();
2237 
2238   const APFloat *ExpoF;
2239   if (!match(Expo, m_APFloat(ExpoF)) ||
2240       (!ExpoF->isExactlyValue(0.5) && !ExpoF->isExactlyValue(-0.5)))
2241     return nullptr;
2242 
2243   // Converting pow(X, -0.5) to 1/sqrt(X) may introduce an extra rounding step,
2244   // so that requires fast-math-flags (afn or reassoc).
2245   if (ExpoF->isNegative() && (!Pow->hasApproxFunc() && !Pow->hasAllowReassoc()))
2246     return nullptr;
2247 
2248   // If we have a pow() library call (accesses memory) and we can't guarantee
2249   // that the base is not an infinity, give up:
2250   // pow(-Inf, 0.5) is optionally required to have a result of +Inf (not setting
2251   // errno), but sqrt(-Inf) is required by various standards to set errno.
2252   if (!Pow->doesNotAccessMemory() && !Pow->hasNoInfs() &&
2253       !isKnownNeverInfinity(Base, 0,
2254                             SimplifyQuery(DL, TLI, /*DT=*/nullptr, AC, Pow)))
2255     return nullptr;
2256 
2257   Sqrt = getSqrtCall(Base, AttributeList(), Pow->doesNotAccessMemory(), Mod, B,
2258                      TLI);
2259   if (!Sqrt)
2260     return nullptr;
2261 
2262   // Handle signed zero base by expanding to fabs(sqrt(x)).
2263   if (!Pow->hasNoSignedZeros())
2264     Sqrt = B.CreateUnaryIntrinsic(Intrinsic::fabs, Sqrt, nullptr, "abs");
2265 
2266   Sqrt = copyFlags(*Pow, Sqrt);
2267 
2268   // Handle non finite base by expanding to
2269   // (x == -infinity ? +infinity : sqrt(x)).
2270   if (!Pow->hasNoInfs()) {
2271     Value *PosInf = ConstantFP::getInfinity(Ty),
2272           *NegInf = ConstantFP::getInfinity(Ty, true);
2273     Value *FCmp = B.CreateFCmpOEQ(Base, NegInf, "isinf");
2274     Sqrt = B.CreateSelect(FCmp, PosInf, Sqrt);
2275   }
2276 
2277   // If the exponent is negative, then get the reciprocal.
2278   if (ExpoF->isNegative())
2279     Sqrt = B.CreateFDiv(ConstantFP::get(Ty, 1.0), Sqrt, "reciprocal");
2280 
2281   return Sqrt;
2282 }
2283 
createPowWithIntegerExponent(Value * Base,Value * Expo,Module * M,IRBuilderBase & B)2284 static Value *createPowWithIntegerExponent(Value *Base, Value *Expo, Module *M,
2285                                            IRBuilderBase &B) {
2286   Value *Args[] = {Base, Expo};
2287   Type *Types[] = {Base->getType(), Expo->getType()};
2288   return B.CreateIntrinsic(Intrinsic::powi, Types, Args);
2289 }
2290 
optimizePow(CallInst * Pow,IRBuilderBase & B)2291 Value *LibCallSimplifier::optimizePow(CallInst *Pow, IRBuilderBase &B) {
2292   Value *Base = Pow->getArgOperand(0);
2293   Value *Expo = Pow->getArgOperand(1);
2294   Function *Callee = Pow->getCalledFunction();
2295   StringRef Name = Callee->getName();
2296   Type *Ty = Pow->getType();
2297   Module *M = Pow->getModule();
2298   bool AllowApprox = Pow->hasApproxFunc();
2299   bool Ignored;
2300 
2301   // Propagate the math semantics from the call to any created instructions.
2302   IRBuilderBase::FastMathFlagGuard Guard(B);
2303   B.setFastMathFlags(Pow->getFastMathFlags());
2304   // Evaluate special cases related to the base.
2305 
2306   // pow(1.0, x) -> 1.0
2307   if (match(Base, m_FPOne()))
2308     return Base;
2309 
2310   if (Value *Exp = replacePowWithExp(Pow, B))
2311     return Exp;
2312 
2313   // Evaluate special cases related to the exponent.
2314 
2315   // pow(x, -1.0) -> 1.0 / x
2316   if (match(Expo, m_SpecificFP(-1.0)))
2317     return B.CreateFDiv(ConstantFP::get(Ty, 1.0), Base, "reciprocal");
2318 
2319   // pow(x, +/-0.0) -> 1.0
2320   if (match(Expo, m_AnyZeroFP()))
2321     return ConstantFP::get(Ty, 1.0);
2322 
2323   // pow(x, 1.0) -> x
2324   if (match(Expo, m_FPOne()))
2325     return Base;
2326 
2327   // pow(x, 2.0) -> x * x
2328   if (match(Expo, m_SpecificFP(2.0)))
2329     return B.CreateFMul(Base, Base, "square");
2330 
2331   if (Value *Sqrt = replacePowWithSqrt(Pow, B))
2332     return Sqrt;
2333 
2334   // If we can approximate pow:
2335   // pow(x, n) -> powi(x, n) * sqrt(x) if n has exactly a 0.5 fraction
2336   // pow(x, n) -> powi(x, n) if n is a constant signed integer value
2337   const APFloat *ExpoF;
2338   if (AllowApprox && match(Expo, m_APFloat(ExpoF)) &&
2339       !ExpoF->isExactlyValue(0.5) && !ExpoF->isExactlyValue(-0.5)) {
2340     APFloat ExpoA(abs(*ExpoF));
2341     APFloat ExpoI(*ExpoF);
2342     Value *Sqrt = nullptr;
2343     if (!ExpoA.isInteger()) {
2344       APFloat Expo2 = ExpoA;
2345       // To check if ExpoA is an integer + 0.5, we add it to itself. If there
2346       // is no floating point exception and the result is an integer, then
2347       // ExpoA == integer + 0.5
2348       if (Expo2.add(ExpoA, APFloat::rmNearestTiesToEven) != APFloat::opOK)
2349         return nullptr;
2350 
2351       if (!Expo2.isInteger())
2352         return nullptr;
2353 
2354       if (ExpoI.roundToIntegral(APFloat::rmTowardNegative) !=
2355           APFloat::opInexact)
2356         return nullptr;
2357       if (!ExpoI.isInteger())
2358         return nullptr;
2359       ExpoF = &ExpoI;
2360 
2361       Sqrt = getSqrtCall(Base, AttributeList(), Pow->doesNotAccessMemory(), M,
2362                          B, TLI);
2363       if (!Sqrt)
2364         return nullptr;
2365     }
2366 
2367     // 0.5 fraction is now optionally handled.
2368     // Do pow -> powi for remaining integer exponent
2369     APSInt IntExpo(TLI->getIntSize(), /*isUnsigned=*/false);
2370     if (ExpoF->isInteger() &&
2371         ExpoF->convertToInteger(IntExpo, APFloat::rmTowardZero, &Ignored) ==
2372             APFloat::opOK) {
2373       Value *PowI = copyFlags(
2374           *Pow,
2375           createPowWithIntegerExponent(
2376               Base, ConstantInt::get(B.getIntNTy(TLI->getIntSize()), IntExpo),
2377               M, B));
2378 
2379       if (PowI && Sqrt)
2380         return B.CreateFMul(PowI, Sqrt);
2381 
2382       return PowI;
2383     }
2384   }
2385 
2386   // powf(x, itofp(y)) -> powi(x, y)
2387   if (AllowApprox && (isa<SIToFPInst>(Expo) || isa<UIToFPInst>(Expo))) {
2388     if (Value *ExpoI = getIntToFPVal(Expo, B, TLI->getIntSize()))
2389       return copyFlags(*Pow, createPowWithIntegerExponent(Base, ExpoI, M, B));
2390   }
2391 
2392   // Shrink pow() to powf() if the arguments are single precision,
2393   // unless the result is expected to be double precision.
2394   if (UnsafeFPShrink && Name == TLI->getName(LibFunc_pow) &&
2395       hasFloatVersion(M, Name)) {
2396     if (Value *Shrunk = optimizeBinaryDoubleFP(Pow, B, TLI, true))
2397       return Shrunk;
2398   }
2399 
2400   return nullptr;
2401 }
2402 
optimizeExp2(CallInst * CI,IRBuilderBase & B)2403 Value *LibCallSimplifier::optimizeExp2(CallInst *CI, IRBuilderBase &B) {
2404   Module *M = CI->getModule();
2405   Function *Callee = CI->getCalledFunction();
2406   StringRef Name = Callee->getName();
2407   Value *Ret = nullptr;
2408   if (UnsafeFPShrink && Name == TLI->getName(LibFunc_exp2) &&
2409       hasFloatVersion(M, Name))
2410     Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
2411 
2412   // If we have an llvm.exp2 intrinsic, emit the llvm.ldexp intrinsic. If we
2413   // have the libcall, emit the libcall.
2414   //
2415   // TODO: In principle we should be able to just always use the intrinsic for
2416   // any doesNotAccessMemory callsite.
2417 
2418   const bool UseIntrinsic = Callee->isIntrinsic();
2419   // Bail out for vectors because the code below only expects scalars.
2420   Type *Ty = CI->getType();
2421   if (!UseIntrinsic && Ty->isVectorTy())
2422     return Ret;
2423 
2424   // exp2(sitofp(x)) -> ldexp(1.0, sext(x))  if sizeof(x) <= IntSize
2425   // exp2(uitofp(x)) -> ldexp(1.0, zext(x))  if sizeof(x) < IntSize
2426   Value *Op = CI->getArgOperand(0);
2427   if ((isa<SIToFPInst>(Op) || isa<UIToFPInst>(Op)) &&
2428       (UseIntrinsic ||
2429        hasFloatFn(M, TLI, Ty, LibFunc_ldexp, LibFunc_ldexpf, LibFunc_ldexpl))) {
2430     if (Value *Exp = getIntToFPVal(Op, B, TLI->getIntSize())) {
2431       Constant *One = ConstantFP::get(Ty, 1.0);
2432 
2433       if (UseIntrinsic) {
2434         return copyFlags(*CI, B.CreateIntrinsic(Intrinsic::ldexp,
2435                                                 {Ty, Exp->getType()},
2436                                                 {One, Exp}, CI));
2437       }
2438 
2439       IRBuilderBase::FastMathFlagGuard Guard(B);
2440       B.setFastMathFlags(CI->getFastMathFlags());
2441       return copyFlags(*CI, emitBinaryFloatFnCall(
2442                                 One, Exp, TLI, LibFunc_ldexp, LibFunc_ldexpf,
2443                                 LibFunc_ldexpl, B, AttributeList()));
2444     }
2445   }
2446 
2447   return Ret;
2448 }
2449 
optimizeFMinFMax(CallInst * CI,IRBuilderBase & B)2450 Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilderBase &B) {
2451   Module *M = CI->getModule();
2452 
2453   // If we can shrink the call to a float function rather than a double
2454   // function, do that first.
2455   Function *Callee = CI->getCalledFunction();
2456   StringRef Name = Callee->getName();
2457   if ((Name == "fmin" || Name == "fmax") && hasFloatVersion(M, Name))
2458     if (Value *Ret = optimizeBinaryDoubleFP(CI, B, TLI))
2459       return Ret;
2460 
2461   // The LLVM intrinsics minnum/maxnum correspond to fmin/fmax. Canonicalize to
2462   // the intrinsics for improved optimization (for example, vectorization).
2463   // No-signed-zeros is implied by the definitions of fmax/fmin themselves.
2464   // From the C standard draft WG14/N1256:
2465   // "Ideally, fmax would be sensitive to the sign of zero, for example
2466   // fmax(-0.0, +0.0) would return +0; however, implementation in software
2467   // might be impractical."
2468   IRBuilderBase::FastMathFlagGuard Guard(B);
2469   FastMathFlags FMF = CI->getFastMathFlags();
2470   FMF.setNoSignedZeros();
2471   B.setFastMathFlags(FMF);
2472 
2473   Intrinsic::ID IID = Callee->getName().starts_with("fmin") ? Intrinsic::minnum
2474                                                             : Intrinsic::maxnum;
2475   return copyFlags(*CI, B.CreateBinaryIntrinsic(IID, CI->getArgOperand(0),
2476                                                 CI->getArgOperand(1)));
2477 }
2478 
optimizeLog(CallInst * Log,IRBuilderBase & B)2479 Value *LibCallSimplifier::optimizeLog(CallInst *Log, IRBuilderBase &B) {
2480   Function *LogFn = Log->getCalledFunction();
2481   StringRef LogNm = LogFn->getName();
2482   Intrinsic::ID LogID = LogFn->getIntrinsicID();
2483   Module *Mod = Log->getModule();
2484   Type *Ty = Log->getType();
2485   Value *Ret = nullptr;
2486 
2487   if (UnsafeFPShrink && hasFloatVersion(Mod, LogNm))
2488     Ret = optimizeUnaryDoubleFP(Log, B, TLI, true);
2489 
2490   // The earlier call must also be 'fast' in order to do these transforms.
2491   CallInst *Arg = dyn_cast<CallInst>(Log->getArgOperand(0));
2492   if (!Log->isFast() || !Arg || !Arg->isFast() || !Arg->hasOneUse())
2493     return Ret;
2494 
2495   LibFunc LogLb, ExpLb, Exp2Lb, Exp10Lb, PowLb;
2496 
2497   // This is only applicable to log(), log2(), log10().
2498   if (TLI->getLibFunc(LogNm, LogLb))
2499     switch (LogLb) {
2500     case LibFunc_logf:
2501       LogID = Intrinsic::log;
2502       ExpLb = LibFunc_expf;
2503       Exp2Lb = LibFunc_exp2f;
2504       Exp10Lb = LibFunc_exp10f;
2505       PowLb = LibFunc_powf;
2506       break;
2507     case LibFunc_log:
2508       LogID = Intrinsic::log;
2509       ExpLb = LibFunc_exp;
2510       Exp2Lb = LibFunc_exp2;
2511       Exp10Lb = LibFunc_exp10;
2512       PowLb = LibFunc_pow;
2513       break;
2514     case LibFunc_logl:
2515       LogID = Intrinsic::log;
2516       ExpLb = LibFunc_expl;
2517       Exp2Lb = LibFunc_exp2l;
2518       Exp10Lb = LibFunc_exp10l;
2519       PowLb = LibFunc_powl;
2520       break;
2521     case LibFunc_log2f:
2522       LogID = Intrinsic::log2;
2523       ExpLb = LibFunc_expf;
2524       Exp2Lb = LibFunc_exp2f;
2525       Exp10Lb = LibFunc_exp10f;
2526       PowLb = LibFunc_powf;
2527       break;
2528     case LibFunc_log2:
2529       LogID = Intrinsic::log2;
2530       ExpLb = LibFunc_exp;
2531       Exp2Lb = LibFunc_exp2;
2532       Exp10Lb = LibFunc_exp10;
2533       PowLb = LibFunc_pow;
2534       break;
2535     case LibFunc_log2l:
2536       LogID = Intrinsic::log2;
2537       ExpLb = LibFunc_expl;
2538       Exp2Lb = LibFunc_exp2l;
2539       Exp10Lb = LibFunc_exp10l;
2540       PowLb = LibFunc_powl;
2541       break;
2542     case LibFunc_log10f:
2543       LogID = Intrinsic::log10;
2544       ExpLb = LibFunc_expf;
2545       Exp2Lb = LibFunc_exp2f;
2546       Exp10Lb = LibFunc_exp10f;
2547       PowLb = LibFunc_powf;
2548       break;
2549     case LibFunc_log10:
2550       LogID = Intrinsic::log10;
2551       ExpLb = LibFunc_exp;
2552       Exp2Lb = LibFunc_exp2;
2553       Exp10Lb = LibFunc_exp10;
2554       PowLb = LibFunc_pow;
2555       break;
2556     case LibFunc_log10l:
2557       LogID = Intrinsic::log10;
2558       ExpLb = LibFunc_expl;
2559       Exp2Lb = LibFunc_exp2l;
2560       Exp10Lb = LibFunc_exp10l;
2561       PowLb = LibFunc_powl;
2562       break;
2563     default:
2564       return Ret;
2565     }
2566   else if (LogID == Intrinsic::log || LogID == Intrinsic::log2 ||
2567            LogID == Intrinsic::log10) {
2568     if (Ty->getScalarType()->isFloatTy()) {
2569       ExpLb = LibFunc_expf;
2570       Exp2Lb = LibFunc_exp2f;
2571       Exp10Lb = LibFunc_exp10f;
2572       PowLb = LibFunc_powf;
2573     } else if (Ty->getScalarType()->isDoubleTy()) {
2574       ExpLb = LibFunc_exp;
2575       Exp2Lb = LibFunc_exp2;
2576       Exp10Lb = LibFunc_exp10;
2577       PowLb = LibFunc_pow;
2578     } else
2579       return Ret;
2580   } else
2581     return Ret;
2582 
2583   IRBuilderBase::FastMathFlagGuard Guard(B);
2584   B.setFastMathFlags(FastMathFlags::getFast());
2585 
2586   Intrinsic::ID ArgID = Arg->getIntrinsicID();
2587   LibFunc ArgLb = NotLibFunc;
2588   TLI->getLibFunc(*Arg, ArgLb);
2589 
2590   // log(pow(x,y)) -> y*log(x)
2591   AttributeList NoAttrs;
2592   if (ArgLb == PowLb || ArgID == Intrinsic::pow || ArgID == Intrinsic::powi) {
2593     Value *LogX =
2594         Log->doesNotAccessMemory()
2595             ? B.CreateUnaryIntrinsic(LogID, Arg->getOperand(0), nullptr, "log")
2596             : emitUnaryFloatFnCall(Arg->getOperand(0), TLI, LogNm, B, NoAttrs);
2597     Value *Y = Arg->getArgOperand(1);
2598     // Cast exponent to FP if integer.
2599     if (ArgID == Intrinsic::powi)
2600       Y = B.CreateSIToFP(Y, Ty, "cast");
2601     Value *MulY = B.CreateFMul(Y, LogX, "mul");
2602     // Since pow() may have side effects, e.g. errno,
2603     // dead code elimination may not be trusted to remove it.
2604     substituteInParent(Arg, MulY);
2605     return MulY;
2606   }
2607 
2608   // log(exp{,2,10}(y)) -> y*log({e,2,10})
2609   // TODO: There is no exp10() intrinsic yet.
2610   if (ArgLb == ExpLb || ArgLb == Exp2Lb || ArgLb == Exp10Lb ||
2611            ArgID == Intrinsic::exp || ArgID == Intrinsic::exp2) {
2612     Constant *Eul;
2613     if (ArgLb == ExpLb || ArgID == Intrinsic::exp)
2614       // FIXME: Add more precise value of e for long double.
2615       Eul = ConstantFP::get(Log->getType(), numbers::e);
2616     else if (ArgLb == Exp2Lb || ArgID == Intrinsic::exp2)
2617       Eul = ConstantFP::get(Log->getType(), 2.0);
2618     else
2619       Eul = ConstantFP::get(Log->getType(), 10.0);
2620     Value *LogE = Log->doesNotAccessMemory()
2621                       ? B.CreateUnaryIntrinsic(LogID, Eul, nullptr, "log")
2622                       : emitUnaryFloatFnCall(Eul, TLI, LogNm, B, NoAttrs);
2623     Value *MulY = B.CreateFMul(Arg->getArgOperand(0), LogE, "mul");
2624     // Since exp() may have side effects, e.g. errno,
2625     // dead code elimination may not be trusted to remove it.
2626     substituteInParent(Arg, MulY);
2627     return MulY;
2628   }
2629 
2630   return Ret;
2631 }
2632 
2633 // sqrt(exp(X)) -> exp(X * 0.5)
mergeSqrtToExp(CallInst * CI,IRBuilderBase & B)2634 Value *LibCallSimplifier::mergeSqrtToExp(CallInst *CI, IRBuilderBase &B) {
2635   if (!CI->hasAllowReassoc())
2636     return nullptr;
2637 
2638   Function *SqrtFn = CI->getCalledFunction();
2639   CallInst *Arg = dyn_cast<CallInst>(CI->getArgOperand(0));
2640   if (!Arg || !Arg->hasAllowReassoc() || !Arg->hasOneUse())
2641     return nullptr;
2642   Intrinsic::ID ArgID = Arg->getIntrinsicID();
2643   LibFunc ArgLb = NotLibFunc;
2644   TLI->getLibFunc(*Arg, ArgLb);
2645 
2646   LibFunc SqrtLb, ExpLb, Exp2Lb, Exp10Lb;
2647 
2648   if (TLI->getLibFunc(SqrtFn->getName(), SqrtLb))
2649     switch (SqrtLb) {
2650     case LibFunc_sqrtf:
2651       ExpLb = LibFunc_expf;
2652       Exp2Lb = LibFunc_exp2f;
2653       Exp10Lb = LibFunc_exp10f;
2654       break;
2655     case LibFunc_sqrt:
2656       ExpLb = LibFunc_exp;
2657       Exp2Lb = LibFunc_exp2;
2658       Exp10Lb = LibFunc_exp10;
2659       break;
2660     case LibFunc_sqrtl:
2661       ExpLb = LibFunc_expl;
2662       Exp2Lb = LibFunc_exp2l;
2663       Exp10Lb = LibFunc_exp10l;
2664       break;
2665     default:
2666       return nullptr;
2667     }
2668   else if (SqrtFn->getIntrinsicID() == Intrinsic::sqrt) {
2669     if (CI->getType()->getScalarType()->isFloatTy()) {
2670       ExpLb = LibFunc_expf;
2671       Exp2Lb = LibFunc_exp2f;
2672       Exp10Lb = LibFunc_exp10f;
2673     } else if (CI->getType()->getScalarType()->isDoubleTy()) {
2674       ExpLb = LibFunc_exp;
2675       Exp2Lb = LibFunc_exp2;
2676       Exp10Lb = LibFunc_exp10;
2677     } else
2678       return nullptr;
2679   } else
2680     return nullptr;
2681 
2682   if (ArgLb != ExpLb && ArgLb != Exp2Lb && ArgLb != Exp10Lb &&
2683       ArgID != Intrinsic::exp && ArgID != Intrinsic::exp2)
2684     return nullptr;
2685 
2686   IRBuilderBase::InsertPointGuard Guard(B);
2687   B.SetInsertPoint(Arg);
2688   auto *ExpOperand = Arg->getOperand(0);
2689   auto *FMul =
2690       B.CreateFMulFMF(ExpOperand, ConstantFP::get(ExpOperand->getType(), 0.5),
2691                       CI, "merged.sqrt");
2692 
2693   Arg->setOperand(0, FMul);
2694   return Arg;
2695 }
2696 
optimizeSqrt(CallInst * CI,IRBuilderBase & B)2697 Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilderBase &B) {
2698   Module *M = CI->getModule();
2699   Function *Callee = CI->getCalledFunction();
2700   Value *Ret = nullptr;
2701   // TODO: Once we have a way (other than checking for the existince of the
2702   // libcall) to tell whether our target can lower @llvm.sqrt, relax the
2703   // condition below.
2704   if (isLibFuncEmittable(M, TLI, LibFunc_sqrtf) &&
2705       (Callee->getName() == "sqrt" ||
2706        Callee->getIntrinsicID() == Intrinsic::sqrt))
2707     Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
2708 
2709   if (Value *Opt = mergeSqrtToExp(CI, B))
2710     return Opt;
2711 
2712   if (!CI->isFast())
2713     return Ret;
2714 
2715   Instruction *I = dyn_cast<Instruction>(CI->getArgOperand(0));
2716   if (!I || I->getOpcode() != Instruction::FMul || !I->isFast())
2717     return Ret;
2718 
2719   // We're looking for a repeated factor in a multiplication tree,
2720   // so we can do this fold: sqrt(x * x) -> fabs(x);
2721   // or this fold: sqrt((x * x) * y) -> fabs(x) * sqrt(y).
2722   Value *Op0 = I->getOperand(0);
2723   Value *Op1 = I->getOperand(1);
2724   Value *RepeatOp = nullptr;
2725   Value *OtherOp = nullptr;
2726   if (Op0 == Op1) {
2727     // Simple match: the operands of the multiply are identical.
2728     RepeatOp = Op0;
2729   } else {
2730     // Look for a more complicated pattern: one of the operands is itself
2731     // a multiply, so search for a common factor in that multiply.
2732     // Note: We don't bother looking any deeper than this first level or for
2733     // variations of this pattern because instcombine's visitFMUL and/or the
2734     // reassociation pass should give us this form.
2735     Value *OtherMul0, *OtherMul1;
2736     if (match(Op0, m_FMul(m_Value(OtherMul0), m_Value(OtherMul1)))) {
2737       // Pattern: sqrt((x * y) * z)
2738       if (OtherMul0 == OtherMul1 && cast<Instruction>(Op0)->isFast()) {
2739         // Matched: sqrt((x * x) * z)
2740         RepeatOp = OtherMul0;
2741         OtherOp = Op1;
2742       }
2743     }
2744   }
2745   if (!RepeatOp)
2746     return Ret;
2747 
2748   // Fast math flags for any created instructions should match the sqrt
2749   // and multiply.
2750   IRBuilderBase::FastMathFlagGuard Guard(B);
2751   B.setFastMathFlags(I->getFastMathFlags());
2752 
2753   // If we found a repeated factor, hoist it out of the square root and
2754   // replace it with the fabs of that factor.
2755   Value *FabsCall =
2756       B.CreateUnaryIntrinsic(Intrinsic::fabs, RepeatOp, nullptr, "fabs");
2757   if (OtherOp) {
2758     // If we found a non-repeated factor, we still need to get its square
2759     // root. We then multiply that by the value that was simplified out
2760     // of the square root calculation.
2761     Value *SqrtCall =
2762         B.CreateUnaryIntrinsic(Intrinsic::sqrt, OtherOp, nullptr, "sqrt");
2763     return copyFlags(*CI, B.CreateFMul(FabsCall, SqrtCall));
2764   }
2765   return copyFlags(*CI, FabsCall);
2766 }
2767 
optimizeTrigInversionPairs(CallInst * CI,IRBuilderBase & B)2768 Value *LibCallSimplifier::optimizeTrigInversionPairs(CallInst *CI,
2769                                                      IRBuilderBase &B) {
2770   Module *M = CI->getModule();
2771   Function *Callee = CI->getCalledFunction();
2772   Value *Ret = nullptr;
2773   StringRef Name = Callee->getName();
2774   if (UnsafeFPShrink &&
2775       (Name == "tan" || Name == "atanh" || Name == "sinh" || Name == "cosh" ||
2776        Name == "asinh") &&
2777       hasFloatVersion(M, Name))
2778     Ret = optimizeUnaryDoubleFP(CI, B, TLI, true);
2779 
2780   Value *Op1 = CI->getArgOperand(0);
2781   auto *OpC = dyn_cast<CallInst>(Op1);
2782   if (!OpC)
2783     return Ret;
2784 
2785   // Both calls must be 'fast' in order to remove them.
2786   if (!CI->isFast() || !OpC->isFast())
2787     return Ret;
2788 
2789   // tan(atan(x)) -> x
2790   // atanh(tanh(x)) -> x
2791   // sinh(asinh(x)) -> x
2792   // asinh(sinh(x)) -> x
2793   // cosh(acosh(x)) -> x
2794   LibFunc Func;
2795   Function *F = OpC->getCalledFunction();
2796   if (F && TLI->getLibFunc(F->getName(), Func) &&
2797       isLibFuncEmittable(M, TLI, Func)) {
2798     LibFunc inverseFunc = llvm::StringSwitch<LibFunc>(Callee->getName())
2799                               .Case("tan", LibFunc_atan)
2800                               .Case("atanh", LibFunc_tanh)
2801                               .Case("sinh", LibFunc_asinh)
2802                               .Case("cosh", LibFunc_acosh)
2803                               .Case("tanf", LibFunc_atanf)
2804                               .Case("atanhf", LibFunc_tanhf)
2805                               .Case("sinhf", LibFunc_asinhf)
2806                               .Case("coshf", LibFunc_acoshf)
2807                               .Case("tanl", LibFunc_atanl)
2808                               .Case("atanhl", LibFunc_tanhl)
2809                               .Case("sinhl", LibFunc_asinhl)
2810                               .Case("coshl", LibFunc_acoshl)
2811                               .Case("asinh", LibFunc_sinh)
2812                               .Case("asinhf", LibFunc_sinhf)
2813                               .Case("asinhl", LibFunc_sinhl)
2814                               .Default(NumLibFuncs); // Used as error value
2815     if (Func == inverseFunc)
2816       Ret = OpC->getArgOperand(0);
2817   }
2818   return Ret;
2819 }
2820 
isTrigLibCall(CallInst * CI)2821 static bool isTrigLibCall(CallInst *CI) {
2822   // We can only hope to do anything useful if we can ignore things like errno
2823   // and floating-point exceptions.
2824   // We already checked the prototype.
2825   return CI->doesNotThrow() && CI->doesNotAccessMemory();
2826 }
2827 
insertSinCosCall(IRBuilderBase & B,Function * OrigCallee,Value * Arg,bool UseFloat,Value * & Sin,Value * & Cos,Value * & SinCos,const TargetLibraryInfo * TLI)2828 static bool insertSinCosCall(IRBuilderBase &B, Function *OrigCallee, Value *Arg,
2829                              bool UseFloat, Value *&Sin, Value *&Cos,
2830                              Value *&SinCos, const TargetLibraryInfo *TLI) {
2831   Module *M = OrigCallee->getParent();
2832   Type *ArgTy = Arg->getType();
2833   Type *ResTy;
2834   StringRef Name;
2835 
2836   Triple T(OrigCallee->getParent()->getTargetTriple());
2837   if (UseFloat) {
2838     Name = "__sincospif_stret";
2839 
2840     assert(T.getArch() != Triple::x86 && "x86 messy and unsupported for now");
2841     // x86_64 can't use {float, float} since that would be returned in both
2842     // xmm0 and xmm1, which isn't what a real struct would do.
2843     ResTy = T.getArch() == Triple::x86_64
2844                 ? static_cast<Type *>(FixedVectorType::get(ArgTy, 2))
2845                 : static_cast<Type *>(StructType::get(ArgTy, ArgTy));
2846   } else {
2847     Name = "__sincospi_stret";
2848     ResTy = StructType::get(ArgTy, ArgTy);
2849   }
2850 
2851   if (!isLibFuncEmittable(M, TLI, Name))
2852     return false;
2853   LibFunc TheLibFunc;
2854   TLI->getLibFunc(Name, TheLibFunc);
2855   FunctionCallee Callee = getOrInsertLibFunc(
2856       M, *TLI, TheLibFunc, OrigCallee->getAttributes(), ResTy, ArgTy);
2857 
2858   if (Instruction *ArgInst = dyn_cast<Instruction>(Arg)) {
2859     // If the argument is an instruction, it must dominate all uses so put our
2860     // sincos call there.
2861     B.SetInsertPoint(ArgInst->getParent(), ++ArgInst->getIterator());
2862   } else {
2863     // Otherwise (e.g. for a constant) the beginning of the function is as
2864     // good a place as any.
2865     BasicBlock &EntryBB = B.GetInsertBlock()->getParent()->getEntryBlock();
2866     B.SetInsertPoint(&EntryBB, EntryBB.begin());
2867   }
2868 
2869   SinCos = B.CreateCall(Callee, Arg, "sincospi");
2870 
2871   if (SinCos->getType()->isStructTy()) {
2872     Sin = B.CreateExtractValue(SinCos, 0, "sinpi");
2873     Cos = B.CreateExtractValue(SinCos, 1, "cospi");
2874   } else {
2875     Sin = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 0),
2876                                  "sinpi");
2877     Cos = B.CreateExtractElement(SinCos, ConstantInt::get(B.getInt32Ty(), 1),
2878                                  "cospi");
2879   }
2880 
2881   return true;
2882 }
2883 
optimizeSymmetricCall(CallInst * CI,bool IsEven,IRBuilderBase & B)2884 static Value *optimizeSymmetricCall(CallInst *CI, bool IsEven,
2885                                     IRBuilderBase &B) {
2886   Value *X;
2887   Value *Src = CI->getArgOperand(0);
2888 
2889   if (match(Src, m_OneUse(m_FNeg(m_Value(X))))) {
2890     IRBuilderBase::FastMathFlagGuard Guard(B);
2891     B.setFastMathFlags(CI->getFastMathFlags());
2892 
2893     auto *CallInst = copyFlags(*CI, B.CreateCall(CI->getCalledFunction(), {X}));
2894     if (IsEven) {
2895       // Even function: f(-x) = f(x)
2896       return CallInst;
2897     }
2898     // Odd function: f(-x) = -f(x)
2899     return B.CreateFNeg(CallInst);
2900   }
2901 
2902   // Even function: f(abs(x)) = f(x), f(copysign(x, y)) = f(x)
2903   if (IsEven && (match(Src, m_FAbs(m_Value(X))) ||
2904                  match(Src, m_CopySign(m_Value(X), m_Value())))) {
2905     IRBuilderBase::FastMathFlagGuard Guard(B);
2906     B.setFastMathFlags(CI->getFastMathFlags());
2907 
2908     auto *CallInst = copyFlags(*CI, B.CreateCall(CI->getCalledFunction(), {X}));
2909     return CallInst;
2910   }
2911 
2912   return nullptr;
2913 }
2914 
optimizeSymmetric(CallInst * CI,LibFunc Func,IRBuilderBase & B)2915 Value *LibCallSimplifier::optimizeSymmetric(CallInst *CI, LibFunc Func,
2916                                             IRBuilderBase &B) {
2917   switch (Func) {
2918   case LibFunc_cos:
2919   case LibFunc_cosf:
2920   case LibFunc_cosl:
2921     return optimizeSymmetricCall(CI, /*IsEven*/ true, B);
2922 
2923   case LibFunc_sin:
2924   case LibFunc_sinf:
2925   case LibFunc_sinl:
2926 
2927   case LibFunc_tan:
2928   case LibFunc_tanf:
2929   case LibFunc_tanl:
2930 
2931   case LibFunc_erf:
2932   case LibFunc_erff:
2933   case LibFunc_erfl:
2934     return optimizeSymmetricCall(CI, /*IsEven*/ false, B);
2935 
2936   default:
2937     return nullptr;
2938   }
2939 }
2940 
optimizeSinCosPi(CallInst * CI,bool IsSin,IRBuilderBase & B)2941 Value *LibCallSimplifier::optimizeSinCosPi(CallInst *CI, bool IsSin, IRBuilderBase &B) {
2942   // Make sure the prototype is as expected, otherwise the rest of the
2943   // function is probably invalid and likely to abort.
2944   if (!isTrigLibCall(CI))
2945     return nullptr;
2946 
2947   Value *Arg = CI->getArgOperand(0);
2948   SmallVector<CallInst *, 1> SinCalls;
2949   SmallVector<CallInst *, 1> CosCalls;
2950   SmallVector<CallInst *, 1> SinCosCalls;
2951 
2952   bool IsFloat = Arg->getType()->isFloatTy();
2953 
2954   // Look for all compatible sinpi, cospi and sincospi calls with the same
2955   // argument. If there are enough (in some sense) we can make the
2956   // substitution.
2957   Function *F = CI->getFunction();
2958   for (User *U : Arg->users())
2959     classifyArgUse(U, F, IsFloat, SinCalls, CosCalls, SinCosCalls);
2960 
2961   // It's only worthwhile if both sinpi and cospi are actually used.
2962   if (SinCalls.empty() || CosCalls.empty())
2963     return nullptr;
2964 
2965   Value *Sin, *Cos, *SinCos;
2966   if (!insertSinCosCall(B, CI->getCalledFunction(), Arg, IsFloat, Sin, Cos,
2967                         SinCos, TLI))
2968     return nullptr;
2969 
2970   auto replaceTrigInsts = [this](SmallVectorImpl<CallInst *> &Calls,
2971                                  Value *Res) {
2972     for (CallInst *C : Calls)
2973       replaceAllUsesWith(C, Res);
2974   };
2975 
2976   replaceTrigInsts(SinCalls, Sin);
2977   replaceTrigInsts(CosCalls, Cos);
2978   replaceTrigInsts(SinCosCalls, SinCos);
2979 
2980   return IsSin ? Sin : Cos;
2981 }
2982 
classifyArgUse(Value * Val,Function * F,bool IsFloat,SmallVectorImpl<CallInst * > & SinCalls,SmallVectorImpl<CallInst * > & CosCalls,SmallVectorImpl<CallInst * > & SinCosCalls)2983 void LibCallSimplifier::classifyArgUse(
2984     Value *Val, Function *F, bool IsFloat,
2985     SmallVectorImpl<CallInst *> &SinCalls,
2986     SmallVectorImpl<CallInst *> &CosCalls,
2987     SmallVectorImpl<CallInst *> &SinCosCalls) {
2988   auto *CI = dyn_cast<CallInst>(Val);
2989   if (!CI || CI->use_empty())
2990     return;
2991 
2992   // Don't consider calls in other functions.
2993   if (CI->getFunction() != F)
2994     return;
2995 
2996   Module *M = CI->getModule();
2997   Function *Callee = CI->getCalledFunction();
2998   LibFunc Func;
2999   if (!Callee || !TLI->getLibFunc(*Callee, Func) ||
3000       !isLibFuncEmittable(M, TLI, Func) ||
3001       !isTrigLibCall(CI))
3002     return;
3003 
3004   if (IsFloat) {
3005     if (Func == LibFunc_sinpif)
3006       SinCalls.push_back(CI);
3007     else if (Func == LibFunc_cospif)
3008       CosCalls.push_back(CI);
3009     else if (Func == LibFunc_sincospif_stret)
3010       SinCosCalls.push_back(CI);
3011   } else {
3012     if (Func == LibFunc_sinpi)
3013       SinCalls.push_back(CI);
3014     else if (Func == LibFunc_cospi)
3015       CosCalls.push_back(CI);
3016     else if (Func == LibFunc_sincospi_stret)
3017       SinCosCalls.push_back(CI);
3018   }
3019 }
3020 
3021 //===----------------------------------------------------------------------===//
3022 // Integer Library Call Optimizations
3023 //===----------------------------------------------------------------------===//
3024 
optimizeFFS(CallInst * CI,IRBuilderBase & B)3025 Value *LibCallSimplifier::optimizeFFS(CallInst *CI, IRBuilderBase &B) {
3026   // All variants of ffs return int which need not be 32 bits wide.
3027   // ffs{,l,ll}(x) -> x != 0 ? (int)llvm.cttz(x)+1 : 0
3028   Type *RetType = CI->getType();
3029   Value *Op = CI->getArgOperand(0);
3030   Type *ArgType = Op->getType();
3031   Value *V = B.CreateIntrinsic(Intrinsic::cttz, {ArgType}, {Op, B.getTrue()},
3032                                nullptr, "cttz");
3033   V = B.CreateAdd(V, ConstantInt::get(V->getType(), 1));
3034   V = B.CreateIntCast(V, RetType, false);
3035 
3036   Value *Cond = B.CreateICmpNE(Op, Constant::getNullValue(ArgType));
3037   return B.CreateSelect(Cond, V, ConstantInt::get(RetType, 0));
3038 }
3039 
optimizeFls(CallInst * CI,IRBuilderBase & B)3040 Value *LibCallSimplifier::optimizeFls(CallInst *CI, IRBuilderBase &B) {
3041   // All variants of fls return int which need not be 32 bits wide.
3042   // fls{,l,ll}(x) -> (int)(sizeInBits(x) - llvm.ctlz(x, false))
3043   Value *Op = CI->getArgOperand(0);
3044   Type *ArgType = Op->getType();
3045   Value *V = B.CreateIntrinsic(Intrinsic::ctlz, {ArgType}, {Op, B.getFalse()},
3046                                nullptr, "ctlz");
3047   V = B.CreateSub(ConstantInt::get(V->getType(), ArgType->getIntegerBitWidth()),
3048                   V);
3049   return B.CreateIntCast(V, CI->getType(), false);
3050 }
3051 
optimizeAbs(CallInst * CI,IRBuilderBase & B)3052 Value *LibCallSimplifier::optimizeAbs(CallInst *CI, IRBuilderBase &B) {
3053   // abs(x) -> x <s 0 ? -x : x
3054   // The negation has 'nsw' because abs of INT_MIN is undefined.
3055   Value *X = CI->getArgOperand(0);
3056   Value *IsNeg = B.CreateIsNeg(X);
3057   Value *NegX = B.CreateNSWNeg(X, "neg");
3058   return B.CreateSelect(IsNeg, NegX, X);
3059 }
3060 
optimizeIsDigit(CallInst * CI,IRBuilderBase & B)3061 Value *LibCallSimplifier::optimizeIsDigit(CallInst *CI, IRBuilderBase &B) {
3062   // isdigit(c) -> (c-'0') <u 10
3063   Value *Op = CI->getArgOperand(0);
3064   Type *ArgType = Op->getType();
3065   Op = B.CreateSub(Op, ConstantInt::get(ArgType, '0'), "isdigittmp");
3066   Op = B.CreateICmpULT(Op, ConstantInt::get(ArgType, 10), "isdigit");
3067   return B.CreateZExt(Op, CI->getType());
3068 }
3069 
optimizeIsAscii(CallInst * CI,IRBuilderBase & B)3070 Value *LibCallSimplifier::optimizeIsAscii(CallInst *CI, IRBuilderBase &B) {
3071   // isascii(c) -> c <u 128
3072   Value *Op = CI->getArgOperand(0);
3073   Type *ArgType = Op->getType();
3074   Op = B.CreateICmpULT(Op, ConstantInt::get(ArgType, 128), "isascii");
3075   return B.CreateZExt(Op, CI->getType());
3076 }
3077 
optimizeToAscii(CallInst * CI,IRBuilderBase & B)3078 Value *LibCallSimplifier::optimizeToAscii(CallInst *CI, IRBuilderBase &B) {
3079   // toascii(c) -> c & 0x7f
3080   return B.CreateAnd(CI->getArgOperand(0),
3081                      ConstantInt::get(CI->getType(), 0x7F));
3082 }
3083 
3084 // Fold calls to atoi, atol, and atoll.
optimizeAtoi(CallInst * CI,IRBuilderBase & B)3085 Value *LibCallSimplifier::optimizeAtoi(CallInst *CI, IRBuilderBase &B) {
3086   CI->addParamAttr(0, Attribute::NoCapture);
3087 
3088   StringRef Str;
3089   if (!getConstantStringInfo(CI->getArgOperand(0), Str))
3090     return nullptr;
3091 
3092   return convertStrToInt(CI, Str, nullptr, 10, /*AsSigned=*/true, B);
3093 }
3094 
3095 // Fold calls to strtol, strtoll, strtoul, and strtoull.
optimizeStrToInt(CallInst * CI,IRBuilderBase & B,bool AsSigned)3096 Value *LibCallSimplifier::optimizeStrToInt(CallInst *CI, IRBuilderBase &B,
3097                                            bool AsSigned) {
3098   Value *EndPtr = CI->getArgOperand(1);
3099   if (isa<ConstantPointerNull>(EndPtr)) {
3100     // With a null EndPtr, this function won't capture the main argument.
3101     // It would be readonly too, except that it still may write to errno.
3102     CI->addParamAttr(0, Attribute::NoCapture);
3103     EndPtr = nullptr;
3104   } else if (!isKnownNonZero(EndPtr, DL))
3105     return nullptr;
3106 
3107   StringRef Str;
3108   if (!getConstantStringInfo(CI->getArgOperand(0), Str))
3109     return nullptr;
3110 
3111   if (ConstantInt *CInt = dyn_cast<ConstantInt>(CI->getArgOperand(2))) {
3112     return convertStrToInt(CI, Str, EndPtr, CInt->getSExtValue(), AsSigned, B);
3113   }
3114 
3115   return nullptr;
3116 }
3117 
3118 //===----------------------------------------------------------------------===//
3119 // Formatting and IO Library Call Optimizations
3120 //===----------------------------------------------------------------------===//
3121 
3122 static bool isReportingError(Function *Callee, CallInst *CI, int StreamArg);
3123 
optimizeErrorReporting(CallInst * CI,IRBuilderBase & B,int StreamArg)3124 Value *LibCallSimplifier::optimizeErrorReporting(CallInst *CI, IRBuilderBase &B,
3125                                                  int StreamArg) {
3126   Function *Callee = CI->getCalledFunction();
3127   // Error reporting calls should be cold, mark them as such.
3128   // This applies even to non-builtin calls: it is only a hint and applies to
3129   // functions that the frontend might not understand as builtins.
3130 
3131   // This heuristic was suggested in:
3132   // Improving Static Branch Prediction in a Compiler
3133   // Brian L. Deitrich, Ben-Chung Cheng, Wen-mei W. Hwu
3134   // Proceedings of PACT'98, Oct. 1998, IEEE
3135   if (!CI->hasFnAttr(Attribute::Cold) &&
3136       isReportingError(Callee, CI, StreamArg)) {
3137     CI->addFnAttr(Attribute::Cold);
3138   }
3139 
3140   return nullptr;
3141 }
3142 
isReportingError(Function * Callee,CallInst * CI,int StreamArg)3143 static bool isReportingError(Function *Callee, CallInst *CI, int StreamArg) {
3144   if (!Callee || !Callee->isDeclaration())
3145     return false;
3146 
3147   if (StreamArg < 0)
3148     return true;
3149 
3150   // These functions might be considered cold, but only if their stream
3151   // argument is stderr.
3152 
3153   if (StreamArg >= (int)CI->arg_size())
3154     return false;
3155   LoadInst *LI = dyn_cast<LoadInst>(CI->getArgOperand(StreamArg));
3156   if (!LI)
3157     return false;
3158   GlobalVariable *GV = dyn_cast<GlobalVariable>(LI->getPointerOperand());
3159   if (!GV || !GV->isDeclaration())
3160     return false;
3161   return GV->getName() == "stderr";
3162 }
3163 
optimizePrintFString(CallInst * CI,IRBuilderBase & B)3164 Value *LibCallSimplifier::optimizePrintFString(CallInst *CI, IRBuilderBase &B) {
3165   // Check for a fixed format string.
3166   StringRef FormatStr;
3167   if (!getConstantStringInfo(CI->getArgOperand(0), FormatStr))
3168     return nullptr;
3169 
3170   // Empty format string -> noop.
3171   if (FormatStr.empty()) // Tolerate printf's declared void.
3172     return CI->use_empty() ? (Value *)CI : ConstantInt::get(CI->getType(), 0);
3173 
3174   // Do not do any of the following transformations if the printf return value
3175   // is used, in general the printf return value is not compatible with either
3176   // putchar() or puts().
3177   if (!CI->use_empty())
3178     return nullptr;
3179 
3180   Type *IntTy = CI->getType();
3181   // printf("x") -> putchar('x'), even for "%" and "%%".
3182   if (FormatStr.size() == 1 || FormatStr == "%%") {
3183     // Convert the character to unsigned char before passing it to putchar
3184     // to avoid host-specific sign extension in the IR.  Putchar converts
3185     // it to unsigned char regardless.
3186     Value *IntChar = ConstantInt::get(IntTy, (unsigned char)FormatStr[0]);
3187     return copyFlags(*CI, emitPutChar(IntChar, B, TLI));
3188   }
3189 
3190   // Try to remove call or emit putchar/puts.
3191   if (FormatStr == "%s" && CI->arg_size() > 1) {
3192     StringRef OperandStr;
3193     if (!getConstantStringInfo(CI->getOperand(1), OperandStr))
3194       return nullptr;
3195     // printf("%s", "") --> NOP
3196     if (OperandStr.empty())
3197       return (Value *)CI;
3198     // printf("%s", "a") --> putchar('a')
3199     if (OperandStr.size() == 1) {
3200       // Convert the character to unsigned char before passing it to putchar
3201       // to avoid host-specific sign extension in the IR.  Putchar converts
3202       // it to unsigned char regardless.
3203       Value *IntChar = ConstantInt::get(IntTy, (unsigned char)OperandStr[0]);
3204       return copyFlags(*CI, emitPutChar(IntChar, B, TLI));
3205     }
3206     // printf("%s", str"\n") --> puts(str)
3207     if (OperandStr.back() == '\n') {
3208       OperandStr = OperandStr.drop_back();
3209       Value *GV = B.CreateGlobalString(OperandStr, "str");
3210       return copyFlags(*CI, emitPutS(GV, B, TLI));
3211     }
3212     return nullptr;
3213   }
3214 
3215   // printf("foo\n") --> puts("foo")
3216   if (FormatStr.back() == '\n' &&
3217       !FormatStr.contains('%')) { // No format characters.
3218     // Create a string literal with no \n on it.  We expect the constant merge
3219     // pass to be run after this pass, to merge duplicate strings.
3220     FormatStr = FormatStr.drop_back();
3221     Value *GV = B.CreateGlobalString(FormatStr, "str");
3222     return copyFlags(*CI, emitPutS(GV, B, TLI));
3223   }
3224 
3225   // Optimize specific format strings.
3226   // printf("%c", chr) --> putchar(chr)
3227   if (FormatStr == "%c" && CI->arg_size() > 1 &&
3228       CI->getArgOperand(1)->getType()->isIntegerTy()) {
3229     // Convert the argument to the type expected by putchar, i.e., int, which
3230     // need not be 32 bits wide but which is the same as printf's return type.
3231     Value *IntChar = B.CreateIntCast(CI->getArgOperand(1), IntTy, false);
3232     return copyFlags(*CI, emitPutChar(IntChar, B, TLI));
3233   }
3234 
3235   // printf("%s\n", str) --> puts(str)
3236   if (FormatStr == "%s\n" && CI->arg_size() > 1 &&
3237       CI->getArgOperand(1)->getType()->isPointerTy())
3238     return copyFlags(*CI, emitPutS(CI->getArgOperand(1), B, TLI));
3239   return nullptr;
3240 }
3241 
optimizePrintF(CallInst * CI,IRBuilderBase & B)3242 Value *LibCallSimplifier::optimizePrintF(CallInst *CI, IRBuilderBase &B) {
3243 
3244   Module *M = CI->getModule();
3245   Function *Callee = CI->getCalledFunction();
3246   FunctionType *FT = Callee->getFunctionType();
3247   if (Value *V = optimizePrintFString(CI, B)) {
3248     return V;
3249   }
3250 
3251   annotateNonNullNoUndefBasedOnAccess(CI, 0);
3252 
3253   // printf(format, ...) -> iprintf(format, ...) if no floating point
3254   // arguments.
3255   if (isLibFuncEmittable(M, TLI, LibFunc_iprintf) &&
3256       !callHasFloatingPointArgument(CI)) {
3257     FunctionCallee IPrintFFn = getOrInsertLibFunc(M, *TLI, LibFunc_iprintf, FT,
3258                                                   Callee->getAttributes());
3259     CallInst *New = cast<CallInst>(CI->clone());
3260     New->setCalledFunction(IPrintFFn);
3261     B.Insert(New);
3262     return New;
3263   }
3264 
3265   // printf(format, ...) -> __small_printf(format, ...) if no 128-bit floating point
3266   // arguments.
3267   if (isLibFuncEmittable(M, TLI, LibFunc_small_printf) &&
3268       !callHasFP128Argument(CI)) {
3269     auto SmallPrintFFn = getOrInsertLibFunc(M, *TLI, LibFunc_small_printf, FT,
3270                                             Callee->getAttributes());
3271     CallInst *New = cast<CallInst>(CI->clone());
3272     New->setCalledFunction(SmallPrintFFn);
3273     B.Insert(New);
3274     return New;
3275   }
3276 
3277   return nullptr;
3278 }
3279 
optimizeSPrintFString(CallInst * CI,IRBuilderBase & B)3280 Value *LibCallSimplifier::optimizeSPrintFString(CallInst *CI,
3281                                                 IRBuilderBase &B) {
3282   // Check for a fixed format string.
3283   StringRef FormatStr;
3284   if (!getConstantStringInfo(CI->getArgOperand(1), FormatStr))
3285     return nullptr;
3286 
3287   // If we just have a format string (nothing else crazy) transform it.
3288   Value *Dest = CI->getArgOperand(0);
3289   if (CI->arg_size() == 2) {
3290     // Make sure there's no % in the constant array.  We could try to handle
3291     // %% -> % in the future if we cared.
3292     if (FormatStr.contains('%'))
3293       return nullptr; // we found a format specifier, bail out.
3294 
3295     // sprintf(str, fmt) -> llvm.memcpy(align 1 str, align 1 fmt, strlen(fmt)+1)
3296     B.CreateMemCpy(
3297         Dest, Align(1), CI->getArgOperand(1), Align(1),
3298         ConstantInt::get(DL.getIntPtrType(CI->getContext()),
3299                          FormatStr.size() + 1)); // Copy the null byte.
3300     return ConstantInt::get(CI->getType(), FormatStr.size());
3301   }
3302 
3303   // The remaining optimizations require the format string to be "%s" or "%c"
3304   // and have an extra operand.
3305   if (FormatStr.size() != 2 || FormatStr[0] != '%' || CI->arg_size() < 3)
3306     return nullptr;
3307 
3308   // Decode the second character of the format string.
3309   if (FormatStr[1] == 'c') {
3310     // sprintf(dst, "%c", chr) --> *(i8*)dst = chr; *((i8*)dst+1) = 0
3311     if (!CI->getArgOperand(2)->getType()->isIntegerTy())
3312       return nullptr;
3313     Value *V = B.CreateTrunc(CI->getArgOperand(2), B.getInt8Ty(), "char");
3314     Value *Ptr = Dest;
3315     B.CreateStore(V, Ptr);
3316     Ptr = B.CreateInBoundsGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul");
3317     B.CreateStore(B.getInt8(0), Ptr);
3318 
3319     return ConstantInt::get(CI->getType(), 1);
3320   }
3321 
3322   if (FormatStr[1] == 's') {
3323     // sprintf(dest, "%s", str) -> llvm.memcpy(align 1 dest, align 1 str,
3324     // strlen(str)+1)
3325     if (!CI->getArgOperand(2)->getType()->isPointerTy())
3326       return nullptr;
3327 
3328     if (CI->use_empty())
3329       // sprintf(dest, "%s", str) -> strcpy(dest, str)
3330       return copyFlags(*CI, emitStrCpy(Dest, CI->getArgOperand(2), B, TLI));
3331 
3332     uint64_t SrcLen = GetStringLength(CI->getArgOperand(2));
3333     if (SrcLen) {
3334       B.CreateMemCpy(
3335           Dest, Align(1), CI->getArgOperand(2), Align(1),
3336           ConstantInt::get(DL.getIntPtrType(CI->getContext()), SrcLen));
3337       // Returns total number of characters written without null-character.
3338       return ConstantInt::get(CI->getType(), SrcLen - 1);
3339     } else if (Value *V = emitStpCpy(Dest, CI->getArgOperand(2), B, TLI)) {
3340       // sprintf(dest, "%s", str) -> stpcpy(dest, str) - dest
3341       Value *PtrDiff = B.CreatePtrDiff(B.getInt8Ty(), V, Dest);
3342       return B.CreateIntCast(PtrDiff, CI->getType(), false);
3343     }
3344 
3345     bool OptForSize = CI->getFunction()->hasOptSize() ||
3346                       llvm::shouldOptimizeForSize(CI->getParent(), PSI, BFI,
3347                                                   PGSOQueryType::IRPass);
3348     if (OptForSize)
3349       return nullptr;
3350 
3351     Value *Len = emitStrLen(CI->getArgOperand(2), B, DL, TLI);
3352     if (!Len)
3353       return nullptr;
3354     Value *IncLen =
3355         B.CreateAdd(Len, ConstantInt::get(Len->getType(), 1), "leninc");
3356     B.CreateMemCpy(Dest, Align(1), CI->getArgOperand(2), Align(1), IncLen);
3357 
3358     // The sprintf result is the unincremented number of bytes in the string.
3359     return B.CreateIntCast(Len, CI->getType(), false);
3360   }
3361   return nullptr;
3362 }
3363 
optimizeSPrintF(CallInst * CI,IRBuilderBase & B)3364 Value *LibCallSimplifier::optimizeSPrintF(CallInst *CI, IRBuilderBase &B) {
3365   Module *M = CI->getModule();
3366   Function *Callee = CI->getCalledFunction();
3367   FunctionType *FT = Callee->getFunctionType();
3368   if (Value *V = optimizeSPrintFString(CI, B)) {
3369     return V;
3370   }
3371 
3372   annotateNonNullNoUndefBasedOnAccess(CI, {0, 1});
3373 
3374   // sprintf(str, format, ...) -> siprintf(str, format, ...) if no floating
3375   // point arguments.
3376   if (isLibFuncEmittable(M, TLI, LibFunc_siprintf) &&
3377       !callHasFloatingPointArgument(CI)) {
3378     FunctionCallee SIPrintFFn = getOrInsertLibFunc(M, *TLI, LibFunc_siprintf,
3379                                                    FT, Callee->getAttributes());
3380     CallInst *New = cast<CallInst>(CI->clone());
3381     New->setCalledFunction(SIPrintFFn);
3382     B.Insert(New);
3383     return New;
3384   }
3385 
3386   // sprintf(str, format, ...) -> __small_sprintf(str, format, ...) if no 128-bit
3387   // floating point arguments.
3388   if (isLibFuncEmittable(M, TLI, LibFunc_small_sprintf) &&
3389       !callHasFP128Argument(CI)) {
3390     auto SmallSPrintFFn = getOrInsertLibFunc(M, *TLI, LibFunc_small_sprintf, FT,
3391                                              Callee->getAttributes());
3392     CallInst *New = cast<CallInst>(CI->clone());
3393     New->setCalledFunction(SmallSPrintFFn);
3394     B.Insert(New);
3395     return New;
3396   }
3397 
3398   return nullptr;
3399 }
3400 
3401 // Transform an snprintf call CI with the bound N to format the string Str
3402 // either to a call to memcpy, or to single character a store, or to nothing,
3403 // and fold the result to a constant.  A nonnull StrArg refers to the string
3404 // argument being formatted.  Otherwise the call is one with N < 2 and
3405 // the "%c" directive to format a single character.
emitSnPrintfMemCpy(CallInst * CI,Value * StrArg,StringRef Str,uint64_t N,IRBuilderBase & B)3406 Value *LibCallSimplifier::emitSnPrintfMemCpy(CallInst *CI, Value *StrArg,
3407                                              StringRef Str, uint64_t N,
3408                                              IRBuilderBase &B) {
3409   assert(StrArg || (N < 2 && Str.size() == 1));
3410 
3411   unsigned IntBits = TLI->getIntSize();
3412   uint64_t IntMax = maxIntN(IntBits);
3413   if (Str.size() > IntMax)
3414     // Bail if the string is longer than INT_MAX.  POSIX requires
3415     // implementations to set errno to EOVERFLOW in this case, in
3416     // addition to when N is larger than that (checked by the caller).
3417     return nullptr;
3418 
3419   Value *StrLen = ConstantInt::get(CI->getType(), Str.size());
3420   if (N == 0)
3421     return StrLen;
3422 
3423   // Set to the number of bytes to copy fron StrArg which is also
3424   // the offset of the terinating nul.
3425   uint64_t NCopy;
3426   if (N > Str.size())
3427     // Copy the full string, including the terminating nul (which must
3428     // be present regardless of the bound).
3429     NCopy = Str.size() + 1;
3430   else
3431     NCopy = N - 1;
3432 
3433   Value *DstArg = CI->getArgOperand(0);
3434   if (NCopy && StrArg)
3435     // Transform the call to lvm.memcpy(dst, fmt, N).
3436     copyFlags(
3437          *CI,
3438           B.CreateMemCpy(
3439                          DstArg, Align(1), StrArg, Align(1),
3440               ConstantInt::get(DL.getIntPtrType(CI->getContext()), NCopy)));
3441 
3442   if (N > Str.size())
3443     // Return early when the whole format string, including the final nul,
3444     // has been copied.
3445     return StrLen;
3446 
3447   // Otherwise, when truncating the string append a terminating nul.
3448   Type *Int8Ty = B.getInt8Ty();
3449   Value *NulOff = B.getIntN(IntBits, NCopy);
3450   Value *DstEnd = B.CreateInBoundsGEP(Int8Ty, DstArg, NulOff, "endptr");
3451   B.CreateStore(ConstantInt::get(Int8Ty, 0), DstEnd);
3452   return StrLen;
3453 }
3454 
optimizeSnPrintFString(CallInst * CI,IRBuilderBase & B)3455 Value *LibCallSimplifier::optimizeSnPrintFString(CallInst *CI,
3456                                                  IRBuilderBase &B) {
3457   // Check for size
3458   ConstantInt *Size = dyn_cast<ConstantInt>(CI->getArgOperand(1));
3459   if (!Size)
3460     return nullptr;
3461 
3462   uint64_t N = Size->getZExtValue();
3463   uint64_t IntMax = maxIntN(TLI->getIntSize());
3464   if (N > IntMax)
3465     // Bail if the bound exceeds INT_MAX.  POSIX requires implementations
3466     // to set errno to EOVERFLOW in this case.
3467     return nullptr;
3468 
3469   Value *DstArg = CI->getArgOperand(0);
3470   Value *FmtArg = CI->getArgOperand(2);
3471 
3472   // Check for a fixed format string.
3473   StringRef FormatStr;
3474   if (!getConstantStringInfo(FmtArg, FormatStr))
3475     return nullptr;
3476 
3477   // If we just have a format string (nothing else crazy) transform it.
3478   if (CI->arg_size() == 3) {
3479     if (FormatStr.contains('%'))
3480       // Bail if the format string contains a directive and there are
3481       // no arguments.  We could handle "%%" in the future.
3482       return nullptr;
3483 
3484     return emitSnPrintfMemCpy(CI, FmtArg, FormatStr, N, B);
3485   }
3486 
3487   // The remaining optimizations require the format string to be "%s" or "%c"
3488   // and have an extra operand.
3489   if (FormatStr.size() != 2 || FormatStr[0] != '%' || CI->arg_size() != 4)
3490     return nullptr;
3491 
3492   // Decode the second character of the format string.
3493   if (FormatStr[1] == 'c') {
3494     if (N <= 1) {
3495       // Use an arbitary string of length 1 to transform the call into
3496       // either a nul store (N == 1) or a no-op (N == 0) and fold it
3497       // to one.
3498       StringRef CharStr("*");
3499       return emitSnPrintfMemCpy(CI, nullptr, CharStr, N, B);
3500     }
3501 
3502     // snprintf(dst, size, "%c", chr) --> *(i8*)dst = chr; *((i8*)dst+1) = 0
3503     if (!CI->getArgOperand(3)->getType()->isIntegerTy())
3504       return nullptr;
3505     Value *V = B.CreateTrunc(CI->getArgOperand(3), B.getInt8Ty(), "char");
3506     Value *Ptr = DstArg;
3507     B.CreateStore(V, Ptr);
3508     Ptr = B.CreateInBoundsGEP(B.getInt8Ty(), Ptr, B.getInt32(1), "nul");
3509     B.CreateStore(B.getInt8(0), Ptr);
3510     return ConstantInt::get(CI->getType(), 1);
3511   }
3512 
3513   if (FormatStr[1] != 's')
3514     return nullptr;
3515 
3516   Value *StrArg = CI->getArgOperand(3);
3517   // snprintf(dest, size, "%s", str) to llvm.memcpy(dest, str, len+1, 1)
3518   StringRef Str;
3519   if (!getConstantStringInfo(StrArg, Str))
3520     return nullptr;
3521 
3522   return emitSnPrintfMemCpy(CI, StrArg, Str, N, B);
3523 }
3524 
optimizeSnPrintF(CallInst * CI,IRBuilderBase & B)3525 Value *LibCallSimplifier::optimizeSnPrintF(CallInst *CI, IRBuilderBase &B) {
3526   if (Value *V = optimizeSnPrintFString(CI, B)) {
3527     return V;
3528   }
3529 
3530   if (isKnownNonZero(CI->getOperand(1), DL))
3531     annotateNonNullNoUndefBasedOnAccess(CI, 0);
3532   return nullptr;
3533 }
3534 
optimizeFPrintFString(CallInst * CI,IRBuilderBase & B)3535 Value *LibCallSimplifier::optimizeFPrintFString(CallInst *CI,
3536                                                 IRBuilderBase &B) {
3537   optimizeErrorReporting(CI, B, 0);
3538 
3539   // All the optimizations depend on the format string.
3540   StringRef FormatStr;
3541   if (!getConstantStringInfo(CI->getArgOperand(1), FormatStr))
3542     return nullptr;
3543 
3544   // Do not do any of the following transformations if the fprintf return
3545   // value is used, in general the fprintf return value is not compatible
3546   // with fwrite(), fputc() or fputs().
3547   if (!CI->use_empty())
3548     return nullptr;
3549 
3550   // fprintf(F, "foo") --> fwrite("foo", 3, 1, F)
3551   if (CI->arg_size() == 2) {
3552     // Could handle %% -> % if we cared.
3553     if (FormatStr.contains('%'))
3554       return nullptr; // We found a format specifier.
3555 
3556     unsigned SizeTBits = TLI->getSizeTSize(*CI->getModule());
3557     Type *SizeTTy = IntegerType::get(CI->getContext(), SizeTBits);
3558     return copyFlags(
3559         *CI, emitFWrite(CI->getArgOperand(1),
3560                         ConstantInt::get(SizeTTy, FormatStr.size()),
3561                         CI->getArgOperand(0), B, DL, TLI));
3562   }
3563 
3564   // The remaining optimizations require the format string to be "%s" or "%c"
3565   // and have an extra operand.
3566   if (FormatStr.size() != 2 || FormatStr[0] != '%' || CI->arg_size() < 3)
3567     return nullptr;
3568 
3569   // Decode the second character of the format string.
3570   if (FormatStr[1] == 'c') {
3571     // fprintf(F, "%c", chr) --> fputc((int)chr, F)
3572     if (!CI->getArgOperand(2)->getType()->isIntegerTy())
3573       return nullptr;
3574     Type *IntTy = B.getIntNTy(TLI->getIntSize());
3575     Value *V = B.CreateIntCast(CI->getArgOperand(2), IntTy, /*isSigned*/ true,
3576                                "chari");
3577     return copyFlags(*CI, emitFPutC(V, CI->getArgOperand(0), B, TLI));
3578   }
3579 
3580   if (FormatStr[1] == 's') {
3581     // fprintf(F, "%s", str) --> fputs(str, F)
3582     if (!CI->getArgOperand(2)->getType()->isPointerTy())
3583       return nullptr;
3584     return copyFlags(
3585         *CI, emitFPutS(CI->getArgOperand(2), CI->getArgOperand(0), B, TLI));
3586   }
3587   return nullptr;
3588 }
3589 
optimizeFPrintF(CallInst * CI,IRBuilderBase & B)3590 Value *LibCallSimplifier::optimizeFPrintF(CallInst *CI, IRBuilderBase &B) {
3591   Module *M = CI->getModule();
3592   Function *Callee = CI->getCalledFunction();
3593   FunctionType *FT = Callee->getFunctionType();
3594   if (Value *V = optimizeFPrintFString(CI, B)) {
3595     return V;
3596   }
3597 
3598   // fprintf(stream, format, ...) -> fiprintf(stream, format, ...) if no
3599   // floating point arguments.
3600   if (isLibFuncEmittable(M, TLI, LibFunc_fiprintf) &&
3601       !callHasFloatingPointArgument(CI)) {
3602     FunctionCallee FIPrintFFn = getOrInsertLibFunc(M, *TLI, LibFunc_fiprintf,
3603                                                    FT, Callee->getAttributes());
3604     CallInst *New = cast<CallInst>(CI->clone());
3605     New->setCalledFunction(FIPrintFFn);
3606     B.Insert(New);
3607     return New;
3608   }
3609 
3610   // fprintf(stream, format, ...) -> __small_fprintf(stream, format, ...) if no
3611   // 128-bit floating point arguments.
3612   if (isLibFuncEmittable(M, TLI, LibFunc_small_fprintf) &&
3613       !callHasFP128Argument(CI)) {
3614     auto SmallFPrintFFn =
3615         getOrInsertLibFunc(M, *TLI, LibFunc_small_fprintf, FT,
3616                            Callee->getAttributes());
3617     CallInst *New = cast<CallInst>(CI->clone());
3618     New->setCalledFunction(SmallFPrintFFn);
3619     B.Insert(New);
3620     return New;
3621   }
3622 
3623   return nullptr;
3624 }
3625 
optimizeFWrite(CallInst * CI,IRBuilderBase & B)3626 Value *LibCallSimplifier::optimizeFWrite(CallInst *CI, IRBuilderBase &B) {
3627   optimizeErrorReporting(CI, B, 3);
3628 
3629   // Get the element size and count.
3630   ConstantInt *SizeC = dyn_cast<ConstantInt>(CI->getArgOperand(1));
3631   ConstantInt *CountC = dyn_cast<ConstantInt>(CI->getArgOperand(2));
3632   if (SizeC && CountC) {
3633     uint64_t Bytes = SizeC->getZExtValue() * CountC->getZExtValue();
3634 
3635     // If this is writing zero records, remove the call (it's a noop).
3636     if (Bytes == 0)
3637       return ConstantInt::get(CI->getType(), 0);
3638 
3639     // If this is writing one byte, turn it into fputc.
3640     // This optimisation is only valid, if the return value is unused.
3641     if (Bytes == 1 && CI->use_empty()) { // fwrite(S,1,1,F) -> fputc(S[0],F)
3642       Value *Char = B.CreateLoad(B.getInt8Ty(), CI->getArgOperand(0), "char");
3643       Type *IntTy = B.getIntNTy(TLI->getIntSize());
3644       Value *Cast = B.CreateIntCast(Char, IntTy, /*isSigned*/ true, "chari");
3645       Value *NewCI = emitFPutC(Cast, CI->getArgOperand(3), B, TLI);
3646       return NewCI ? ConstantInt::get(CI->getType(), 1) : nullptr;
3647     }
3648   }
3649 
3650   return nullptr;
3651 }
3652 
optimizeFPuts(CallInst * CI,IRBuilderBase & B)3653 Value *LibCallSimplifier::optimizeFPuts(CallInst *CI, IRBuilderBase &B) {
3654   optimizeErrorReporting(CI, B, 1);
3655 
3656   // Don't rewrite fputs to fwrite when optimising for size because fwrite
3657   // requires more arguments and thus extra MOVs are required.
3658   bool OptForSize = CI->getFunction()->hasOptSize() ||
3659                     llvm::shouldOptimizeForSize(CI->getParent(), PSI, BFI,
3660                                                 PGSOQueryType::IRPass);
3661   if (OptForSize)
3662     return nullptr;
3663 
3664   // We can't optimize if return value is used.
3665   if (!CI->use_empty())
3666     return nullptr;
3667 
3668   // fputs(s,F) --> fwrite(s,strlen(s),1,F)
3669   uint64_t Len = GetStringLength(CI->getArgOperand(0));
3670   if (!Len)
3671     return nullptr;
3672 
3673   // Known to have no uses (see above).
3674   unsigned SizeTBits = TLI->getSizeTSize(*CI->getModule());
3675   Type *SizeTTy = IntegerType::get(CI->getContext(), SizeTBits);
3676   return copyFlags(
3677       *CI,
3678       emitFWrite(CI->getArgOperand(0),
3679                  ConstantInt::get(SizeTTy, Len - 1),
3680                  CI->getArgOperand(1), B, DL, TLI));
3681 }
3682 
optimizePuts(CallInst * CI,IRBuilderBase & B)3683 Value *LibCallSimplifier::optimizePuts(CallInst *CI, IRBuilderBase &B) {
3684   annotateNonNullNoUndefBasedOnAccess(CI, 0);
3685   if (!CI->use_empty())
3686     return nullptr;
3687 
3688   // Check for a constant string.
3689   // puts("") -> putchar('\n')
3690   StringRef Str;
3691   if (getConstantStringInfo(CI->getArgOperand(0), Str) && Str.empty()) {
3692     // putchar takes an argument of the same type as puts returns, i.e.,
3693     // int, which need not be 32 bits wide.
3694     Type *IntTy = CI->getType();
3695     return copyFlags(*CI, emitPutChar(ConstantInt::get(IntTy, '\n'), B, TLI));
3696   }
3697 
3698   return nullptr;
3699 }
3700 
optimizeBCopy(CallInst * CI,IRBuilderBase & B)3701 Value *LibCallSimplifier::optimizeBCopy(CallInst *CI, IRBuilderBase &B) {
3702   // bcopy(src, dst, n) -> llvm.memmove(dst, src, n)
3703   return copyFlags(*CI, B.CreateMemMove(CI->getArgOperand(1), Align(1),
3704                                         CI->getArgOperand(0), Align(1),
3705                                         CI->getArgOperand(2)));
3706 }
3707 
hasFloatVersion(const Module * M,StringRef FuncName)3708 bool LibCallSimplifier::hasFloatVersion(const Module *M, StringRef FuncName) {
3709   SmallString<20> FloatFuncName = FuncName;
3710   FloatFuncName += 'f';
3711   return isLibFuncEmittable(M, TLI, FloatFuncName);
3712 }
3713 
optimizeStringMemoryLibCall(CallInst * CI,IRBuilderBase & Builder)3714 Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI,
3715                                                       IRBuilderBase &Builder) {
3716   Module *M = CI->getModule();
3717   LibFunc Func;
3718   Function *Callee = CI->getCalledFunction();
3719   // Check for string/memory library functions.
3720   if (TLI->getLibFunc(*Callee, Func) && isLibFuncEmittable(M, TLI, Func)) {
3721     // Make sure we never change the calling convention.
3722     assert(
3723         (ignoreCallingConv(Func) ||
3724          TargetLibraryInfoImpl::isCallingConvCCompatible(CI)) &&
3725         "Optimizing string/memory libcall would change the calling convention");
3726     switch (Func) {
3727     case LibFunc_strcat:
3728       return optimizeStrCat(CI, Builder);
3729     case LibFunc_strncat:
3730       return optimizeStrNCat(CI, Builder);
3731     case LibFunc_strchr:
3732       return optimizeStrChr(CI, Builder);
3733     case LibFunc_strrchr:
3734       return optimizeStrRChr(CI, Builder);
3735     case LibFunc_strcmp:
3736       return optimizeStrCmp(CI, Builder);
3737     case LibFunc_strncmp:
3738       return optimizeStrNCmp(CI, Builder);
3739     case LibFunc_strcpy:
3740       return optimizeStrCpy(CI, Builder);
3741     case LibFunc_stpcpy:
3742       return optimizeStpCpy(CI, Builder);
3743     case LibFunc_strlcpy:
3744       return optimizeStrLCpy(CI, Builder);
3745     case LibFunc_stpncpy:
3746       return optimizeStringNCpy(CI, /*RetEnd=*/true, Builder);
3747     case LibFunc_strncpy:
3748       return optimizeStringNCpy(CI, /*RetEnd=*/false, Builder);
3749     case LibFunc_strlen:
3750       return optimizeStrLen(CI, Builder);
3751     case LibFunc_strnlen:
3752       return optimizeStrNLen(CI, Builder);
3753     case LibFunc_strpbrk:
3754       return optimizeStrPBrk(CI, Builder);
3755     case LibFunc_strndup:
3756       return optimizeStrNDup(CI, Builder);
3757     case LibFunc_strtol:
3758     case LibFunc_strtod:
3759     case LibFunc_strtof:
3760     case LibFunc_strtoul:
3761     case LibFunc_strtoll:
3762     case LibFunc_strtold:
3763     case LibFunc_strtoull:
3764       return optimizeStrTo(CI, Builder);
3765     case LibFunc_strspn:
3766       return optimizeStrSpn(CI, Builder);
3767     case LibFunc_strcspn:
3768       return optimizeStrCSpn(CI, Builder);
3769     case LibFunc_strstr:
3770       return optimizeStrStr(CI, Builder);
3771     case LibFunc_memchr:
3772       return optimizeMemChr(CI, Builder);
3773     case LibFunc_memrchr:
3774       return optimizeMemRChr(CI, Builder);
3775     case LibFunc_bcmp:
3776       return optimizeBCmp(CI, Builder);
3777     case LibFunc_memcmp:
3778       return optimizeMemCmp(CI, Builder);
3779     case LibFunc_memcpy:
3780       return optimizeMemCpy(CI, Builder);
3781     case LibFunc_memccpy:
3782       return optimizeMemCCpy(CI, Builder);
3783     case LibFunc_mempcpy:
3784       return optimizeMemPCpy(CI, Builder);
3785     case LibFunc_memmove:
3786       return optimizeMemMove(CI, Builder);
3787     case LibFunc_memset:
3788       return optimizeMemSet(CI, Builder);
3789     case LibFunc_realloc:
3790       return optimizeRealloc(CI, Builder);
3791     case LibFunc_wcslen:
3792       return optimizeWcslen(CI, Builder);
3793     case LibFunc_bcopy:
3794       return optimizeBCopy(CI, Builder);
3795     case LibFunc_Znwm:
3796     case LibFunc_ZnwmRKSt9nothrow_t:
3797     case LibFunc_ZnwmSt11align_val_t:
3798     case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t:
3799     case LibFunc_Znam:
3800     case LibFunc_ZnamRKSt9nothrow_t:
3801     case LibFunc_ZnamSt11align_val_t:
3802     case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t:
3803     case LibFunc_Znwm12__hot_cold_t:
3804     case LibFunc_ZnwmRKSt9nothrow_t12__hot_cold_t:
3805     case LibFunc_ZnwmSt11align_val_t12__hot_cold_t:
3806     case LibFunc_ZnwmSt11align_val_tRKSt9nothrow_t12__hot_cold_t:
3807     case LibFunc_Znam12__hot_cold_t:
3808     case LibFunc_ZnamRKSt9nothrow_t12__hot_cold_t:
3809     case LibFunc_ZnamSt11align_val_t12__hot_cold_t:
3810     case LibFunc_ZnamSt11align_val_tRKSt9nothrow_t12__hot_cold_t:
3811       return optimizeNew(CI, Builder, Func);
3812     default:
3813       break;
3814     }
3815   }
3816   return nullptr;
3817 }
3818 
optimizeFloatingPointLibCall(CallInst * CI,LibFunc Func,IRBuilderBase & Builder)3819 Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI,
3820                                                        LibFunc Func,
3821                                                        IRBuilderBase &Builder) {
3822   const Module *M = CI->getModule();
3823 
3824   // Don't optimize calls that require strict floating point semantics.
3825   if (CI->isStrictFP())
3826     return nullptr;
3827 
3828   if (Value *V = optimizeSymmetric(CI, Func, Builder))
3829     return V;
3830 
3831   switch (Func) {
3832   case LibFunc_sinpif:
3833   case LibFunc_sinpi:
3834     return optimizeSinCosPi(CI, /*IsSin*/true, Builder);
3835   case LibFunc_cospif:
3836   case LibFunc_cospi:
3837     return optimizeSinCosPi(CI, /*IsSin*/false, Builder);
3838   case LibFunc_powf:
3839   case LibFunc_pow:
3840   case LibFunc_powl:
3841     return optimizePow(CI, Builder);
3842   case LibFunc_exp2l:
3843   case LibFunc_exp2:
3844   case LibFunc_exp2f:
3845     return optimizeExp2(CI, Builder);
3846   case LibFunc_fabsf:
3847   case LibFunc_fabs:
3848   case LibFunc_fabsl:
3849     return replaceUnaryCall(CI, Builder, Intrinsic::fabs);
3850   case LibFunc_sqrtf:
3851   case LibFunc_sqrt:
3852   case LibFunc_sqrtl:
3853     return optimizeSqrt(CI, Builder);
3854   case LibFunc_logf:
3855   case LibFunc_log:
3856   case LibFunc_logl:
3857   case LibFunc_log10f:
3858   case LibFunc_log10:
3859   case LibFunc_log10l:
3860   case LibFunc_log1pf:
3861   case LibFunc_log1p:
3862   case LibFunc_log1pl:
3863   case LibFunc_log2f:
3864   case LibFunc_log2:
3865   case LibFunc_log2l:
3866   case LibFunc_logbf:
3867   case LibFunc_logb:
3868   case LibFunc_logbl:
3869     return optimizeLog(CI, Builder);
3870   case LibFunc_tan:
3871   case LibFunc_tanf:
3872   case LibFunc_tanl:
3873   case LibFunc_sinh:
3874   case LibFunc_sinhf:
3875   case LibFunc_sinhl:
3876   case LibFunc_asinh:
3877   case LibFunc_asinhf:
3878   case LibFunc_asinhl:
3879   case LibFunc_cosh:
3880   case LibFunc_coshf:
3881   case LibFunc_coshl:
3882   case LibFunc_atanh:
3883   case LibFunc_atanhf:
3884   case LibFunc_atanhl:
3885     return optimizeTrigInversionPairs(CI, Builder);
3886   case LibFunc_ceil:
3887     return replaceUnaryCall(CI, Builder, Intrinsic::ceil);
3888   case LibFunc_floor:
3889     return replaceUnaryCall(CI, Builder, Intrinsic::floor);
3890   case LibFunc_round:
3891     return replaceUnaryCall(CI, Builder, Intrinsic::round);
3892   case LibFunc_roundeven:
3893     return replaceUnaryCall(CI, Builder, Intrinsic::roundeven);
3894   case LibFunc_nearbyint:
3895     return replaceUnaryCall(CI, Builder, Intrinsic::nearbyint);
3896   case LibFunc_rint:
3897     return replaceUnaryCall(CI, Builder, Intrinsic::rint);
3898   case LibFunc_trunc:
3899     return replaceUnaryCall(CI, Builder, Intrinsic::trunc);
3900   case LibFunc_acos:
3901   case LibFunc_acosh:
3902   case LibFunc_asin:
3903   case LibFunc_atan:
3904   case LibFunc_cbrt:
3905   case LibFunc_exp:
3906   case LibFunc_exp10:
3907   case LibFunc_expm1:
3908   case LibFunc_cos:
3909   case LibFunc_sin:
3910   case LibFunc_tanh:
3911     if (UnsafeFPShrink && hasFloatVersion(M, CI->getCalledFunction()->getName()))
3912       return optimizeUnaryDoubleFP(CI, Builder, TLI, true);
3913     return nullptr;
3914   case LibFunc_copysign:
3915     if (hasFloatVersion(M, CI->getCalledFunction()->getName()))
3916       return optimizeBinaryDoubleFP(CI, Builder, TLI);
3917     return nullptr;
3918   case LibFunc_fminf:
3919   case LibFunc_fmin:
3920   case LibFunc_fminl:
3921   case LibFunc_fmaxf:
3922   case LibFunc_fmax:
3923   case LibFunc_fmaxl:
3924     return optimizeFMinFMax(CI, Builder);
3925   case LibFunc_cabs:
3926   case LibFunc_cabsf:
3927   case LibFunc_cabsl:
3928     return optimizeCAbs(CI, Builder);
3929   default:
3930     return nullptr;
3931   }
3932 }
3933 
optimizeCall(CallInst * CI,IRBuilderBase & Builder)3934 Value *LibCallSimplifier::optimizeCall(CallInst *CI, IRBuilderBase &Builder) {
3935   Module *M = CI->getModule();
3936   assert(!CI->isMustTailCall() && "These transforms aren't musttail safe.");
3937 
3938   // TODO: Split out the code below that operates on FP calls so that
3939   //       we can all non-FP calls with the StrictFP attribute to be
3940   //       optimized.
3941   if (CI->isNoBuiltin())
3942     return nullptr;
3943 
3944   LibFunc Func;
3945   Function *Callee = CI->getCalledFunction();
3946   bool IsCallingConvC = TargetLibraryInfoImpl::isCallingConvCCompatible(CI);
3947 
3948   SmallVector<OperandBundleDef, 2> OpBundles;
3949   CI->getOperandBundlesAsDefs(OpBundles);
3950 
3951   IRBuilderBase::OperandBundlesGuard Guard(Builder);
3952   Builder.setDefaultOperandBundles(OpBundles);
3953 
3954   // Command-line parameter overrides instruction attribute.
3955   // This can't be moved to optimizeFloatingPointLibCall() because it may be
3956   // used by the intrinsic optimizations.
3957   if (EnableUnsafeFPShrink.getNumOccurrences() > 0)
3958     UnsafeFPShrink = EnableUnsafeFPShrink;
3959   else if (isa<FPMathOperator>(CI) && CI->isFast())
3960     UnsafeFPShrink = true;
3961 
3962   // First, check for intrinsics.
3963   if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
3964     if (!IsCallingConvC)
3965       return nullptr;
3966     // The FP intrinsics have corresponding constrained versions so we don't
3967     // need to check for the StrictFP attribute here.
3968     switch (II->getIntrinsicID()) {
3969     case Intrinsic::pow:
3970       return optimizePow(CI, Builder);
3971     case Intrinsic::exp2:
3972       return optimizeExp2(CI, Builder);
3973     case Intrinsic::log:
3974     case Intrinsic::log2:
3975     case Intrinsic::log10:
3976       return optimizeLog(CI, Builder);
3977     case Intrinsic::sqrt:
3978       return optimizeSqrt(CI, Builder);
3979     case Intrinsic::memset:
3980       return optimizeMemSet(CI, Builder);
3981     case Intrinsic::memcpy:
3982       return optimizeMemCpy(CI, Builder);
3983     case Intrinsic::memmove:
3984       return optimizeMemMove(CI, Builder);
3985     default:
3986       return nullptr;
3987     }
3988   }
3989 
3990   // Also try to simplify calls to fortified library functions.
3991   if (Value *SimplifiedFortifiedCI =
3992           FortifiedSimplifier.optimizeCall(CI, Builder))
3993     return SimplifiedFortifiedCI;
3994 
3995   // Then check for known library functions.
3996   if (TLI->getLibFunc(*Callee, Func) && isLibFuncEmittable(M, TLI, Func)) {
3997     // We never change the calling convention.
3998     if (!ignoreCallingConv(Func) && !IsCallingConvC)
3999       return nullptr;
4000     if (Value *V = optimizeStringMemoryLibCall(CI, Builder))
4001       return V;
4002     if (Value *V = optimizeFloatingPointLibCall(CI, Func, Builder))
4003       return V;
4004     switch (Func) {
4005     case LibFunc_ffs:
4006     case LibFunc_ffsl:
4007     case LibFunc_ffsll:
4008       return optimizeFFS(CI, Builder);
4009     case LibFunc_fls:
4010     case LibFunc_flsl:
4011     case LibFunc_flsll:
4012       return optimizeFls(CI, Builder);
4013     case LibFunc_abs:
4014     case LibFunc_labs:
4015     case LibFunc_llabs:
4016       return optimizeAbs(CI, Builder);
4017     case LibFunc_isdigit:
4018       return optimizeIsDigit(CI, Builder);
4019     case LibFunc_isascii:
4020       return optimizeIsAscii(CI, Builder);
4021     case LibFunc_toascii:
4022       return optimizeToAscii(CI, Builder);
4023     case LibFunc_atoi:
4024     case LibFunc_atol:
4025     case LibFunc_atoll:
4026       return optimizeAtoi(CI, Builder);
4027     case LibFunc_strtol:
4028     case LibFunc_strtoll:
4029       return optimizeStrToInt(CI, Builder, /*AsSigned=*/true);
4030     case LibFunc_strtoul:
4031     case LibFunc_strtoull:
4032       return optimizeStrToInt(CI, Builder, /*AsSigned=*/false);
4033     case LibFunc_printf:
4034       return optimizePrintF(CI, Builder);
4035     case LibFunc_sprintf:
4036       return optimizeSPrintF(CI, Builder);
4037     case LibFunc_snprintf:
4038       return optimizeSnPrintF(CI, Builder);
4039     case LibFunc_fprintf:
4040       return optimizeFPrintF(CI, Builder);
4041     case LibFunc_fwrite:
4042       return optimizeFWrite(CI, Builder);
4043     case LibFunc_fputs:
4044       return optimizeFPuts(CI, Builder);
4045     case LibFunc_puts:
4046       return optimizePuts(CI, Builder);
4047     case LibFunc_perror:
4048       return optimizeErrorReporting(CI, Builder);
4049     case LibFunc_vfprintf:
4050     case LibFunc_fiprintf:
4051       return optimizeErrorReporting(CI, Builder, 0);
4052     default:
4053       return nullptr;
4054     }
4055   }
4056   return nullptr;
4057 }
4058 
LibCallSimplifier(const DataLayout & DL,const TargetLibraryInfo * TLI,AssumptionCache * AC,OptimizationRemarkEmitter & ORE,BlockFrequencyInfo * BFI,ProfileSummaryInfo * PSI,function_ref<void (Instruction *,Value *)> Replacer,function_ref<void (Instruction *)> Eraser)4059 LibCallSimplifier::LibCallSimplifier(
4060     const DataLayout &DL, const TargetLibraryInfo *TLI, AssumptionCache *AC,
4061     OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI,
4062     ProfileSummaryInfo *PSI,
4063     function_ref<void(Instruction *, Value *)> Replacer,
4064     function_ref<void(Instruction *)> Eraser)
4065     : FortifiedSimplifier(TLI), DL(DL), TLI(TLI), AC(AC), ORE(ORE), BFI(BFI),
4066       PSI(PSI), Replacer(Replacer), Eraser(Eraser) {}
4067 
replaceAllUsesWith(Instruction * I,Value * With)4068 void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) {
4069   // Indirect through the replacer used in this instance.
4070   Replacer(I, With);
4071 }
4072 
eraseFromParent(Instruction * I)4073 void LibCallSimplifier::eraseFromParent(Instruction *I) {
4074   Eraser(I);
4075 }
4076 
4077 // TODO:
4078 //   Additional cases that we need to add to this file:
4079 //
4080 // cbrt:
4081 //   * cbrt(expN(X))  -> expN(x/3)
4082 //   * cbrt(sqrt(x))  -> pow(x,1/6)
4083 //   * cbrt(cbrt(x))  -> pow(x,1/9)
4084 //
4085 // exp, expf, expl:
4086 //   * exp(log(x))  -> x
4087 //
4088 // log, logf, logl:
4089 //   * log(exp(x))   -> x
4090 //   * log(exp(y))   -> y*log(e)
4091 //   * log(exp10(y)) -> y*log(10)
4092 //   * log(sqrt(x))  -> 0.5*log(x)
4093 //
4094 // pow, powf, powl:
4095 //   * pow(sqrt(x),y) -> pow(x,y*0.5)
4096 //   * pow(pow(x,y),z)-> pow(x,y*z)
4097 //
4098 // signbit:
4099 //   * signbit(cnst) -> cnst'
4100 //   * signbit(nncst) -> 0 (if pstv is a non-negative constant)
4101 //
4102 // sqrt, sqrtf, sqrtl:
4103 //   * sqrt(expN(x))  -> expN(x*0.5)
4104 //   * sqrt(Nroot(x)) -> pow(x,1/(2*N))
4105 //   * sqrt(pow(x,y)) -> pow(|x|,y*0.5)
4106 //
4107 
4108 //===----------------------------------------------------------------------===//
4109 // Fortified Library Call Optimizations
4110 //===----------------------------------------------------------------------===//
4111 
isFortifiedCallFoldable(CallInst * CI,unsigned ObjSizeOp,std::optional<unsigned> SizeOp,std::optional<unsigned> StrOp,std::optional<unsigned> FlagOp)4112 bool FortifiedLibCallSimplifier::isFortifiedCallFoldable(
4113     CallInst *CI, unsigned ObjSizeOp, std::optional<unsigned> SizeOp,
4114     std::optional<unsigned> StrOp, std::optional<unsigned> FlagOp) {
4115   // If this function takes a flag argument, the implementation may use it to
4116   // perform extra checks. Don't fold into the non-checking variant.
4117   if (FlagOp) {
4118     ConstantInt *Flag = dyn_cast<ConstantInt>(CI->getArgOperand(*FlagOp));
4119     if (!Flag || !Flag->isZero())
4120       return false;
4121   }
4122 
4123   if (SizeOp && CI->getArgOperand(ObjSizeOp) == CI->getArgOperand(*SizeOp))
4124     return true;
4125 
4126   if (ConstantInt *ObjSizeCI =
4127           dyn_cast<ConstantInt>(CI->getArgOperand(ObjSizeOp))) {
4128     if (ObjSizeCI->isMinusOne())
4129       return true;
4130     // If the object size wasn't -1 (unknown), bail out if we were asked to.
4131     if (OnlyLowerUnknownSize)
4132       return false;
4133     if (StrOp) {
4134       uint64_t Len = GetStringLength(CI->getArgOperand(*StrOp));
4135       // If the length is 0 we don't know how long it is and so we can't
4136       // remove the check.
4137       if (Len)
4138         annotateDereferenceableBytes(CI, *StrOp, Len);
4139       else
4140         return false;
4141       return ObjSizeCI->getZExtValue() >= Len;
4142     }
4143 
4144     if (SizeOp) {
4145       if (ConstantInt *SizeCI =
4146               dyn_cast<ConstantInt>(CI->getArgOperand(*SizeOp)))
4147         return ObjSizeCI->getZExtValue() >= SizeCI->getZExtValue();
4148     }
4149   }
4150   return false;
4151 }
4152 
optimizeMemCpyChk(CallInst * CI,IRBuilderBase & B)4153 Value *FortifiedLibCallSimplifier::optimizeMemCpyChk(CallInst *CI,
4154                                                      IRBuilderBase &B) {
4155   if (isFortifiedCallFoldable(CI, 3, 2)) {
4156     CallInst *NewCI =
4157         B.CreateMemCpy(CI->getArgOperand(0), Align(1), CI->getArgOperand(1),
4158                        Align(1), CI->getArgOperand(2));
4159     mergeAttributesAndFlags(NewCI, *CI);
4160     return CI->getArgOperand(0);
4161   }
4162   return nullptr;
4163 }
4164 
optimizeMemMoveChk(CallInst * CI,IRBuilderBase & B)4165 Value *FortifiedLibCallSimplifier::optimizeMemMoveChk(CallInst *CI,
4166                                                       IRBuilderBase &B) {
4167   if (isFortifiedCallFoldable(CI, 3, 2)) {
4168     CallInst *NewCI =
4169         B.CreateMemMove(CI->getArgOperand(0), Align(1), CI->getArgOperand(1),
4170                         Align(1), CI->getArgOperand(2));
4171     mergeAttributesAndFlags(NewCI, *CI);
4172     return CI->getArgOperand(0);
4173   }
4174   return nullptr;
4175 }
4176 
optimizeMemSetChk(CallInst * CI,IRBuilderBase & B)4177 Value *FortifiedLibCallSimplifier::optimizeMemSetChk(CallInst *CI,
4178                                                      IRBuilderBase &B) {
4179   if (isFortifiedCallFoldable(CI, 3, 2)) {
4180     Value *Val = B.CreateIntCast(CI->getArgOperand(1), B.getInt8Ty(), false);
4181     CallInst *NewCI = B.CreateMemSet(CI->getArgOperand(0), Val,
4182                                      CI->getArgOperand(2), Align(1));
4183     mergeAttributesAndFlags(NewCI, *CI);
4184     return CI->getArgOperand(0);
4185   }
4186   return nullptr;
4187 }
4188 
optimizeMemPCpyChk(CallInst * CI,IRBuilderBase & B)4189 Value *FortifiedLibCallSimplifier::optimizeMemPCpyChk(CallInst *CI,
4190                                                       IRBuilderBase &B) {
4191   const DataLayout &DL = CI->getDataLayout();
4192   if (isFortifiedCallFoldable(CI, 3, 2))
4193     if (Value *Call = emitMemPCpy(CI->getArgOperand(0), CI->getArgOperand(1),
4194                                   CI->getArgOperand(2), B, DL, TLI)) {
4195       return mergeAttributesAndFlags(cast<CallInst>(Call), *CI);
4196     }
4197   return nullptr;
4198 }
4199 
optimizeStrpCpyChk(CallInst * CI,IRBuilderBase & B,LibFunc Func)4200 Value *FortifiedLibCallSimplifier::optimizeStrpCpyChk(CallInst *CI,
4201                                                       IRBuilderBase &B,
4202                                                       LibFunc Func) {
4203   const DataLayout &DL = CI->getDataLayout();
4204   Value *Dst = CI->getArgOperand(0), *Src = CI->getArgOperand(1),
4205         *ObjSize = CI->getArgOperand(2);
4206 
4207   // __stpcpy_chk(x,x,...)  -> x+strlen(x)
4208   if (Func == LibFunc_stpcpy_chk && !OnlyLowerUnknownSize && Dst == Src) {
4209     Value *StrLen = emitStrLen(Src, B, DL, TLI);
4210     return StrLen ? B.CreateInBoundsGEP(B.getInt8Ty(), Dst, StrLen) : nullptr;
4211   }
4212 
4213   // If a) we don't have any length information, or b) we know this will
4214   // fit then just lower to a plain st[rp]cpy. Otherwise we'll keep our
4215   // st[rp]cpy_chk call which may fail at runtime if the size is too long.
4216   // TODO: It might be nice to get a maximum length out of the possible
4217   // string lengths for varying.
4218   if (isFortifiedCallFoldable(CI, 2, std::nullopt, 1)) {
4219     if (Func == LibFunc_strcpy_chk)
4220       return copyFlags(*CI, emitStrCpy(Dst, Src, B, TLI));
4221     else
4222       return copyFlags(*CI, emitStpCpy(Dst, Src, B, TLI));
4223   }
4224 
4225   if (OnlyLowerUnknownSize)
4226     return nullptr;
4227 
4228   // Maybe we can stil fold __st[rp]cpy_chk to __memcpy_chk.
4229   uint64_t Len = GetStringLength(Src);
4230   if (Len)
4231     annotateDereferenceableBytes(CI, 1, Len);
4232   else
4233     return nullptr;
4234 
4235   unsigned SizeTBits = TLI->getSizeTSize(*CI->getModule());
4236   Type *SizeTTy = IntegerType::get(CI->getContext(), SizeTBits);
4237   Value *LenV = ConstantInt::get(SizeTTy, Len);
4238   Value *Ret = emitMemCpyChk(Dst, Src, LenV, ObjSize, B, DL, TLI);
4239   // If the function was an __stpcpy_chk, and we were able to fold it into
4240   // a __memcpy_chk, we still need to return the correct end pointer.
4241   if (Ret && Func == LibFunc_stpcpy_chk)
4242     return B.CreateInBoundsGEP(B.getInt8Ty(), Dst,
4243                                ConstantInt::get(SizeTTy, Len - 1));
4244   return copyFlags(*CI, cast<CallInst>(Ret));
4245 }
4246 
optimizeStrLenChk(CallInst * CI,IRBuilderBase & B)4247 Value *FortifiedLibCallSimplifier::optimizeStrLenChk(CallInst *CI,
4248                                                      IRBuilderBase &B) {
4249   if (isFortifiedCallFoldable(CI, 1, std::nullopt, 0))
4250     return copyFlags(*CI, emitStrLen(CI->getArgOperand(0), B,
4251                                      CI->getDataLayout(), TLI));
4252   return nullptr;
4253 }
4254 
optimizeStrpNCpyChk(CallInst * CI,IRBuilderBase & B,LibFunc Func)4255 Value *FortifiedLibCallSimplifier::optimizeStrpNCpyChk(CallInst *CI,
4256                                                        IRBuilderBase &B,
4257                                                        LibFunc Func) {
4258   if (isFortifiedCallFoldable(CI, 3, 2)) {
4259     if (Func == LibFunc_strncpy_chk)
4260       return copyFlags(*CI,
4261                        emitStrNCpy(CI->getArgOperand(0), CI->getArgOperand(1),
4262                                    CI->getArgOperand(2), B, TLI));
4263     else
4264       return copyFlags(*CI,
4265                        emitStpNCpy(CI->getArgOperand(0), CI->getArgOperand(1),
4266                                    CI->getArgOperand(2), B, TLI));
4267   }
4268 
4269   return nullptr;
4270 }
4271 
optimizeMemCCpyChk(CallInst * CI,IRBuilderBase & B)4272 Value *FortifiedLibCallSimplifier::optimizeMemCCpyChk(CallInst *CI,
4273                                                       IRBuilderBase &B) {
4274   if (isFortifiedCallFoldable(CI, 4, 3))
4275     return copyFlags(
4276         *CI, emitMemCCpy(CI->getArgOperand(0), CI->getArgOperand(1),
4277                          CI->getArgOperand(2), CI->getArgOperand(3), B, TLI));
4278 
4279   return nullptr;
4280 }
4281 
optimizeSNPrintfChk(CallInst * CI,IRBuilderBase & B)4282 Value *FortifiedLibCallSimplifier::optimizeSNPrintfChk(CallInst *CI,
4283                                                        IRBuilderBase &B) {
4284   if (isFortifiedCallFoldable(CI, 3, 1, std::nullopt, 2)) {
4285     SmallVector<Value *, 8> VariadicArgs(drop_begin(CI->args(), 5));
4286     return copyFlags(*CI,
4287                      emitSNPrintf(CI->getArgOperand(0), CI->getArgOperand(1),
4288                                   CI->getArgOperand(4), VariadicArgs, B, TLI));
4289   }
4290 
4291   return nullptr;
4292 }
4293 
optimizeSPrintfChk(CallInst * CI,IRBuilderBase & B)4294 Value *FortifiedLibCallSimplifier::optimizeSPrintfChk(CallInst *CI,
4295                                                       IRBuilderBase &B) {
4296   if (isFortifiedCallFoldable(CI, 2, std::nullopt, std::nullopt, 1)) {
4297     SmallVector<Value *, 8> VariadicArgs(drop_begin(CI->args(), 4));
4298     return copyFlags(*CI,
4299                      emitSPrintf(CI->getArgOperand(0), CI->getArgOperand(3),
4300                                  VariadicArgs, B, TLI));
4301   }
4302 
4303   return nullptr;
4304 }
4305 
optimizeStrCatChk(CallInst * CI,IRBuilderBase & B)4306 Value *FortifiedLibCallSimplifier::optimizeStrCatChk(CallInst *CI,
4307                                                      IRBuilderBase &B) {
4308   if (isFortifiedCallFoldable(CI, 2))
4309     return copyFlags(
4310         *CI, emitStrCat(CI->getArgOperand(0), CI->getArgOperand(1), B, TLI));
4311 
4312   return nullptr;
4313 }
4314 
optimizeStrLCat(CallInst * CI,IRBuilderBase & B)4315 Value *FortifiedLibCallSimplifier::optimizeStrLCat(CallInst *CI,
4316                                                    IRBuilderBase &B) {
4317   if (isFortifiedCallFoldable(CI, 3))
4318     return copyFlags(*CI,
4319                      emitStrLCat(CI->getArgOperand(0), CI->getArgOperand(1),
4320                                  CI->getArgOperand(2), B, TLI));
4321 
4322   return nullptr;
4323 }
4324 
optimizeStrNCatChk(CallInst * CI,IRBuilderBase & B)4325 Value *FortifiedLibCallSimplifier::optimizeStrNCatChk(CallInst *CI,
4326                                                       IRBuilderBase &B) {
4327   if (isFortifiedCallFoldable(CI, 3))
4328     return copyFlags(*CI,
4329                      emitStrNCat(CI->getArgOperand(0), CI->getArgOperand(1),
4330                                  CI->getArgOperand(2), B, TLI));
4331 
4332   return nullptr;
4333 }
4334 
optimizeStrLCpyChk(CallInst * CI,IRBuilderBase & B)4335 Value *FortifiedLibCallSimplifier::optimizeStrLCpyChk(CallInst *CI,
4336                                                       IRBuilderBase &B) {
4337   if (isFortifiedCallFoldable(CI, 3))
4338     return copyFlags(*CI,
4339                      emitStrLCpy(CI->getArgOperand(0), CI->getArgOperand(1),
4340                                  CI->getArgOperand(2), B, TLI));
4341 
4342   return nullptr;
4343 }
4344 
optimizeVSNPrintfChk(CallInst * CI,IRBuilderBase & B)4345 Value *FortifiedLibCallSimplifier::optimizeVSNPrintfChk(CallInst *CI,
4346                                                         IRBuilderBase &B) {
4347   if (isFortifiedCallFoldable(CI, 3, 1, std::nullopt, 2))
4348     return copyFlags(
4349         *CI, emitVSNPrintf(CI->getArgOperand(0), CI->getArgOperand(1),
4350                            CI->getArgOperand(4), CI->getArgOperand(5), B, TLI));
4351 
4352   return nullptr;
4353 }
4354 
optimizeVSPrintfChk(CallInst * CI,IRBuilderBase & B)4355 Value *FortifiedLibCallSimplifier::optimizeVSPrintfChk(CallInst *CI,
4356                                                        IRBuilderBase &B) {
4357   if (isFortifiedCallFoldable(CI, 2, std::nullopt, std::nullopt, 1))
4358     return copyFlags(*CI,
4359                      emitVSPrintf(CI->getArgOperand(0), CI->getArgOperand(3),
4360                                   CI->getArgOperand(4), B, TLI));
4361 
4362   return nullptr;
4363 }
4364 
optimizeCall(CallInst * CI,IRBuilderBase & Builder)4365 Value *FortifiedLibCallSimplifier::optimizeCall(CallInst *CI,
4366                                                 IRBuilderBase &Builder) {
4367   // FIXME: We shouldn't be changing "nobuiltin" or TLI unavailable calls here.
4368   // Some clang users checked for _chk libcall availability using:
4369   //   __has_builtin(__builtin___memcpy_chk)
4370   // When compiling with -fno-builtin, this is always true.
4371   // When passing -ffreestanding/-mkernel, which both imply -fno-builtin, we
4372   // end up with fortified libcalls, which isn't acceptable in a freestanding
4373   // environment which only provides their non-fortified counterparts.
4374   //
4375   // Until we change clang and/or teach external users to check for availability
4376   // differently, disregard the "nobuiltin" attribute and TLI::has.
4377   //
4378   // PR23093.
4379 
4380   LibFunc Func;
4381   Function *Callee = CI->getCalledFunction();
4382   bool IsCallingConvC = TargetLibraryInfoImpl::isCallingConvCCompatible(CI);
4383 
4384   SmallVector<OperandBundleDef, 2> OpBundles;
4385   CI->getOperandBundlesAsDefs(OpBundles);
4386 
4387   IRBuilderBase::OperandBundlesGuard Guard(Builder);
4388   Builder.setDefaultOperandBundles(OpBundles);
4389 
4390   // First, check that this is a known library functions and that the prototype
4391   // is correct.
4392   if (!TLI->getLibFunc(*Callee, Func))
4393     return nullptr;
4394 
4395   // We never change the calling convention.
4396   if (!ignoreCallingConv(Func) && !IsCallingConvC)
4397     return nullptr;
4398 
4399   switch (Func) {
4400   case LibFunc_memcpy_chk:
4401     return optimizeMemCpyChk(CI, Builder);
4402   case LibFunc_mempcpy_chk:
4403     return optimizeMemPCpyChk(CI, Builder);
4404   case LibFunc_memmove_chk:
4405     return optimizeMemMoveChk(CI, Builder);
4406   case LibFunc_memset_chk:
4407     return optimizeMemSetChk(CI, Builder);
4408   case LibFunc_stpcpy_chk:
4409   case LibFunc_strcpy_chk:
4410     return optimizeStrpCpyChk(CI, Builder, Func);
4411   case LibFunc_strlen_chk:
4412     return optimizeStrLenChk(CI, Builder);
4413   case LibFunc_stpncpy_chk:
4414   case LibFunc_strncpy_chk:
4415     return optimizeStrpNCpyChk(CI, Builder, Func);
4416   case LibFunc_memccpy_chk:
4417     return optimizeMemCCpyChk(CI, Builder);
4418   case LibFunc_snprintf_chk:
4419     return optimizeSNPrintfChk(CI, Builder);
4420   case LibFunc_sprintf_chk:
4421     return optimizeSPrintfChk(CI, Builder);
4422   case LibFunc_strcat_chk:
4423     return optimizeStrCatChk(CI, Builder);
4424   case LibFunc_strlcat_chk:
4425     return optimizeStrLCat(CI, Builder);
4426   case LibFunc_strncat_chk:
4427     return optimizeStrNCatChk(CI, Builder);
4428   case LibFunc_strlcpy_chk:
4429     return optimizeStrLCpyChk(CI, Builder);
4430   case LibFunc_vsnprintf_chk:
4431     return optimizeVSNPrintfChk(CI, Builder);
4432   case LibFunc_vsprintf_chk:
4433     return optimizeVSPrintfChk(CI, Builder);
4434   default:
4435     break;
4436   }
4437   return nullptr;
4438 }
4439 
FortifiedLibCallSimplifier(const TargetLibraryInfo * TLI,bool OnlyLowerUnknownSize)4440 FortifiedLibCallSimplifier::FortifiedLibCallSimplifier(
4441     const TargetLibraryInfo *TLI, bool OnlyLowerUnknownSize)
4442     : TLI(TLI), OnlyLowerUnknownSize(OnlyLowerUnknownSize) {}
4443