1 //===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass tries to expand memcmp() calls into optimally-sized loads and 10 // compares for the target. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/Statistic.h" 15 #include "llvm/Analysis/ConstantFolding.h" 16 #include "llvm/Analysis/DomTreeUpdater.h" 17 #include "llvm/Analysis/LazyBlockFrequencyInfo.h" 18 #include "llvm/Analysis/ProfileSummaryInfo.h" 19 #include "llvm/Analysis/TargetLibraryInfo.h" 20 #include "llvm/Analysis/TargetTransformInfo.h" 21 #include "llvm/Analysis/ValueTracking.h" 22 #include "llvm/CodeGen/TargetLowering.h" 23 #include "llvm/CodeGen/TargetPassConfig.h" 24 #include "llvm/CodeGen/TargetSubtargetInfo.h" 25 #include "llvm/IR/Dominators.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/InitializePasses.h" 28 #include "llvm/Target/TargetMachine.h" 29 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 30 #include "llvm/Transforms/Utils/Local.h" 31 #include "llvm/Transforms/Utils/SizeOpts.h" 32 33 using namespace llvm; 34 35 #define DEBUG_TYPE "expandmemcmp" 36 37 STATISTIC(NumMemCmpCalls, "Number of memcmp calls"); 38 STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size"); 39 STATISTIC(NumMemCmpGreaterThanMax, 40 "Number of memcmp calls with size greater than max size"); 41 STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls"); 42 43 static cl::opt<unsigned> MemCmpEqZeroNumLoadsPerBlock( 44 "memcmp-num-loads-per-block", cl::Hidden, cl::init(1), 45 cl::desc("The number of loads per basic block for inline expansion of " 46 "memcmp that is only being compared against zero.")); 47 48 static cl::opt<unsigned> MaxLoadsPerMemcmp( 49 "max-loads-per-memcmp", cl::Hidden, 50 cl::desc("Set maximum number of loads used in expanded memcmp")); 51 52 static cl::opt<unsigned> MaxLoadsPerMemcmpOptSize( 53 "max-loads-per-memcmp-opt-size", cl::Hidden, 54 cl::desc("Set maximum number of loads used in expanded memcmp for -Os/Oz")); 55 56 namespace { 57 58 59 // This class provides helper functions to expand a memcmp library call into an 60 // inline expansion. 61 class MemCmpExpansion { 62 struct ResultBlock { 63 BasicBlock *BB = nullptr; 64 PHINode *PhiSrc1 = nullptr; 65 PHINode *PhiSrc2 = nullptr; 66 67 ResultBlock() = default; 68 }; 69 70 CallInst *const CI; 71 ResultBlock ResBlock; 72 const uint64_t Size; 73 unsigned MaxLoadSize; 74 uint64_t NumLoadsNonOneByte; 75 const uint64_t NumLoadsPerBlockForZeroCmp; 76 std::vector<BasicBlock *> LoadCmpBlocks; 77 BasicBlock *EndBlock; 78 PHINode *PhiRes; 79 const bool IsUsedForZeroCmp; 80 const DataLayout &DL; 81 DomTreeUpdater *DTU; 82 IRBuilder<> Builder; 83 // Represents the decomposition in blocks of the expansion. For example, 84 // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and 85 // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {1, 32}. 86 struct LoadEntry { 87 LoadEntry(unsigned LoadSize, uint64_t Offset) 88 : LoadSize(LoadSize), Offset(Offset) { 89 } 90 91 // The size of the load for this block, in bytes. 92 unsigned LoadSize; 93 // The offset of this load from the base pointer, in bytes. 94 uint64_t Offset; 95 }; 96 using LoadEntryVector = SmallVector<LoadEntry, 8>; 97 LoadEntryVector LoadSequence; 98 99 void createLoadCmpBlocks(); 100 void createResultBlock(); 101 void setupResultBlockPHINodes(); 102 void setupEndBlockPHINodes(); 103 Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex); 104 void emitLoadCompareBlock(unsigned BlockIndex); 105 void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, 106 unsigned &LoadIndex); 107 void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned OffsetBytes); 108 void emitMemCmpResultBlock(); 109 Value *getMemCmpExpansionZeroCase(); 110 Value *getMemCmpEqZeroOneBlock(); 111 Value *getMemCmpOneBlock(); 112 struct LoadPair { 113 Value *Lhs = nullptr; 114 Value *Rhs = nullptr; 115 }; 116 LoadPair getLoadPair(Type *LoadSizeType, bool NeedsBSwap, Type *CmpSizeType, 117 unsigned OffsetBytes); 118 119 static LoadEntryVector 120 computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes, 121 unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte); 122 static LoadEntryVector 123 computeOverlappingLoadSequence(uint64_t Size, unsigned MaxLoadSize, 124 unsigned MaxNumLoads, 125 unsigned &NumLoadsNonOneByte); 126 127 public: 128 MemCmpExpansion(CallInst *CI, uint64_t Size, 129 const TargetTransformInfo::MemCmpExpansionOptions &Options, 130 const bool IsUsedForZeroCmp, const DataLayout &TheDataLayout, 131 DomTreeUpdater *DTU); 132 133 unsigned getNumBlocks(); 134 uint64_t getNumLoads() const { return LoadSequence.size(); } 135 136 Value *getMemCmpExpansion(); 137 }; 138 139 MemCmpExpansion::LoadEntryVector MemCmpExpansion::computeGreedyLoadSequence( 140 uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes, 141 const unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte) { 142 NumLoadsNonOneByte = 0; 143 LoadEntryVector LoadSequence; 144 uint64_t Offset = 0; 145 while (Size && !LoadSizes.empty()) { 146 const unsigned LoadSize = LoadSizes.front(); 147 const uint64_t NumLoadsForThisSize = Size / LoadSize; 148 if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) { 149 // Do not expand if the total number of loads is larger than what the 150 // target allows. Note that it's important that we exit before completing 151 // the expansion to avoid using a ton of memory to store the expansion for 152 // large sizes. 153 return {}; 154 } 155 if (NumLoadsForThisSize > 0) { 156 for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) { 157 LoadSequence.push_back({LoadSize, Offset}); 158 Offset += LoadSize; 159 } 160 if (LoadSize > 1) 161 ++NumLoadsNonOneByte; 162 Size = Size % LoadSize; 163 } 164 LoadSizes = LoadSizes.drop_front(); 165 } 166 return LoadSequence; 167 } 168 169 MemCmpExpansion::LoadEntryVector 170 MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size, 171 const unsigned MaxLoadSize, 172 const unsigned MaxNumLoads, 173 unsigned &NumLoadsNonOneByte) { 174 // These are already handled by the greedy approach. 175 if (Size < 2 || MaxLoadSize < 2) 176 return {}; 177 178 // We try to do as many non-overlapping loads as possible starting from the 179 // beginning. 180 const uint64_t NumNonOverlappingLoads = Size / MaxLoadSize; 181 assert(NumNonOverlappingLoads && "there must be at least one load"); 182 // There remain 0 to (MaxLoadSize - 1) bytes to load, this will be done with 183 // an overlapping load. 184 Size = Size - NumNonOverlappingLoads * MaxLoadSize; 185 // Bail if we do not need an overloapping store, this is already handled by 186 // the greedy approach. 187 if (Size == 0) 188 return {}; 189 // Bail if the number of loads (non-overlapping + potential overlapping one) 190 // is larger than the max allowed. 191 if ((NumNonOverlappingLoads + 1) > MaxNumLoads) 192 return {}; 193 194 // Add non-overlapping loads. 195 LoadEntryVector LoadSequence; 196 uint64_t Offset = 0; 197 for (uint64_t I = 0; I < NumNonOverlappingLoads; ++I) { 198 LoadSequence.push_back({MaxLoadSize, Offset}); 199 Offset += MaxLoadSize; 200 } 201 202 // Add the last overlapping load. 203 assert(Size > 0 && Size < MaxLoadSize && "broken invariant"); 204 LoadSequence.push_back({MaxLoadSize, Offset - (MaxLoadSize - Size)}); 205 NumLoadsNonOneByte = 1; 206 return LoadSequence; 207 } 208 209 // Initialize the basic block structure required for expansion of memcmp call 210 // with given maximum load size and memcmp size parameter. 211 // This structure includes: 212 // 1. A list of load compare blocks - LoadCmpBlocks. 213 // 2. An EndBlock, split from original instruction point, which is the block to 214 // return from. 215 // 3. ResultBlock, block to branch to for early exit when a 216 // LoadCmpBlock finds a difference. 217 MemCmpExpansion::MemCmpExpansion( 218 CallInst *const CI, uint64_t Size, 219 const TargetTransformInfo::MemCmpExpansionOptions &Options, 220 const bool IsUsedForZeroCmp, const DataLayout &TheDataLayout, 221 DomTreeUpdater *DTU) 222 : CI(CI), Size(Size), MaxLoadSize(0), NumLoadsNonOneByte(0), 223 NumLoadsPerBlockForZeroCmp(Options.NumLoadsPerBlock), 224 IsUsedForZeroCmp(IsUsedForZeroCmp), DL(TheDataLayout), DTU(DTU), 225 Builder(CI) { 226 assert(Size > 0 && "zero blocks"); 227 // Scale the max size down if the target can load more bytes than we need. 228 llvm::ArrayRef<unsigned> LoadSizes(Options.LoadSizes); 229 while (!LoadSizes.empty() && LoadSizes.front() > Size) { 230 LoadSizes = LoadSizes.drop_front(); 231 } 232 assert(!LoadSizes.empty() && "cannot load Size bytes"); 233 MaxLoadSize = LoadSizes.front(); 234 // Compute the decomposition. 235 unsigned GreedyNumLoadsNonOneByte = 0; 236 LoadSequence = computeGreedyLoadSequence(Size, LoadSizes, Options.MaxNumLoads, 237 GreedyNumLoadsNonOneByte); 238 NumLoadsNonOneByte = GreedyNumLoadsNonOneByte; 239 assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant"); 240 // If we allow overlapping loads and the load sequence is not already optimal, 241 // use overlapping loads. 242 if (Options.AllowOverlappingLoads && 243 (LoadSequence.empty() || LoadSequence.size() > 2)) { 244 unsigned OverlappingNumLoadsNonOneByte = 0; 245 auto OverlappingLoads = computeOverlappingLoadSequence( 246 Size, MaxLoadSize, Options.MaxNumLoads, OverlappingNumLoadsNonOneByte); 247 if (!OverlappingLoads.empty() && 248 (LoadSequence.empty() || 249 OverlappingLoads.size() < LoadSequence.size())) { 250 LoadSequence = OverlappingLoads; 251 NumLoadsNonOneByte = OverlappingNumLoadsNonOneByte; 252 } 253 } 254 assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant"); 255 } 256 257 unsigned MemCmpExpansion::getNumBlocks() { 258 if (IsUsedForZeroCmp) 259 return getNumLoads() / NumLoadsPerBlockForZeroCmp + 260 (getNumLoads() % NumLoadsPerBlockForZeroCmp != 0 ? 1 : 0); 261 return getNumLoads(); 262 } 263 264 void MemCmpExpansion::createLoadCmpBlocks() { 265 for (unsigned i = 0; i < getNumBlocks(); i++) { 266 BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb", 267 EndBlock->getParent(), EndBlock); 268 LoadCmpBlocks.push_back(BB); 269 } 270 } 271 272 void MemCmpExpansion::createResultBlock() { 273 ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block", 274 EndBlock->getParent(), EndBlock); 275 } 276 277 MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType, 278 bool NeedsBSwap, 279 Type *CmpSizeType, 280 unsigned OffsetBytes) { 281 // Get the memory source at offset `OffsetBytes`. 282 Value *LhsSource = CI->getArgOperand(0); 283 Value *RhsSource = CI->getArgOperand(1); 284 Align LhsAlign = LhsSource->getPointerAlignment(DL); 285 Align RhsAlign = RhsSource->getPointerAlignment(DL); 286 if (OffsetBytes > 0) { 287 auto *ByteType = Type::getInt8Ty(CI->getContext()); 288 LhsSource = Builder.CreateConstGEP1_64( 289 ByteType, Builder.CreateBitCast(LhsSource, ByteType->getPointerTo()), 290 OffsetBytes); 291 RhsSource = Builder.CreateConstGEP1_64( 292 ByteType, Builder.CreateBitCast(RhsSource, ByteType->getPointerTo()), 293 OffsetBytes); 294 LhsAlign = commonAlignment(LhsAlign, OffsetBytes); 295 RhsAlign = commonAlignment(RhsAlign, OffsetBytes); 296 } 297 LhsSource = Builder.CreateBitCast(LhsSource, LoadSizeType->getPointerTo()); 298 RhsSource = Builder.CreateBitCast(RhsSource, LoadSizeType->getPointerTo()); 299 300 // Create a constant or a load from the source. 301 Value *Lhs = nullptr; 302 if (auto *C = dyn_cast<Constant>(LhsSource)) 303 Lhs = ConstantFoldLoadFromConstPtr(C, LoadSizeType, DL); 304 if (!Lhs) 305 Lhs = Builder.CreateAlignedLoad(LoadSizeType, LhsSource, LhsAlign); 306 307 Value *Rhs = nullptr; 308 if (auto *C = dyn_cast<Constant>(RhsSource)) 309 Rhs = ConstantFoldLoadFromConstPtr(C, LoadSizeType, DL); 310 if (!Rhs) 311 Rhs = Builder.CreateAlignedLoad(LoadSizeType, RhsSource, RhsAlign); 312 313 // Swap bytes if required. 314 if (NeedsBSwap) { 315 Function *Bswap = Intrinsic::getDeclaration(CI->getModule(), 316 Intrinsic::bswap, LoadSizeType); 317 Lhs = Builder.CreateCall(Bswap, Lhs); 318 Rhs = Builder.CreateCall(Bswap, Rhs); 319 } 320 321 // Zero extend if required. 322 if (CmpSizeType != nullptr && CmpSizeType != LoadSizeType) { 323 Lhs = Builder.CreateZExt(Lhs, CmpSizeType); 324 Rhs = Builder.CreateZExt(Rhs, CmpSizeType); 325 } 326 return {Lhs, Rhs}; 327 } 328 329 // This function creates the IR instructions for loading and comparing 1 byte. 330 // It loads 1 byte from each source of the memcmp parameters with the given 331 // GEPIndex. It then subtracts the two loaded values and adds this result to the 332 // final phi node for selecting the memcmp result. 333 void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex, 334 unsigned OffsetBytes) { 335 BasicBlock *BB = LoadCmpBlocks[BlockIndex]; 336 Builder.SetInsertPoint(BB); 337 const LoadPair Loads = 338 getLoadPair(Type::getInt8Ty(CI->getContext()), /*NeedsBSwap=*/false, 339 Type::getInt32Ty(CI->getContext()), OffsetBytes); 340 Value *Diff = Builder.CreateSub(Loads.Lhs, Loads.Rhs); 341 342 PhiRes->addIncoming(Diff, BB); 343 344 if (BlockIndex < (LoadCmpBlocks.size() - 1)) { 345 // Early exit branch if difference found to EndBlock. Otherwise, continue to 346 // next LoadCmpBlock, 347 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff, 348 ConstantInt::get(Diff->getType(), 0)); 349 BranchInst *CmpBr = 350 BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp); 351 if (DTU) 352 DTU->applyUpdates( 353 {{DominatorTree::Insert, BB, EndBlock}, 354 {DominatorTree::Insert, BB, LoadCmpBlocks[BlockIndex + 1]}}); 355 Builder.Insert(CmpBr); 356 } else { 357 // The last block has an unconditional branch to EndBlock. 358 BranchInst *CmpBr = BranchInst::Create(EndBlock); 359 if (DTU) 360 DTU->applyUpdates({{DominatorTree::Insert, BB, EndBlock}}); 361 Builder.Insert(CmpBr); 362 } 363 } 364 365 /// Generate an equality comparison for one or more pairs of loaded values. 366 /// This is used in the case where the memcmp() call is compared equal or not 367 /// equal to zero. 368 Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex, 369 unsigned &LoadIndex) { 370 assert(LoadIndex < getNumLoads() && 371 "getCompareLoadPairs() called with no remaining loads"); 372 std::vector<Value *> XorList, OrList; 373 Value *Diff = nullptr; 374 375 const unsigned NumLoads = 376 std::min(getNumLoads() - LoadIndex, NumLoadsPerBlockForZeroCmp); 377 378 // For a single-block expansion, start inserting before the memcmp call. 379 if (LoadCmpBlocks.empty()) 380 Builder.SetInsertPoint(CI); 381 else 382 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 383 384 Value *Cmp = nullptr; 385 // If we have multiple loads per block, we need to generate a composite 386 // comparison using xor+or. The type for the combinations is the largest load 387 // type. 388 IntegerType *const MaxLoadType = 389 NumLoads == 1 ? nullptr 390 : IntegerType::get(CI->getContext(), MaxLoadSize * 8); 391 for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) { 392 const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex]; 393 const LoadPair Loads = getLoadPair( 394 IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8), 395 /*NeedsBSwap=*/false, MaxLoadType, CurLoadEntry.Offset); 396 397 if (NumLoads != 1) { 398 // If we have multiple loads per block, we need to generate a composite 399 // comparison using xor+or. 400 Diff = Builder.CreateXor(Loads.Lhs, Loads.Rhs); 401 Diff = Builder.CreateZExt(Diff, MaxLoadType); 402 XorList.push_back(Diff); 403 } else { 404 // If there's only one load per block, we just compare the loaded values. 405 Cmp = Builder.CreateICmpNE(Loads.Lhs, Loads.Rhs); 406 } 407 } 408 409 auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> { 410 std::vector<Value *> OutList; 411 for (unsigned i = 0; i < InList.size() - 1; i = i + 2) { 412 Value *Or = Builder.CreateOr(InList[i], InList[i + 1]); 413 OutList.push_back(Or); 414 } 415 if (InList.size() % 2 != 0) 416 OutList.push_back(InList.back()); 417 return OutList; 418 }; 419 420 if (!Cmp) { 421 // Pairwise OR the XOR results. 422 OrList = pairWiseOr(XorList); 423 424 // Pairwise OR the OR results until one result left. 425 while (OrList.size() != 1) { 426 OrList = pairWiseOr(OrList); 427 } 428 429 assert(Diff && "Failed to find comparison diff"); 430 Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0)); 431 } 432 433 return Cmp; 434 } 435 436 void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, 437 unsigned &LoadIndex) { 438 Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex); 439 440 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) 441 ? EndBlock 442 : LoadCmpBlocks[BlockIndex + 1]; 443 // Early exit branch if difference found to ResultBlock. Otherwise, 444 // continue to next LoadCmpBlock or EndBlock. 445 BasicBlock *BB = Builder.GetInsertBlock(); 446 BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp); 447 Builder.Insert(CmpBr); 448 if (DTU) 449 DTU->applyUpdates({{DominatorTree::Insert, BB, ResBlock.BB}, 450 {DominatorTree::Insert, BB, NextBB}}); 451 452 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 453 // since early exit to ResultBlock was not taken (no difference was found in 454 // any of the bytes). 455 if (BlockIndex == LoadCmpBlocks.size() - 1) { 456 Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); 457 PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); 458 } 459 } 460 461 // This function creates the IR intructions for loading and comparing using the 462 // given LoadSize. It loads the number of bytes specified by LoadSize from each 463 // source of the memcmp parameters. It then does a subtract to see if there was 464 // a difference in the loaded values. If a difference is found, it branches 465 // with an early exit to the ResultBlock for calculating which source was 466 // larger. Otherwise, it falls through to the either the next LoadCmpBlock or 467 // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with 468 // a special case through emitLoadCompareByteBlock. The special handling can 469 // simply subtract the loaded values and add it to the result phi node. 470 void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { 471 // There is one load per block in this case, BlockIndex == LoadIndex. 472 const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex]; 473 474 if (CurLoadEntry.LoadSize == 1) { 475 MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, CurLoadEntry.Offset); 476 return; 477 } 478 479 Type *LoadSizeType = 480 IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); 481 Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); 482 assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type"); 483 484 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 485 486 const LoadPair Loads = 487 getLoadPair(LoadSizeType, /*NeedsBSwap=*/DL.isLittleEndian(), MaxLoadType, 488 CurLoadEntry.Offset); 489 490 // Add the loaded values to the phi nodes for calculating memcmp result only 491 // if result is not used in a zero equality. 492 if (!IsUsedForZeroCmp) { 493 ResBlock.PhiSrc1->addIncoming(Loads.Lhs, LoadCmpBlocks[BlockIndex]); 494 ResBlock.PhiSrc2->addIncoming(Loads.Rhs, LoadCmpBlocks[BlockIndex]); 495 } 496 497 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Loads.Lhs, Loads.Rhs); 498 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) 499 ? EndBlock 500 : LoadCmpBlocks[BlockIndex + 1]; 501 // Early exit branch if difference found to ResultBlock. Otherwise, continue 502 // to next LoadCmpBlock or EndBlock. 503 BasicBlock *BB = Builder.GetInsertBlock(); 504 BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp); 505 Builder.Insert(CmpBr); 506 if (DTU) 507 DTU->applyUpdates({{DominatorTree::Insert, BB, NextBB}, 508 {DominatorTree::Insert, BB, ResBlock.BB}}); 509 510 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 511 // since early exit to ResultBlock was not taken (no difference was found in 512 // any of the bytes). 513 if (BlockIndex == LoadCmpBlocks.size() - 1) { 514 Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); 515 PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); 516 } 517 } 518 519 // This function populates the ResultBlock with a sequence to calculate the 520 // memcmp result. It compares the two loaded source values and returns -1 if 521 // src1 < src2 and 1 if src1 > src2. 522 void MemCmpExpansion::emitMemCmpResultBlock() { 523 // Special case: if memcmp result is used in a zero equality, result does not 524 // need to be calculated and can simply return 1. 525 if (IsUsedForZeroCmp) { 526 BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt(); 527 Builder.SetInsertPoint(ResBlock.BB, InsertPt); 528 Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1); 529 PhiRes->addIncoming(Res, ResBlock.BB); 530 BranchInst *NewBr = BranchInst::Create(EndBlock); 531 Builder.Insert(NewBr); 532 if (DTU) 533 DTU->applyUpdates({{DominatorTree::Insert, ResBlock.BB, EndBlock}}); 534 return; 535 } 536 BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt(); 537 Builder.SetInsertPoint(ResBlock.BB, InsertPt); 538 539 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1, 540 ResBlock.PhiSrc2); 541 542 Value *Res = 543 Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1), 544 ConstantInt::get(Builder.getInt32Ty(), 1)); 545 546 PhiRes->addIncoming(Res, ResBlock.BB); 547 BranchInst *NewBr = BranchInst::Create(EndBlock); 548 Builder.Insert(NewBr); 549 if (DTU) 550 DTU->applyUpdates({{DominatorTree::Insert, ResBlock.BB, EndBlock}}); 551 } 552 553 void MemCmpExpansion::setupResultBlockPHINodes() { 554 Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); 555 Builder.SetInsertPoint(ResBlock.BB); 556 // Note: this assumes one load per block. 557 ResBlock.PhiSrc1 = 558 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1"); 559 ResBlock.PhiSrc2 = 560 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2"); 561 } 562 563 void MemCmpExpansion::setupEndBlockPHINodes() { 564 Builder.SetInsertPoint(&EndBlock->front()); 565 PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res"); 566 } 567 568 Value *MemCmpExpansion::getMemCmpExpansionZeroCase() { 569 unsigned LoadIndex = 0; 570 // This loop populates each of the LoadCmpBlocks with the IR sequence to 571 // handle multiple loads per block. 572 for (unsigned I = 0; I < getNumBlocks(); ++I) { 573 emitLoadCompareBlockMultipleLoads(I, LoadIndex); 574 } 575 576 emitMemCmpResultBlock(); 577 return PhiRes; 578 } 579 580 /// A memcmp expansion that compares equality with 0 and only has one block of 581 /// load and compare can bypass the compare, branch, and phi IR that is required 582 /// in the general case. 583 Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() { 584 unsigned LoadIndex = 0; 585 Value *Cmp = getCompareLoadPairs(0, LoadIndex); 586 assert(LoadIndex == getNumLoads() && "some entries were not consumed"); 587 return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext())); 588 } 589 590 /// A memcmp expansion that only has one block of load and compare can bypass 591 /// the compare, branch, and phi IR that is required in the general case. 592 Value *MemCmpExpansion::getMemCmpOneBlock() { 593 Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8); 594 bool NeedsBSwap = DL.isLittleEndian() && Size != 1; 595 596 // The i8 and i16 cases don't need compares. We zext the loaded values and 597 // subtract them to get the suitable negative, zero, or positive i32 result. 598 if (Size < 4) { 599 const LoadPair Loads = 600 getLoadPair(LoadSizeType, NeedsBSwap, Builder.getInt32Ty(), 601 /*Offset*/ 0); 602 return Builder.CreateSub(Loads.Lhs, Loads.Rhs); 603 } 604 605 const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, LoadSizeType, 606 /*Offset*/ 0); 607 // The result of memcmp is negative, zero, or positive, so produce that by 608 // subtracting 2 extended compare bits: sub (ugt, ult). 609 // If a target prefers to use selects to get -1/0/1, they should be able 610 // to transform this later. The inverse transform (going from selects to math) 611 // may not be possible in the DAG because the selects got converted into 612 // branches before we got there. 613 Value *CmpUGT = Builder.CreateICmpUGT(Loads.Lhs, Loads.Rhs); 614 Value *CmpULT = Builder.CreateICmpULT(Loads.Lhs, Loads.Rhs); 615 Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty()); 616 Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty()); 617 return Builder.CreateSub(ZextUGT, ZextULT); 618 } 619 620 // This function expands the memcmp call into an inline expansion and returns 621 // the memcmp result. 622 Value *MemCmpExpansion::getMemCmpExpansion() { 623 // Create the basic block framework for a multi-block expansion. 624 if (getNumBlocks() != 1) { 625 BasicBlock *StartBlock = CI->getParent(); 626 EndBlock = SplitBlock(StartBlock, CI, DTU, /*LI=*/nullptr, 627 /*MSSAU=*/nullptr, "endblock"); 628 setupEndBlockPHINodes(); 629 createResultBlock(); 630 631 // If return value of memcmp is not used in a zero equality, we need to 632 // calculate which source was larger. The calculation requires the 633 // two loaded source values of each load compare block. 634 // These will be saved in the phi nodes created by setupResultBlockPHINodes. 635 if (!IsUsedForZeroCmp) setupResultBlockPHINodes(); 636 637 // Create the number of required load compare basic blocks. 638 createLoadCmpBlocks(); 639 640 // Update the terminator added by SplitBlock to branch to the first 641 // LoadCmpBlock. 642 StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]); 643 if (DTU) 644 DTU->applyUpdates({{DominatorTree::Insert, StartBlock, LoadCmpBlocks[0]}, 645 {DominatorTree::Delete, StartBlock, EndBlock}}); 646 } 647 648 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 649 650 if (IsUsedForZeroCmp) 651 return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock() 652 : getMemCmpExpansionZeroCase(); 653 654 if (getNumBlocks() == 1) 655 return getMemCmpOneBlock(); 656 657 for (unsigned I = 0; I < getNumBlocks(); ++I) { 658 emitLoadCompareBlock(I); 659 } 660 661 emitMemCmpResultBlock(); 662 return PhiRes; 663 } 664 665 // This function checks to see if an expansion of memcmp can be generated. 666 // It checks for constant compare size that is less than the max inline size. 667 // If an expansion cannot occur, returns false to leave as a library call. 668 // Otherwise, the library call is replaced with a new IR instruction sequence. 669 /// We want to transform: 670 /// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15) 671 /// To: 672 /// loadbb: 673 /// %0 = bitcast i32* %buffer2 to i8* 674 /// %1 = bitcast i32* %buffer1 to i8* 675 /// %2 = bitcast i8* %1 to i64* 676 /// %3 = bitcast i8* %0 to i64* 677 /// %4 = load i64, i64* %2 678 /// %5 = load i64, i64* %3 679 /// %6 = call i64 @llvm.bswap.i64(i64 %4) 680 /// %7 = call i64 @llvm.bswap.i64(i64 %5) 681 /// %8 = sub i64 %6, %7 682 /// %9 = icmp ne i64 %8, 0 683 /// br i1 %9, label %res_block, label %loadbb1 684 /// res_block: ; preds = %loadbb2, 685 /// %loadbb1, %loadbb 686 /// %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ] 687 /// %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ] 688 /// %10 = icmp ult i64 %phi.src1, %phi.src2 689 /// %11 = select i1 %10, i32 -1, i32 1 690 /// br label %endblock 691 /// loadbb1: ; preds = %loadbb 692 /// %12 = bitcast i32* %buffer2 to i8* 693 /// %13 = bitcast i32* %buffer1 to i8* 694 /// %14 = bitcast i8* %13 to i32* 695 /// %15 = bitcast i8* %12 to i32* 696 /// %16 = getelementptr i32, i32* %14, i32 2 697 /// %17 = getelementptr i32, i32* %15, i32 2 698 /// %18 = load i32, i32* %16 699 /// %19 = load i32, i32* %17 700 /// %20 = call i32 @llvm.bswap.i32(i32 %18) 701 /// %21 = call i32 @llvm.bswap.i32(i32 %19) 702 /// %22 = zext i32 %20 to i64 703 /// %23 = zext i32 %21 to i64 704 /// %24 = sub i64 %22, %23 705 /// %25 = icmp ne i64 %24, 0 706 /// br i1 %25, label %res_block, label %loadbb2 707 /// loadbb2: ; preds = %loadbb1 708 /// %26 = bitcast i32* %buffer2 to i8* 709 /// %27 = bitcast i32* %buffer1 to i8* 710 /// %28 = bitcast i8* %27 to i16* 711 /// %29 = bitcast i8* %26 to i16* 712 /// %30 = getelementptr i16, i16* %28, i16 6 713 /// %31 = getelementptr i16, i16* %29, i16 6 714 /// %32 = load i16, i16* %30 715 /// %33 = load i16, i16* %31 716 /// %34 = call i16 @llvm.bswap.i16(i16 %32) 717 /// %35 = call i16 @llvm.bswap.i16(i16 %33) 718 /// %36 = zext i16 %34 to i64 719 /// %37 = zext i16 %35 to i64 720 /// %38 = sub i64 %36, %37 721 /// %39 = icmp ne i64 %38, 0 722 /// br i1 %39, label %res_block, label %loadbb3 723 /// loadbb3: ; preds = %loadbb2 724 /// %40 = bitcast i32* %buffer2 to i8* 725 /// %41 = bitcast i32* %buffer1 to i8* 726 /// %42 = getelementptr i8, i8* %41, i8 14 727 /// %43 = getelementptr i8, i8* %40, i8 14 728 /// %44 = load i8, i8* %42 729 /// %45 = load i8, i8* %43 730 /// %46 = zext i8 %44 to i32 731 /// %47 = zext i8 %45 to i32 732 /// %48 = sub i32 %46, %47 733 /// br label %endblock 734 /// endblock: ; preds = %res_block, 735 /// %loadbb3 736 /// %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ] 737 /// ret i32 %phi.res 738 static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI, 739 const TargetLowering *TLI, const DataLayout *DL, 740 ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI, 741 DomTreeUpdater *DTU) { 742 NumMemCmpCalls++; 743 744 // Early exit from expansion if -Oz. 745 if (CI->getFunction()->hasMinSize()) 746 return false; 747 748 // Early exit from expansion if size is not a constant. 749 ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2)); 750 if (!SizeCast) { 751 NumMemCmpNotConstant++; 752 return false; 753 } 754 const uint64_t SizeVal = SizeCast->getZExtValue(); 755 756 if (SizeVal == 0) { 757 return false; 758 } 759 // TTI call to check if target would like to expand memcmp. Also, get the 760 // available load sizes. 761 const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI); 762 bool OptForSize = CI->getFunction()->hasOptSize() || 763 llvm::shouldOptimizeForSize(CI->getParent(), PSI, BFI); 764 auto Options = TTI->enableMemCmpExpansion(OptForSize, 765 IsUsedForZeroCmp); 766 if (!Options) return false; 767 768 if (MemCmpEqZeroNumLoadsPerBlock.getNumOccurrences()) 769 Options.NumLoadsPerBlock = MemCmpEqZeroNumLoadsPerBlock; 770 771 if (OptForSize && 772 MaxLoadsPerMemcmpOptSize.getNumOccurrences()) 773 Options.MaxNumLoads = MaxLoadsPerMemcmpOptSize; 774 775 if (!OptForSize && MaxLoadsPerMemcmp.getNumOccurrences()) 776 Options.MaxNumLoads = MaxLoadsPerMemcmp; 777 778 MemCmpExpansion Expansion(CI, SizeVal, Options, IsUsedForZeroCmp, *DL, DTU); 779 780 // Don't expand if this will require more loads than desired by the target. 781 if (Expansion.getNumLoads() == 0) { 782 NumMemCmpGreaterThanMax++; 783 return false; 784 } 785 786 NumMemCmpInlined++; 787 788 Value *Res = Expansion.getMemCmpExpansion(); 789 790 // Replace call with result of expansion and erase call. 791 CI->replaceAllUsesWith(Res); 792 CI->eraseFromParent(); 793 794 return true; 795 } 796 797 class ExpandMemCmpPass : public FunctionPass { 798 public: 799 static char ID; 800 801 ExpandMemCmpPass() : FunctionPass(ID) { 802 initializeExpandMemCmpPassPass(*PassRegistry::getPassRegistry()); 803 } 804 805 bool runOnFunction(Function &F) override { 806 if (skipFunction(F)) return false; 807 808 auto *TPC = getAnalysisIfAvailable<TargetPassConfig>(); 809 if (!TPC) { 810 return false; 811 } 812 const TargetLowering* TL = 813 TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering(); 814 815 const TargetLibraryInfo *TLI = 816 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 817 const TargetTransformInfo *TTI = 818 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 819 auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); 820 auto *BFI = (PSI && PSI->hasProfileSummary()) ? 821 &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() : 822 nullptr; 823 DominatorTree *DT = nullptr; 824 if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>()) 825 DT = &DTWP->getDomTree(); 826 auto PA = runImpl(F, TLI, TTI, TL, PSI, BFI, DT); 827 return !PA.areAllPreserved(); 828 } 829 830 private: 831 void getAnalysisUsage(AnalysisUsage &AU) const override { 832 AU.addRequired<TargetLibraryInfoWrapperPass>(); 833 AU.addRequired<TargetTransformInfoWrapperPass>(); 834 AU.addRequired<ProfileSummaryInfoWrapperPass>(); 835 AU.addPreserved<DominatorTreeWrapperPass>(); 836 LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU); 837 FunctionPass::getAnalysisUsage(AU); 838 } 839 840 PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI, 841 const TargetTransformInfo *TTI, 842 const TargetLowering *TL, ProfileSummaryInfo *PSI, 843 BlockFrequencyInfo *BFI, DominatorTree *DT); 844 // Returns true if a change was made. 845 bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI, 846 const TargetTransformInfo *TTI, const TargetLowering *TL, 847 const DataLayout &DL, ProfileSummaryInfo *PSI, 848 BlockFrequencyInfo *BFI, DomTreeUpdater *DTU); 849 }; 850 851 bool ExpandMemCmpPass::runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI, 852 const TargetTransformInfo *TTI, 853 const TargetLowering *TL, 854 const DataLayout &DL, ProfileSummaryInfo *PSI, 855 BlockFrequencyInfo *BFI, 856 DomTreeUpdater *DTU) { 857 for (Instruction& I : BB) { 858 CallInst *CI = dyn_cast<CallInst>(&I); 859 if (!CI) { 860 continue; 861 } 862 LibFunc Func; 863 if (TLI->getLibFunc(*CI, Func) && 864 (Func == LibFunc_memcmp || Func == LibFunc_bcmp) && 865 expandMemCmp(CI, TTI, TL, &DL, PSI, BFI, DTU)) { 866 return true; 867 } 868 } 869 return false; 870 } 871 872 PreservedAnalyses 873 ExpandMemCmpPass::runImpl(Function &F, const TargetLibraryInfo *TLI, 874 const TargetTransformInfo *TTI, 875 const TargetLowering *TL, ProfileSummaryInfo *PSI, 876 BlockFrequencyInfo *BFI, DominatorTree *DT) { 877 Optional<DomTreeUpdater> DTU; 878 if (DT) 879 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy); 880 881 const DataLayout& DL = F.getParent()->getDataLayout(); 882 bool MadeChanges = false; 883 for (auto BBIt = F.begin(); BBIt != F.end();) { 884 if (runOnBlock(*BBIt, TLI, TTI, TL, DL, PSI, BFI, 885 DTU.hasValue() ? DTU.getPointer() : nullptr)) { 886 MadeChanges = true; 887 // If changes were made, restart the function from the beginning, since 888 // the structure of the function was changed. 889 BBIt = F.begin(); 890 } else { 891 ++BBIt; 892 } 893 } 894 if (MadeChanges) 895 for (BasicBlock &BB : F) 896 SimplifyInstructionsInBlock(&BB); 897 if (!MadeChanges) 898 return PreservedAnalyses::all(); 899 PreservedAnalyses PA; 900 PA.preserve<DominatorTreeAnalysis>(); 901 return PA; 902 } 903 904 } // namespace 905 906 char ExpandMemCmpPass::ID = 0; 907 INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp", 908 "Expand memcmp() to load/stores", false, false) 909 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 910 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 911 INITIALIZE_PASS_DEPENDENCY(LazyBlockFrequencyInfoPass) 912 INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) 913 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 914 INITIALIZE_PASS_END(ExpandMemCmpPass, "expandmemcmp", 915 "Expand memcmp() to load/stores", false, false) 916 917 FunctionPass *llvm::createExpandMemCmpPass() { 918 return new ExpandMemCmpPass(); 919 } 920