1 //===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass tries to expand memcmp() calls into optimally-sized loads and 10 // compares for the target. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/Statistic.h" 15 #include "llvm/Analysis/ConstantFolding.h" 16 #include "llvm/Analysis/DomTreeUpdater.h" 17 #include "llvm/Analysis/LazyBlockFrequencyInfo.h" 18 #include "llvm/Analysis/ProfileSummaryInfo.h" 19 #include "llvm/Analysis/TargetLibraryInfo.h" 20 #include "llvm/Analysis/TargetTransformInfo.h" 21 #include "llvm/Analysis/ValueTracking.h" 22 #include "llvm/CodeGen/TargetLowering.h" 23 #include "llvm/CodeGen/TargetPassConfig.h" 24 #include "llvm/CodeGen/TargetSubtargetInfo.h" 25 #include "llvm/IR/Dominators.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/InitializePasses.h" 28 #include "llvm/Target/TargetMachine.h" 29 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 30 #include "llvm/Transforms/Utils/Local.h" 31 #include "llvm/Transforms/Utils/SizeOpts.h" 32 33 using namespace llvm; 34 35 #define DEBUG_TYPE "expandmemcmp" 36 37 STATISTIC(NumMemCmpCalls, "Number of memcmp calls"); 38 STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size"); 39 STATISTIC(NumMemCmpGreaterThanMax, 40 "Number of memcmp calls with size greater than max size"); 41 STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls"); 42 43 static cl::opt<unsigned> MemCmpEqZeroNumLoadsPerBlock( 44 "memcmp-num-loads-per-block", cl::Hidden, cl::init(1), 45 cl::desc("The number of loads per basic block for inline expansion of " 46 "memcmp that is only being compared against zero.")); 47 48 static cl::opt<unsigned> MaxLoadsPerMemcmp( 49 "max-loads-per-memcmp", cl::Hidden, 50 cl::desc("Set maximum number of loads used in expanded memcmp")); 51 52 static cl::opt<unsigned> MaxLoadsPerMemcmpOptSize( 53 "max-loads-per-memcmp-opt-size", cl::Hidden, 54 cl::desc("Set maximum number of loads used in expanded memcmp for -Os/Oz")); 55 56 namespace { 57 58 59 // This class provides helper functions to expand a memcmp library call into an 60 // inline expansion. 61 class MemCmpExpansion { 62 struct ResultBlock { 63 BasicBlock *BB = nullptr; 64 PHINode *PhiSrc1 = nullptr; 65 PHINode *PhiSrc2 = nullptr; 66 67 ResultBlock() = default; 68 }; 69 70 CallInst *const CI; 71 ResultBlock ResBlock; 72 const uint64_t Size; 73 unsigned MaxLoadSize = 0; 74 uint64_t NumLoadsNonOneByte = 0; 75 const uint64_t NumLoadsPerBlockForZeroCmp; 76 std::vector<BasicBlock *> LoadCmpBlocks; 77 BasicBlock *EndBlock; 78 PHINode *PhiRes; 79 const bool IsUsedForZeroCmp; 80 const DataLayout &DL; 81 DomTreeUpdater *DTU; 82 IRBuilder<> Builder; 83 // Represents the decomposition in blocks of the expansion. For example, 84 // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and 85 // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {1, 32}. 86 struct LoadEntry { 87 LoadEntry(unsigned LoadSize, uint64_t Offset) 88 : LoadSize(LoadSize), Offset(Offset) { 89 } 90 91 // The size of the load for this block, in bytes. 92 unsigned LoadSize; 93 // The offset of this load from the base pointer, in bytes. 94 uint64_t Offset; 95 }; 96 using LoadEntryVector = SmallVector<LoadEntry, 8>; 97 LoadEntryVector LoadSequence; 98 99 void createLoadCmpBlocks(); 100 void createResultBlock(); 101 void setupResultBlockPHINodes(); 102 void setupEndBlockPHINodes(); 103 Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex); 104 void emitLoadCompareBlock(unsigned BlockIndex); 105 void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, 106 unsigned &LoadIndex); 107 void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned OffsetBytes); 108 void emitMemCmpResultBlock(); 109 Value *getMemCmpExpansionZeroCase(); 110 Value *getMemCmpEqZeroOneBlock(); 111 Value *getMemCmpOneBlock(); 112 struct LoadPair { 113 Value *Lhs = nullptr; 114 Value *Rhs = nullptr; 115 }; 116 LoadPair getLoadPair(Type *LoadSizeType, bool NeedsBSwap, Type *CmpSizeType, 117 unsigned OffsetBytes); 118 119 static LoadEntryVector 120 computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes, 121 unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte); 122 static LoadEntryVector 123 computeOverlappingLoadSequence(uint64_t Size, unsigned MaxLoadSize, 124 unsigned MaxNumLoads, 125 unsigned &NumLoadsNonOneByte); 126 127 public: 128 MemCmpExpansion(CallInst *CI, uint64_t Size, 129 const TargetTransformInfo::MemCmpExpansionOptions &Options, 130 const bool IsUsedForZeroCmp, const DataLayout &TheDataLayout, 131 DomTreeUpdater *DTU); 132 133 unsigned getNumBlocks(); 134 uint64_t getNumLoads() const { return LoadSequence.size(); } 135 136 Value *getMemCmpExpansion(); 137 }; 138 139 MemCmpExpansion::LoadEntryVector MemCmpExpansion::computeGreedyLoadSequence( 140 uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes, 141 const unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte) { 142 NumLoadsNonOneByte = 0; 143 LoadEntryVector LoadSequence; 144 uint64_t Offset = 0; 145 while (Size && !LoadSizes.empty()) { 146 const unsigned LoadSize = LoadSizes.front(); 147 const uint64_t NumLoadsForThisSize = Size / LoadSize; 148 if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) { 149 // Do not expand if the total number of loads is larger than what the 150 // target allows. Note that it's important that we exit before completing 151 // the expansion to avoid using a ton of memory to store the expansion for 152 // large sizes. 153 return {}; 154 } 155 if (NumLoadsForThisSize > 0) { 156 for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) { 157 LoadSequence.push_back({LoadSize, Offset}); 158 Offset += LoadSize; 159 } 160 if (LoadSize > 1) 161 ++NumLoadsNonOneByte; 162 Size = Size % LoadSize; 163 } 164 LoadSizes = LoadSizes.drop_front(); 165 } 166 return LoadSequence; 167 } 168 169 MemCmpExpansion::LoadEntryVector 170 MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size, 171 const unsigned MaxLoadSize, 172 const unsigned MaxNumLoads, 173 unsigned &NumLoadsNonOneByte) { 174 // These are already handled by the greedy approach. 175 if (Size < 2 || MaxLoadSize < 2) 176 return {}; 177 178 // We try to do as many non-overlapping loads as possible starting from the 179 // beginning. 180 const uint64_t NumNonOverlappingLoads = Size / MaxLoadSize; 181 assert(NumNonOverlappingLoads && "there must be at least one load"); 182 // There remain 0 to (MaxLoadSize - 1) bytes to load, this will be done with 183 // an overlapping load. 184 Size = Size - NumNonOverlappingLoads * MaxLoadSize; 185 // Bail if we do not need an overloapping store, this is already handled by 186 // the greedy approach. 187 if (Size == 0) 188 return {}; 189 // Bail if the number of loads (non-overlapping + potential overlapping one) 190 // is larger than the max allowed. 191 if ((NumNonOverlappingLoads + 1) > MaxNumLoads) 192 return {}; 193 194 // Add non-overlapping loads. 195 LoadEntryVector LoadSequence; 196 uint64_t Offset = 0; 197 for (uint64_t I = 0; I < NumNonOverlappingLoads; ++I) { 198 LoadSequence.push_back({MaxLoadSize, Offset}); 199 Offset += MaxLoadSize; 200 } 201 202 // Add the last overlapping load. 203 assert(Size > 0 && Size < MaxLoadSize && "broken invariant"); 204 LoadSequence.push_back({MaxLoadSize, Offset - (MaxLoadSize - Size)}); 205 NumLoadsNonOneByte = 1; 206 return LoadSequence; 207 } 208 209 // Initialize the basic block structure required for expansion of memcmp call 210 // with given maximum load size and memcmp size parameter. 211 // This structure includes: 212 // 1. A list of load compare blocks - LoadCmpBlocks. 213 // 2. An EndBlock, split from original instruction point, which is the block to 214 // return from. 215 // 3. ResultBlock, block to branch to for early exit when a 216 // LoadCmpBlock finds a difference. 217 MemCmpExpansion::MemCmpExpansion( 218 CallInst *const CI, uint64_t Size, 219 const TargetTransformInfo::MemCmpExpansionOptions &Options, 220 const bool IsUsedForZeroCmp, const DataLayout &TheDataLayout, 221 DomTreeUpdater *DTU) 222 : CI(CI), Size(Size), NumLoadsPerBlockForZeroCmp(Options.NumLoadsPerBlock), 223 IsUsedForZeroCmp(IsUsedForZeroCmp), DL(TheDataLayout), DTU(DTU), 224 Builder(CI) { 225 assert(Size > 0 && "zero blocks"); 226 // Scale the max size down if the target can load more bytes than we need. 227 llvm::ArrayRef<unsigned> LoadSizes(Options.LoadSizes); 228 while (!LoadSizes.empty() && LoadSizes.front() > Size) { 229 LoadSizes = LoadSizes.drop_front(); 230 } 231 assert(!LoadSizes.empty() && "cannot load Size bytes"); 232 MaxLoadSize = LoadSizes.front(); 233 // Compute the decomposition. 234 unsigned GreedyNumLoadsNonOneByte = 0; 235 LoadSequence = computeGreedyLoadSequence(Size, LoadSizes, Options.MaxNumLoads, 236 GreedyNumLoadsNonOneByte); 237 NumLoadsNonOneByte = GreedyNumLoadsNonOneByte; 238 assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant"); 239 // If we allow overlapping loads and the load sequence is not already optimal, 240 // use overlapping loads. 241 if (Options.AllowOverlappingLoads && 242 (LoadSequence.empty() || LoadSequence.size() > 2)) { 243 unsigned OverlappingNumLoadsNonOneByte = 0; 244 auto OverlappingLoads = computeOverlappingLoadSequence( 245 Size, MaxLoadSize, Options.MaxNumLoads, OverlappingNumLoadsNonOneByte); 246 if (!OverlappingLoads.empty() && 247 (LoadSequence.empty() || 248 OverlappingLoads.size() < LoadSequence.size())) { 249 LoadSequence = OverlappingLoads; 250 NumLoadsNonOneByte = OverlappingNumLoadsNonOneByte; 251 } 252 } 253 assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant"); 254 } 255 256 unsigned MemCmpExpansion::getNumBlocks() { 257 if (IsUsedForZeroCmp) 258 return getNumLoads() / NumLoadsPerBlockForZeroCmp + 259 (getNumLoads() % NumLoadsPerBlockForZeroCmp != 0 ? 1 : 0); 260 return getNumLoads(); 261 } 262 263 void MemCmpExpansion::createLoadCmpBlocks() { 264 for (unsigned i = 0; i < getNumBlocks(); i++) { 265 BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb", 266 EndBlock->getParent(), EndBlock); 267 LoadCmpBlocks.push_back(BB); 268 } 269 } 270 271 void MemCmpExpansion::createResultBlock() { 272 ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block", 273 EndBlock->getParent(), EndBlock); 274 } 275 276 MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType, 277 bool NeedsBSwap, 278 Type *CmpSizeType, 279 unsigned OffsetBytes) { 280 // Get the memory source at offset `OffsetBytes`. 281 Value *LhsSource = CI->getArgOperand(0); 282 Value *RhsSource = CI->getArgOperand(1); 283 Align LhsAlign = LhsSource->getPointerAlignment(DL); 284 Align RhsAlign = RhsSource->getPointerAlignment(DL); 285 if (OffsetBytes > 0) { 286 auto *ByteType = Type::getInt8Ty(CI->getContext()); 287 LhsSource = Builder.CreateConstGEP1_64( 288 ByteType, Builder.CreateBitCast(LhsSource, ByteType->getPointerTo()), 289 OffsetBytes); 290 RhsSource = Builder.CreateConstGEP1_64( 291 ByteType, Builder.CreateBitCast(RhsSource, ByteType->getPointerTo()), 292 OffsetBytes); 293 LhsAlign = commonAlignment(LhsAlign, OffsetBytes); 294 RhsAlign = commonAlignment(RhsAlign, OffsetBytes); 295 } 296 LhsSource = Builder.CreateBitCast(LhsSource, LoadSizeType->getPointerTo()); 297 RhsSource = Builder.CreateBitCast(RhsSource, LoadSizeType->getPointerTo()); 298 299 // Create a constant or a load from the source. 300 Value *Lhs = nullptr; 301 if (auto *C = dyn_cast<Constant>(LhsSource)) 302 Lhs = ConstantFoldLoadFromConstPtr(C, LoadSizeType, DL); 303 if (!Lhs) 304 Lhs = Builder.CreateAlignedLoad(LoadSizeType, LhsSource, LhsAlign); 305 306 Value *Rhs = nullptr; 307 if (auto *C = dyn_cast<Constant>(RhsSource)) 308 Rhs = ConstantFoldLoadFromConstPtr(C, LoadSizeType, DL); 309 if (!Rhs) 310 Rhs = Builder.CreateAlignedLoad(LoadSizeType, RhsSource, RhsAlign); 311 312 // Swap bytes if required. 313 if (NeedsBSwap) { 314 Function *Bswap = Intrinsic::getDeclaration(CI->getModule(), 315 Intrinsic::bswap, LoadSizeType); 316 Lhs = Builder.CreateCall(Bswap, Lhs); 317 Rhs = Builder.CreateCall(Bswap, Rhs); 318 } 319 320 // Zero extend if required. 321 if (CmpSizeType != nullptr && CmpSizeType != LoadSizeType) { 322 Lhs = Builder.CreateZExt(Lhs, CmpSizeType); 323 Rhs = Builder.CreateZExt(Rhs, CmpSizeType); 324 } 325 return {Lhs, Rhs}; 326 } 327 328 // This function creates the IR instructions for loading and comparing 1 byte. 329 // It loads 1 byte from each source of the memcmp parameters with the given 330 // GEPIndex. It then subtracts the two loaded values and adds this result to the 331 // final phi node for selecting the memcmp result. 332 void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex, 333 unsigned OffsetBytes) { 334 BasicBlock *BB = LoadCmpBlocks[BlockIndex]; 335 Builder.SetInsertPoint(BB); 336 const LoadPair Loads = 337 getLoadPair(Type::getInt8Ty(CI->getContext()), /*NeedsBSwap=*/false, 338 Type::getInt32Ty(CI->getContext()), OffsetBytes); 339 Value *Diff = Builder.CreateSub(Loads.Lhs, Loads.Rhs); 340 341 PhiRes->addIncoming(Diff, BB); 342 343 if (BlockIndex < (LoadCmpBlocks.size() - 1)) { 344 // Early exit branch if difference found to EndBlock. Otherwise, continue to 345 // next LoadCmpBlock, 346 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff, 347 ConstantInt::get(Diff->getType(), 0)); 348 BranchInst *CmpBr = 349 BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp); 350 Builder.Insert(CmpBr); 351 if (DTU) 352 DTU->applyUpdates( 353 {{DominatorTree::Insert, BB, EndBlock}, 354 {DominatorTree::Insert, BB, LoadCmpBlocks[BlockIndex + 1]}}); 355 } else { 356 // The last block has an unconditional branch to EndBlock. 357 BranchInst *CmpBr = BranchInst::Create(EndBlock); 358 Builder.Insert(CmpBr); 359 if (DTU) 360 DTU->applyUpdates({{DominatorTree::Insert, BB, EndBlock}}); 361 } 362 } 363 364 /// Generate an equality comparison for one or more pairs of loaded values. 365 /// This is used in the case where the memcmp() call is compared equal or not 366 /// equal to zero. 367 Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex, 368 unsigned &LoadIndex) { 369 assert(LoadIndex < getNumLoads() && 370 "getCompareLoadPairs() called with no remaining loads"); 371 std::vector<Value *> XorList, OrList; 372 Value *Diff = nullptr; 373 374 const unsigned NumLoads = 375 std::min(getNumLoads() - LoadIndex, NumLoadsPerBlockForZeroCmp); 376 377 // For a single-block expansion, start inserting before the memcmp call. 378 if (LoadCmpBlocks.empty()) 379 Builder.SetInsertPoint(CI); 380 else 381 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 382 383 Value *Cmp = nullptr; 384 // If we have multiple loads per block, we need to generate a composite 385 // comparison using xor+or. The type for the combinations is the largest load 386 // type. 387 IntegerType *const MaxLoadType = 388 NumLoads == 1 ? nullptr 389 : IntegerType::get(CI->getContext(), MaxLoadSize * 8); 390 for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) { 391 const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex]; 392 const LoadPair Loads = getLoadPair( 393 IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8), 394 /*NeedsBSwap=*/false, MaxLoadType, CurLoadEntry.Offset); 395 396 if (NumLoads != 1) { 397 // If we have multiple loads per block, we need to generate a composite 398 // comparison using xor+or. 399 Diff = Builder.CreateXor(Loads.Lhs, Loads.Rhs); 400 Diff = Builder.CreateZExt(Diff, MaxLoadType); 401 XorList.push_back(Diff); 402 } else { 403 // If there's only one load per block, we just compare the loaded values. 404 Cmp = Builder.CreateICmpNE(Loads.Lhs, Loads.Rhs); 405 } 406 } 407 408 auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> { 409 std::vector<Value *> OutList; 410 for (unsigned i = 0; i < InList.size() - 1; i = i + 2) { 411 Value *Or = Builder.CreateOr(InList[i], InList[i + 1]); 412 OutList.push_back(Or); 413 } 414 if (InList.size() % 2 != 0) 415 OutList.push_back(InList.back()); 416 return OutList; 417 }; 418 419 if (!Cmp) { 420 // Pairwise OR the XOR results. 421 OrList = pairWiseOr(XorList); 422 423 // Pairwise OR the OR results until one result left. 424 while (OrList.size() != 1) { 425 OrList = pairWiseOr(OrList); 426 } 427 428 assert(Diff && "Failed to find comparison diff"); 429 Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0)); 430 } 431 432 return Cmp; 433 } 434 435 void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, 436 unsigned &LoadIndex) { 437 Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex); 438 439 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) 440 ? EndBlock 441 : LoadCmpBlocks[BlockIndex + 1]; 442 // Early exit branch if difference found to ResultBlock. Otherwise, 443 // continue to next LoadCmpBlock or EndBlock. 444 BasicBlock *BB = Builder.GetInsertBlock(); 445 BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp); 446 Builder.Insert(CmpBr); 447 if (DTU) 448 DTU->applyUpdates({{DominatorTree::Insert, BB, ResBlock.BB}, 449 {DominatorTree::Insert, BB, NextBB}}); 450 451 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 452 // since early exit to ResultBlock was not taken (no difference was found in 453 // any of the bytes). 454 if (BlockIndex == LoadCmpBlocks.size() - 1) { 455 Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); 456 PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); 457 } 458 } 459 460 // This function creates the IR intructions for loading and comparing using the 461 // given LoadSize. It loads the number of bytes specified by LoadSize from each 462 // source of the memcmp parameters. It then does a subtract to see if there was 463 // a difference in the loaded values. If a difference is found, it branches 464 // with an early exit to the ResultBlock for calculating which source was 465 // larger. Otherwise, it falls through to the either the next LoadCmpBlock or 466 // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with 467 // a special case through emitLoadCompareByteBlock. The special handling can 468 // simply subtract the loaded values and add it to the result phi node. 469 void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { 470 // There is one load per block in this case, BlockIndex == LoadIndex. 471 const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex]; 472 473 if (CurLoadEntry.LoadSize == 1) { 474 MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, CurLoadEntry.Offset); 475 return; 476 } 477 478 Type *LoadSizeType = 479 IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); 480 Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); 481 assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type"); 482 483 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 484 485 const LoadPair Loads = 486 getLoadPair(LoadSizeType, /*NeedsBSwap=*/DL.isLittleEndian(), MaxLoadType, 487 CurLoadEntry.Offset); 488 489 // Add the loaded values to the phi nodes for calculating memcmp result only 490 // if result is not used in a zero equality. 491 if (!IsUsedForZeroCmp) { 492 ResBlock.PhiSrc1->addIncoming(Loads.Lhs, LoadCmpBlocks[BlockIndex]); 493 ResBlock.PhiSrc2->addIncoming(Loads.Rhs, LoadCmpBlocks[BlockIndex]); 494 } 495 496 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Loads.Lhs, Loads.Rhs); 497 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) 498 ? EndBlock 499 : LoadCmpBlocks[BlockIndex + 1]; 500 // Early exit branch if difference found to ResultBlock. Otherwise, continue 501 // to next LoadCmpBlock or EndBlock. 502 BasicBlock *BB = Builder.GetInsertBlock(); 503 BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp); 504 Builder.Insert(CmpBr); 505 if (DTU) 506 DTU->applyUpdates({{DominatorTree::Insert, BB, NextBB}, 507 {DominatorTree::Insert, BB, ResBlock.BB}}); 508 509 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 510 // since early exit to ResultBlock was not taken (no difference was found in 511 // any of the bytes). 512 if (BlockIndex == LoadCmpBlocks.size() - 1) { 513 Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); 514 PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); 515 } 516 } 517 518 // This function populates the ResultBlock with a sequence to calculate the 519 // memcmp result. It compares the two loaded source values and returns -1 if 520 // src1 < src2 and 1 if src1 > src2. 521 void MemCmpExpansion::emitMemCmpResultBlock() { 522 // Special case: if memcmp result is used in a zero equality, result does not 523 // need to be calculated and can simply return 1. 524 if (IsUsedForZeroCmp) { 525 BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt(); 526 Builder.SetInsertPoint(ResBlock.BB, InsertPt); 527 Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1); 528 PhiRes->addIncoming(Res, ResBlock.BB); 529 BranchInst *NewBr = BranchInst::Create(EndBlock); 530 Builder.Insert(NewBr); 531 if (DTU) 532 DTU->applyUpdates({{DominatorTree::Insert, ResBlock.BB, EndBlock}}); 533 return; 534 } 535 BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt(); 536 Builder.SetInsertPoint(ResBlock.BB, InsertPt); 537 538 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1, 539 ResBlock.PhiSrc2); 540 541 Value *Res = 542 Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1), 543 ConstantInt::get(Builder.getInt32Ty(), 1)); 544 545 PhiRes->addIncoming(Res, ResBlock.BB); 546 BranchInst *NewBr = BranchInst::Create(EndBlock); 547 Builder.Insert(NewBr); 548 if (DTU) 549 DTU->applyUpdates({{DominatorTree::Insert, ResBlock.BB, EndBlock}}); 550 } 551 552 void MemCmpExpansion::setupResultBlockPHINodes() { 553 Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); 554 Builder.SetInsertPoint(ResBlock.BB); 555 // Note: this assumes one load per block. 556 ResBlock.PhiSrc1 = 557 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1"); 558 ResBlock.PhiSrc2 = 559 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2"); 560 } 561 562 void MemCmpExpansion::setupEndBlockPHINodes() { 563 Builder.SetInsertPoint(&EndBlock->front()); 564 PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res"); 565 } 566 567 Value *MemCmpExpansion::getMemCmpExpansionZeroCase() { 568 unsigned LoadIndex = 0; 569 // This loop populates each of the LoadCmpBlocks with the IR sequence to 570 // handle multiple loads per block. 571 for (unsigned I = 0; I < getNumBlocks(); ++I) { 572 emitLoadCompareBlockMultipleLoads(I, LoadIndex); 573 } 574 575 emitMemCmpResultBlock(); 576 return PhiRes; 577 } 578 579 /// A memcmp expansion that compares equality with 0 and only has one block of 580 /// load and compare can bypass the compare, branch, and phi IR that is required 581 /// in the general case. 582 Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() { 583 unsigned LoadIndex = 0; 584 Value *Cmp = getCompareLoadPairs(0, LoadIndex); 585 assert(LoadIndex == getNumLoads() && "some entries were not consumed"); 586 return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext())); 587 } 588 589 /// A memcmp expansion that only has one block of load and compare can bypass 590 /// the compare, branch, and phi IR that is required in the general case. 591 Value *MemCmpExpansion::getMemCmpOneBlock() { 592 Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8); 593 bool NeedsBSwap = DL.isLittleEndian() && Size != 1; 594 595 // The i8 and i16 cases don't need compares. We zext the loaded values and 596 // subtract them to get the suitable negative, zero, or positive i32 result. 597 if (Size < 4) { 598 const LoadPair Loads = 599 getLoadPair(LoadSizeType, NeedsBSwap, Builder.getInt32Ty(), 600 /*Offset*/ 0); 601 return Builder.CreateSub(Loads.Lhs, Loads.Rhs); 602 } 603 604 const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, LoadSizeType, 605 /*Offset*/ 0); 606 // The result of memcmp is negative, zero, or positive, so produce that by 607 // subtracting 2 extended compare bits: sub (ugt, ult). 608 // If a target prefers to use selects to get -1/0/1, they should be able 609 // to transform this later. The inverse transform (going from selects to math) 610 // may not be possible in the DAG because the selects got converted into 611 // branches before we got there. 612 Value *CmpUGT = Builder.CreateICmpUGT(Loads.Lhs, Loads.Rhs); 613 Value *CmpULT = Builder.CreateICmpULT(Loads.Lhs, Loads.Rhs); 614 Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty()); 615 Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty()); 616 return Builder.CreateSub(ZextUGT, ZextULT); 617 } 618 619 // This function expands the memcmp call into an inline expansion and returns 620 // the memcmp result. 621 Value *MemCmpExpansion::getMemCmpExpansion() { 622 // Create the basic block framework for a multi-block expansion. 623 if (getNumBlocks() != 1) { 624 BasicBlock *StartBlock = CI->getParent(); 625 EndBlock = SplitBlock(StartBlock, CI, DTU, /*LI=*/nullptr, 626 /*MSSAU=*/nullptr, "endblock"); 627 setupEndBlockPHINodes(); 628 createResultBlock(); 629 630 // If return value of memcmp is not used in a zero equality, we need to 631 // calculate which source was larger. The calculation requires the 632 // two loaded source values of each load compare block. 633 // These will be saved in the phi nodes created by setupResultBlockPHINodes. 634 if (!IsUsedForZeroCmp) setupResultBlockPHINodes(); 635 636 // Create the number of required load compare basic blocks. 637 createLoadCmpBlocks(); 638 639 // Update the terminator added by SplitBlock to branch to the first 640 // LoadCmpBlock. 641 StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]); 642 if (DTU) 643 DTU->applyUpdates({{DominatorTree::Insert, StartBlock, LoadCmpBlocks[0]}, 644 {DominatorTree::Delete, StartBlock, EndBlock}}); 645 } 646 647 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 648 649 if (IsUsedForZeroCmp) 650 return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock() 651 : getMemCmpExpansionZeroCase(); 652 653 if (getNumBlocks() == 1) 654 return getMemCmpOneBlock(); 655 656 for (unsigned I = 0; I < getNumBlocks(); ++I) { 657 emitLoadCompareBlock(I); 658 } 659 660 emitMemCmpResultBlock(); 661 return PhiRes; 662 } 663 664 // This function checks to see if an expansion of memcmp can be generated. 665 // It checks for constant compare size that is less than the max inline size. 666 // If an expansion cannot occur, returns false to leave as a library call. 667 // Otherwise, the library call is replaced with a new IR instruction sequence. 668 /// We want to transform: 669 /// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15) 670 /// To: 671 /// loadbb: 672 /// %0 = bitcast i32* %buffer2 to i8* 673 /// %1 = bitcast i32* %buffer1 to i8* 674 /// %2 = bitcast i8* %1 to i64* 675 /// %3 = bitcast i8* %0 to i64* 676 /// %4 = load i64, i64* %2 677 /// %5 = load i64, i64* %3 678 /// %6 = call i64 @llvm.bswap.i64(i64 %4) 679 /// %7 = call i64 @llvm.bswap.i64(i64 %5) 680 /// %8 = sub i64 %6, %7 681 /// %9 = icmp ne i64 %8, 0 682 /// br i1 %9, label %res_block, label %loadbb1 683 /// res_block: ; preds = %loadbb2, 684 /// %loadbb1, %loadbb 685 /// %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ] 686 /// %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ] 687 /// %10 = icmp ult i64 %phi.src1, %phi.src2 688 /// %11 = select i1 %10, i32 -1, i32 1 689 /// br label %endblock 690 /// loadbb1: ; preds = %loadbb 691 /// %12 = bitcast i32* %buffer2 to i8* 692 /// %13 = bitcast i32* %buffer1 to i8* 693 /// %14 = bitcast i8* %13 to i32* 694 /// %15 = bitcast i8* %12 to i32* 695 /// %16 = getelementptr i32, i32* %14, i32 2 696 /// %17 = getelementptr i32, i32* %15, i32 2 697 /// %18 = load i32, i32* %16 698 /// %19 = load i32, i32* %17 699 /// %20 = call i32 @llvm.bswap.i32(i32 %18) 700 /// %21 = call i32 @llvm.bswap.i32(i32 %19) 701 /// %22 = zext i32 %20 to i64 702 /// %23 = zext i32 %21 to i64 703 /// %24 = sub i64 %22, %23 704 /// %25 = icmp ne i64 %24, 0 705 /// br i1 %25, label %res_block, label %loadbb2 706 /// loadbb2: ; preds = %loadbb1 707 /// %26 = bitcast i32* %buffer2 to i8* 708 /// %27 = bitcast i32* %buffer1 to i8* 709 /// %28 = bitcast i8* %27 to i16* 710 /// %29 = bitcast i8* %26 to i16* 711 /// %30 = getelementptr i16, i16* %28, i16 6 712 /// %31 = getelementptr i16, i16* %29, i16 6 713 /// %32 = load i16, i16* %30 714 /// %33 = load i16, i16* %31 715 /// %34 = call i16 @llvm.bswap.i16(i16 %32) 716 /// %35 = call i16 @llvm.bswap.i16(i16 %33) 717 /// %36 = zext i16 %34 to i64 718 /// %37 = zext i16 %35 to i64 719 /// %38 = sub i64 %36, %37 720 /// %39 = icmp ne i64 %38, 0 721 /// br i1 %39, label %res_block, label %loadbb3 722 /// loadbb3: ; preds = %loadbb2 723 /// %40 = bitcast i32* %buffer2 to i8* 724 /// %41 = bitcast i32* %buffer1 to i8* 725 /// %42 = getelementptr i8, i8* %41, i8 14 726 /// %43 = getelementptr i8, i8* %40, i8 14 727 /// %44 = load i8, i8* %42 728 /// %45 = load i8, i8* %43 729 /// %46 = zext i8 %44 to i32 730 /// %47 = zext i8 %45 to i32 731 /// %48 = sub i32 %46, %47 732 /// br label %endblock 733 /// endblock: ; preds = %res_block, 734 /// %loadbb3 735 /// %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ] 736 /// ret i32 %phi.res 737 static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI, 738 const TargetLowering *TLI, const DataLayout *DL, 739 ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI, 740 DomTreeUpdater *DTU) { 741 NumMemCmpCalls++; 742 743 // Early exit from expansion if -Oz. 744 if (CI->getFunction()->hasMinSize()) 745 return false; 746 747 // Early exit from expansion if size is not a constant. 748 ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2)); 749 if (!SizeCast) { 750 NumMemCmpNotConstant++; 751 return false; 752 } 753 const uint64_t SizeVal = SizeCast->getZExtValue(); 754 755 if (SizeVal == 0) { 756 return false; 757 } 758 // TTI call to check if target would like to expand memcmp. Also, get the 759 // available load sizes. 760 const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI); 761 bool OptForSize = CI->getFunction()->hasOptSize() || 762 llvm::shouldOptimizeForSize(CI->getParent(), PSI, BFI); 763 auto Options = TTI->enableMemCmpExpansion(OptForSize, 764 IsUsedForZeroCmp); 765 if (!Options) return false; 766 767 if (MemCmpEqZeroNumLoadsPerBlock.getNumOccurrences()) 768 Options.NumLoadsPerBlock = MemCmpEqZeroNumLoadsPerBlock; 769 770 if (OptForSize && 771 MaxLoadsPerMemcmpOptSize.getNumOccurrences()) 772 Options.MaxNumLoads = MaxLoadsPerMemcmpOptSize; 773 774 if (!OptForSize && MaxLoadsPerMemcmp.getNumOccurrences()) 775 Options.MaxNumLoads = MaxLoadsPerMemcmp; 776 777 MemCmpExpansion Expansion(CI, SizeVal, Options, IsUsedForZeroCmp, *DL, DTU); 778 779 // Don't expand if this will require more loads than desired by the target. 780 if (Expansion.getNumLoads() == 0) { 781 NumMemCmpGreaterThanMax++; 782 return false; 783 } 784 785 NumMemCmpInlined++; 786 787 Value *Res = Expansion.getMemCmpExpansion(); 788 789 // Replace call with result of expansion and erase call. 790 CI->replaceAllUsesWith(Res); 791 CI->eraseFromParent(); 792 793 return true; 794 } 795 796 class ExpandMemCmpPass : public FunctionPass { 797 public: 798 static char ID; 799 800 ExpandMemCmpPass() : FunctionPass(ID) { 801 initializeExpandMemCmpPassPass(*PassRegistry::getPassRegistry()); 802 } 803 804 bool runOnFunction(Function &F) override { 805 if (skipFunction(F)) return false; 806 807 auto *TPC = getAnalysisIfAvailable<TargetPassConfig>(); 808 if (!TPC) { 809 return false; 810 } 811 const TargetLowering* TL = 812 TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering(); 813 814 const TargetLibraryInfo *TLI = 815 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 816 const TargetTransformInfo *TTI = 817 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 818 auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); 819 auto *BFI = (PSI && PSI->hasProfileSummary()) ? 820 &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() : 821 nullptr; 822 DominatorTree *DT = nullptr; 823 if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>()) 824 DT = &DTWP->getDomTree(); 825 auto PA = runImpl(F, TLI, TTI, TL, PSI, BFI, DT); 826 return !PA.areAllPreserved(); 827 } 828 829 private: 830 void getAnalysisUsage(AnalysisUsage &AU) const override { 831 AU.addRequired<TargetLibraryInfoWrapperPass>(); 832 AU.addRequired<TargetTransformInfoWrapperPass>(); 833 AU.addRequired<ProfileSummaryInfoWrapperPass>(); 834 AU.addPreserved<DominatorTreeWrapperPass>(); 835 LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU); 836 FunctionPass::getAnalysisUsage(AU); 837 } 838 839 PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI, 840 const TargetTransformInfo *TTI, 841 const TargetLowering *TL, ProfileSummaryInfo *PSI, 842 BlockFrequencyInfo *BFI, DominatorTree *DT); 843 // Returns true if a change was made. 844 bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI, 845 const TargetTransformInfo *TTI, const TargetLowering *TL, 846 const DataLayout &DL, ProfileSummaryInfo *PSI, 847 BlockFrequencyInfo *BFI, DomTreeUpdater *DTU); 848 }; 849 850 bool ExpandMemCmpPass::runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI, 851 const TargetTransformInfo *TTI, 852 const TargetLowering *TL, 853 const DataLayout &DL, ProfileSummaryInfo *PSI, 854 BlockFrequencyInfo *BFI, 855 DomTreeUpdater *DTU) { 856 for (Instruction& I : BB) { 857 CallInst *CI = dyn_cast<CallInst>(&I); 858 if (!CI) { 859 continue; 860 } 861 LibFunc Func; 862 if (TLI->getLibFunc(*CI, Func) && 863 (Func == LibFunc_memcmp || Func == LibFunc_bcmp) && 864 expandMemCmp(CI, TTI, TL, &DL, PSI, BFI, DTU)) { 865 return true; 866 } 867 } 868 return false; 869 } 870 871 PreservedAnalyses 872 ExpandMemCmpPass::runImpl(Function &F, const TargetLibraryInfo *TLI, 873 const TargetTransformInfo *TTI, 874 const TargetLowering *TL, ProfileSummaryInfo *PSI, 875 BlockFrequencyInfo *BFI, DominatorTree *DT) { 876 Optional<DomTreeUpdater> DTU; 877 if (DT) 878 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy); 879 880 const DataLayout& DL = F.getParent()->getDataLayout(); 881 bool MadeChanges = false; 882 for (auto BBIt = F.begin(); BBIt != F.end();) { 883 if (runOnBlock(*BBIt, TLI, TTI, TL, DL, PSI, BFI, 884 DTU.hasValue() ? DTU.getPointer() : nullptr)) { 885 MadeChanges = true; 886 // If changes were made, restart the function from the beginning, since 887 // the structure of the function was changed. 888 BBIt = F.begin(); 889 } else { 890 ++BBIt; 891 } 892 } 893 if (MadeChanges) 894 for (BasicBlock &BB : F) 895 SimplifyInstructionsInBlock(&BB); 896 if (!MadeChanges) 897 return PreservedAnalyses::all(); 898 PreservedAnalyses PA; 899 PA.preserve<DominatorTreeAnalysis>(); 900 return PA; 901 } 902 903 } // namespace 904 905 char ExpandMemCmpPass::ID = 0; 906 INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp", 907 "Expand memcmp() to load/stores", false, false) 908 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 909 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 910 INITIALIZE_PASS_DEPENDENCY(LazyBlockFrequencyInfoPass) 911 INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) 912 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 913 INITIALIZE_PASS_END(ExpandMemCmpPass, "expandmemcmp", 914 "Expand memcmp() to load/stores", false, false) 915 916 FunctionPass *llvm::createExpandMemCmpPass() { 917 return new ExpandMemCmpPass(); 918 } 919