1 //===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass tries to expand memcmp() calls into optimally-sized loads and 10 // compares for the target. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/Statistic.h" 15 #include "llvm/Analysis/ConstantFolding.h" 16 #include "llvm/Analysis/LazyBlockFrequencyInfo.h" 17 #include "llvm/Analysis/ProfileSummaryInfo.h" 18 #include "llvm/Analysis/TargetLibraryInfo.h" 19 #include "llvm/Analysis/TargetTransformInfo.h" 20 #include "llvm/Analysis/ValueTracking.h" 21 #include "llvm/CodeGen/TargetLowering.h" 22 #include "llvm/CodeGen/TargetPassConfig.h" 23 #include "llvm/CodeGen/TargetSubtargetInfo.h" 24 #include "llvm/IR/IRBuilder.h" 25 #include "llvm/InitializePasses.h" 26 #include "llvm/Transforms/Utils/SizeOpts.h" 27 28 using namespace llvm; 29 30 #define DEBUG_TYPE "expandmemcmp" 31 32 STATISTIC(NumMemCmpCalls, "Number of memcmp calls"); 33 STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size"); 34 STATISTIC(NumMemCmpGreaterThanMax, 35 "Number of memcmp calls with size greater than max size"); 36 STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls"); 37 38 static cl::opt<unsigned> MemCmpEqZeroNumLoadsPerBlock( 39 "memcmp-num-loads-per-block", cl::Hidden, cl::init(1), 40 cl::desc("The number of loads per basic block for inline expansion of " 41 "memcmp that is only being compared against zero.")); 42 43 static cl::opt<unsigned> MaxLoadsPerMemcmp( 44 "max-loads-per-memcmp", cl::Hidden, 45 cl::desc("Set maximum number of loads used in expanded memcmp")); 46 47 static cl::opt<unsigned> MaxLoadsPerMemcmpOptSize( 48 "max-loads-per-memcmp-opt-size", cl::Hidden, 49 cl::desc("Set maximum number of loads used in expanded memcmp for -Os/Oz")); 50 51 namespace { 52 53 54 // This class provides helper functions to expand a memcmp library call into an 55 // inline expansion. 56 class MemCmpExpansion { 57 struct ResultBlock { 58 BasicBlock *BB = nullptr; 59 PHINode *PhiSrc1 = nullptr; 60 PHINode *PhiSrc2 = nullptr; 61 62 ResultBlock() = default; 63 }; 64 65 CallInst *const CI; 66 ResultBlock ResBlock; 67 const uint64_t Size; 68 unsigned MaxLoadSize; 69 uint64_t NumLoadsNonOneByte; 70 const uint64_t NumLoadsPerBlockForZeroCmp; 71 std::vector<BasicBlock *> LoadCmpBlocks; 72 BasicBlock *EndBlock; 73 PHINode *PhiRes; 74 const bool IsUsedForZeroCmp; 75 const DataLayout &DL; 76 IRBuilder<> Builder; 77 // Represents the decomposition in blocks of the expansion. For example, 78 // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and 79 // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}. 80 struct LoadEntry { 81 LoadEntry(unsigned LoadSize, uint64_t Offset) 82 : LoadSize(LoadSize), Offset(Offset) { 83 } 84 85 // The size of the load for this block, in bytes. 86 unsigned LoadSize; 87 // The offset of this load from the base pointer, in bytes. 88 uint64_t Offset; 89 }; 90 using LoadEntryVector = SmallVector<LoadEntry, 8>; 91 LoadEntryVector LoadSequence; 92 93 void createLoadCmpBlocks(); 94 void createResultBlock(); 95 void setupResultBlockPHINodes(); 96 void setupEndBlockPHINodes(); 97 Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex); 98 void emitLoadCompareBlock(unsigned BlockIndex); 99 void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, 100 unsigned &LoadIndex); 101 void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned OffsetBytes); 102 void emitMemCmpResultBlock(); 103 Value *getMemCmpExpansionZeroCase(); 104 Value *getMemCmpEqZeroOneBlock(); 105 Value *getMemCmpOneBlock(); 106 Value *getPtrToElementAtOffset(Value *Source, Type *LoadSizeType, 107 uint64_t OffsetBytes); 108 109 static LoadEntryVector 110 computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes, 111 unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte); 112 static LoadEntryVector 113 computeOverlappingLoadSequence(uint64_t Size, unsigned MaxLoadSize, 114 unsigned MaxNumLoads, 115 unsigned &NumLoadsNonOneByte); 116 117 public: 118 MemCmpExpansion(CallInst *CI, uint64_t Size, 119 const TargetTransformInfo::MemCmpExpansionOptions &Options, 120 const bool IsUsedForZeroCmp, const DataLayout &TheDataLayout); 121 122 unsigned getNumBlocks(); 123 uint64_t getNumLoads() const { return LoadSequence.size(); } 124 125 Value *getMemCmpExpansion(); 126 }; 127 128 MemCmpExpansion::LoadEntryVector MemCmpExpansion::computeGreedyLoadSequence( 129 uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes, 130 const unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte) { 131 NumLoadsNonOneByte = 0; 132 LoadEntryVector LoadSequence; 133 uint64_t Offset = 0; 134 while (Size && !LoadSizes.empty()) { 135 const unsigned LoadSize = LoadSizes.front(); 136 const uint64_t NumLoadsForThisSize = Size / LoadSize; 137 if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) { 138 // Do not expand if the total number of loads is larger than what the 139 // target allows. Note that it's important that we exit before completing 140 // the expansion to avoid using a ton of memory to store the expansion for 141 // large sizes. 142 return {}; 143 } 144 if (NumLoadsForThisSize > 0) { 145 for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) { 146 LoadSequence.push_back({LoadSize, Offset}); 147 Offset += LoadSize; 148 } 149 if (LoadSize > 1) 150 ++NumLoadsNonOneByte; 151 Size = Size % LoadSize; 152 } 153 LoadSizes = LoadSizes.drop_front(); 154 } 155 return LoadSequence; 156 } 157 158 MemCmpExpansion::LoadEntryVector 159 MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size, 160 const unsigned MaxLoadSize, 161 const unsigned MaxNumLoads, 162 unsigned &NumLoadsNonOneByte) { 163 // These are already handled by the greedy approach. 164 if (Size < 2 || MaxLoadSize < 2) 165 return {}; 166 167 // We try to do as many non-overlapping loads as possible starting from the 168 // beginning. 169 const uint64_t NumNonOverlappingLoads = Size / MaxLoadSize; 170 assert(NumNonOverlappingLoads && "there must be at least one load"); 171 // There remain 0 to (MaxLoadSize - 1) bytes to load, this will be done with 172 // an overlapping load. 173 Size = Size - NumNonOverlappingLoads * MaxLoadSize; 174 // Bail if we do not need an overloapping store, this is already handled by 175 // the greedy approach. 176 if (Size == 0) 177 return {}; 178 // Bail if the number of loads (non-overlapping + potential overlapping one) 179 // is larger than the max allowed. 180 if ((NumNonOverlappingLoads + 1) > MaxNumLoads) 181 return {}; 182 183 // Add non-overlapping loads. 184 LoadEntryVector LoadSequence; 185 uint64_t Offset = 0; 186 for (uint64_t I = 0; I < NumNonOverlappingLoads; ++I) { 187 LoadSequence.push_back({MaxLoadSize, Offset}); 188 Offset += MaxLoadSize; 189 } 190 191 // Add the last overlapping load. 192 assert(Size > 0 && Size < MaxLoadSize && "broken invariant"); 193 LoadSequence.push_back({MaxLoadSize, Offset - (MaxLoadSize - Size)}); 194 NumLoadsNonOneByte = 1; 195 return LoadSequence; 196 } 197 198 // Initialize the basic block structure required for expansion of memcmp call 199 // with given maximum load size and memcmp size parameter. 200 // This structure includes: 201 // 1. A list of load compare blocks - LoadCmpBlocks. 202 // 2. An EndBlock, split from original instruction point, which is the block to 203 // return from. 204 // 3. ResultBlock, block to branch to for early exit when a 205 // LoadCmpBlock finds a difference. 206 MemCmpExpansion::MemCmpExpansion( 207 CallInst *const CI, uint64_t Size, 208 const TargetTransformInfo::MemCmpExpansionOptions &Options, 209 const bool IsUsedForZeroCmp, const DataLayout &TheDataLayout) 210 : CI(CI), Size(Size), MaxLoadSize(0), NumLoadsNonOneByte(0), 211 NumLoadsPerBlockForZeroCmp(Options.NumLoadsPerBlock), 212 IsUsedForZeroCmp(IsUsedForZeroCmp), DL(TheDataLayout), Builder(CI) { 213 assert(Size > 0 && "zero blocks"); 214 // Scale the max size down if the target can load more bytes than we need. 215 llvm::ArrayRef<unsigned> LoadSizes(Options.LoadSizes); 216 while (!LoadSizes.empty() && LoadSizes.front() > Size) { 217 LoadSizes = LoadSizes.drop_front(); 218 } 219 assert(!LoadSizes.empty() && "cannot load Size bytes"); 220 MaxLoadSize = LoadSizes.front(); 221 // Compute the decomposition. 222 unsigned GreedyNumLoadsNonOneByte = 0; 223 LoadSequence = computeGreedyLoadSequence(Size, LoadSizes, Options.MaxNumLoads, 224 GreedyNumLoadsNonOneByte); 225 NumLoadsNonOneByte = GreedyNumLoadsNonOneByte; 226 assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant"); 227 // If we allow overlapping loads and the load sequence is not already optimal, 228 // use overlapping loads. 229 if (Options.AllowOverlappingLoads && 230 (LoadSequence.empty() || LoadSequence.size() > 2)) { 231 unsigned OverlappingNumLoadsNonOneByte = 0; 232 auto OverlappingLoads = computeOverlappingLoadSequence( 233 Size, MaxLoadSize, Options.MaxNumLoads, OverlappingNumLoadsNonOneByte); 234 if (!OverlappingLoads.empty() && 235 (LoadSequence.empty() || 236 OverlappingLoads.size() < LoadSequence.size())) { 237 LoadSequence = OverlappingLoads; 238 NumLoadsNonOneByte = OverlappingNumLoadsNonOneByte; 239 } 240 } 241 assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant"); 242 } 243 244 unsigned MemCmpExpansion::getNumBlocks() { 245 if (IsUsedForZeroCmp) 246 return getNumLoads() / NumLoadsPerBlockForZeroCmp + 247 (getNumLoads() % NumLoadsPerBlockForZeroCmp != 0 ? 1 : 0); 248 return getNumLoads(); 249 } 250 251 void MemCmpExpansion::createLoadCmpBlocks() { 252 for (unsigned i = 0; i < getNumBlocks(); i++) { 253 BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb", 254 EndBlock->getParent(), EndBlock); 255 LoadCmpBlocks.push_back(BB); 256 } 257 } 258 259 void MemCmpExpansion::createResultBlock() { 260 ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block", 261 EndBlock->getParent(), EndBlock); 262 } 263 264 /// Return a pointer to an element of type `LoadSizeType` at offset 265 /// `OffsetBytes`. 266 Value *MemCmpExpansion::getPtrToElementAtOffset(Value *Source, 267 Type *LoadSizeType, 268 uint64_t OffsetBytes) { 269 if (OffsetBytes > 0) { 270 auto *ByteType = Type::getInt8Ty(CI->getContext()); 271 Source = Builder.CreateConstGEP1_64( 272 ByteType, Builder.CreateBitCast(Source, ByteType->getPointerTo()), 273 OffsetBytes); 274 } 275 return Builder.CreateBitCast(Source, LoadSizeType->getPointerTo()); 276 } 277 278 // This function creates the IR instructions for loading and comparing 1 byte. 279 // It loads 1 byte from each source of the memcmp parameters with the given 280 // GEPIndex. It then subtracts the two loaded values and adds this result to the 281 // final phi node for selecting the memcmp result. 282 void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex, 283 unsigned OffsetBytes) { 284 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 285 Type *LoadSizeType = Type::getInt8Ty(CI->getContext()); 286 Value *Source1 = 287 getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, OffsetBytes); 288 Value *Source2 = 289 getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, OffsetBytes); 290 291 Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 292 Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 293 294 LoadSrc1 = Builder.CreateZExt(LoadSrc1, Type::getInt32Ty(CI->getContext())); 295 LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext())); 296 Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2); 297 298 PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]); 299 300 if (BlockIndex < (LoadCmpBlocks.size() - 1)) { 301 // Early exit branch if difference found to EndBlock. Otherwise, continue to 302 // next LoadCmpBlock, 303 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff, 304 ConstantInt::get(Diff->getType(), 0)); 305 BranchInst *CmpBr = 306 BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp); 307 Builder.Insert(CmpBr); 308 } else { 309 // The last block has an unconditional branch to EndBlock. 310 BranchInst *CmpBr = BranchInst::Create(EndBlock); 311 Builder.Insert(CmpBr); 312 } 313 } 314 315 /// Generate an equality comparison for one or more pairs of loaded values. 316 /// This is used in the case where the memcmp() call is compared equal or not 317 /// equal to zero. 318 Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex, 319 unsigned &LoadIndex) { 320 assert(LoadIndex < getNumLoads() && 321 "getCompareLoadPairs() called with no remaining loads"); 322 std::vector<Value *> XorList, OrList; 323 Value *Diff = nullptr; 324 325 const unsigned NumLoads = 326 std::min(getNumLoads() - LoadIndex, NumLoadsPerBlockForZeroCmp); 327 328 // For a single-block expansion, start inserting before the memcmp call. 329 if (LoadCmpBlocks.empty()) 330 Builder.SetInsertPoint(CI); 331 else 332 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 333 334 Value *Cmp = nullptr; 335 // If we have multiple loads per block, we need to generate a composite 336 // comparison using xor+or. The type for the combinations is the largest load 337 // type. 338 IntegerType *const MaxLoadType = 339 NumLoads == 1 ? nullptr 340 : IntegerType::get(CI->getContext(), MaxLoadSize * 8); 341 for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) { 342 const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex]; 343 344 IntegerType *LoadSizeType = 345 IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); 346 347 Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, 348 CurLoadEntry.Offset); 349 Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, 350 CurLoadEntry.Offset); 351 352 // Get a constant or load a value for each source address. 353 Value *LoadSrc1 = nullptr; 354 if (auto *Source1C = dyn_cast<Constant>(Source1)) 355 LoadSrc1 = ConstantFoldLoadFromConstPtr(Source1C, LoadSizeType, DL); 356 if (!LoadSrc1) 357 LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 358 359 Value *LoadSrc2 = nullptr; 360 if (auto *Source2C = dyn_cast<Constant>(Source2)) 361 LoadSrc2 = ConstantFoldLoadFromConstPtr(Source2C, LoadSizeType, DL); 362 if (!LoadSrc2) 363 LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 364 365 if (NumLoads != 1) { 366 if (LoadSizeType != MaxLoadType) { 367 LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType); 368 LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType); 369 } 370 // If we have multiple loads per block, we need to generate a composite 371 // comparison using xor+or. 372 Diff = Builder.CreateXor(LoadSrc1, LoadSrc2); 373 Diff = Builder.CreateZExt(Diff, MaxLoadType); 374 XorList.push_back(Diff); 375 } else { 376 // If there's only one load per block, we just compare the loaded values. 377 Cmp = Builder.CreateICmpNE(LoadSrc1, LoadSrc2); 378 } 379 } 380 381 auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> { 382 std::vector<Value *> OutList; 383 for (unsigned i = 0; i < InList.size() - 1; i = i + 2) { 384 Value *Or = Builder.CreateOr(InList[i], InList[i + 1]); 385 OutList.push_back(Or); 386 } 387 if (InList.size() % 2 != 0) 388 OutList.push_back(InList.back()); 389 return OutList; 390 }; 391 392 if (!Cmp) { 393 // Pairwise OR the XOR results. 394 OrList = pairWiseOr(XorList); 395 396 // Pairwise OR the OR results until one result left. 397 while (OrList.size() != 1) { 398 OrList = pairWiseOr(OrList); 399 } 400 401 assert(Diff && "Failed to find comparison diff"); 402 Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0)); 403 } 404 405 return Cmp; 406 } 407 408 void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, 409 unsigned &LoadIndex) { 410 Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex); 411 412 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) 413 ? EndBlock 414 : LoadCmpBlocks[BlockIndex + 1]; 415 // Early exit branch if difference found to ResultBlock. Otherwise, 416 // continue to next LoadCmpBlock or EndBlock. 417 BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp); 418 Builder.Insert(CmpBr); 419 420 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 421 // since early exit to ResultBlock was not taken (no difference was found in 422 // any of the bytes). 423 if (BlockIndex == LoadCmpBlocks.size() - 1) { 424 Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); 425 PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); 426 } 427 } 428 429 // This function creates the IR intructions for loading and comparing using the 430 // given LoadSize. It loads the number of bytes specified by LoadSize from each 431 // source of the memcmp parameters. It then does a subtract to see if there was 432 // a difference in the loaded values. If a difference is found, it branches 433 // with an early exit to the ResultBlock for calculating which source was 434 // larger. Otherwise, it falls through to the either the next LoadCmpBlock or 435 // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with 436 // a special case through emitLoadCompareByteBlock. The special handling can 437 // simply subtract the loaded values and add it to the result phi node. 438 void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { 439 // There is one load per block in this case, BlockIndex == LoadIndex. 440 const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex]; 441 442 if (CurLoadEntry.LoadSize == 1) { 443 MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, CurLoadEntry.Offset); 444 return; 445 } 446 447 Type *LoadSizeType = 448 IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); 449 Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); 450 assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type"); 451 452 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 453 454 Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, 455 CurLoadEntry.Offset); 456 Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, 457 CurLoadEntry.Offset); 458 459 // Load LoadSizeType from the base address. 460 Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 461 Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 462 463 if (DL.isLittleEndian()) { 464 Function *Bswap = Intrinsic::getDeclaration(CI->getModule(), 465 Intrinsic::bswap, LoadSizeType); 466 LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1); 467 LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2); 468 } 469 470 if (LoadSizeType != MaxLoadType) { 471 LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType); 472 LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType); 473 } 474 475 // Add the loaded values to the phi nodes for calculating memcmp result only 476 // if result is not used in a zero equality. 477 if (!IsUsedForZeroCmp) { 478 ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]); 479 ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]); 480 } 481 482 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2); 483 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) 484 ? EndBlock 485 : LoadCmpBlocks[BlockIndex + 1]; 486 // Early exit branch if difference found to ResultBlock. Otherwise, continue 487 // to next LoadCmpBlock or EndBlock. 488 BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp); 489 Builder.Insert(CmpBr); 490 491 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 492 // since early exit to ResultBlock was not taken (no difference was found in 493 // any of the bytes). 494 if (BlockIndex == LoadCmpBlocks.size() - 1) { 495 Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); 496 PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); 497 } 498 } 499 500 // This function populates the ResultBlock with a sequence to calculate the 501 // memcmp result. It compares the two loaded source values and returns -1 if 502 // src1 < src2 and 1 if src1 > src2. 503 void MemCmpExpansion::emitMemCmpResultBlock() { 504 // Special case: if memcmp result is used in a zero equality, result does not 505 // need to be calculated and can simply return 1. 506 if (IsUsedForZeroCmp) { 507 BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt(); 508 Builder.SetInsertPoint(ResBlock.BB, InsertPt); 509 Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1); 510 PhiRes->addIncoming(Res, ResBlock.BB); 511 BranchInst *NewBr = BranchInst::Create(EndBlock); 512 Builder.Insert(NewBr); 513 return; 514 } 515 BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt(); 516 Builder.SetInsertPoint(ResBlock.BB, InsertPt); 517 518 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1, 519 ResBlock.PhiSrc2); 520 521 Value *Res = 522 Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1), 523 ConstantInt::get(Builder.getInt32Ty(), 1)); 524 525 BranchInst *NewBr = BranchInst::Create(EndBlock); 526 Builder.Insert(NewBr); 527 PhiRes->addIncoming(Res, ResBlock.BB); 528 } 529 530 void MemCmpExpansion::setupResultBlockPHINodes() { 531 Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); 532 Builder.SetInsertPoint(ResBlock.BB); 533 // Note: this assumes one load per block. 534 ResBlock.PhiSrc1 = 535 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1"); 536 ResBlock.PhiSrc2 = 537 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2"); 538 } 539 540 void MemCmpExpansion::setupEndBlockPHINodes() { 541 Builder.SetInsertPoint(&EndBlock->front()); 542 PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res"); 543 } 544 545 Value *MemCmpExpansion::getMemCmpExpansionZeroCase() { 546 unsigned LoadIndex = 0; 547 // This loop populates each of the LoadCmpBlocks with the IR sequence to 548 // handle multiple loads per block. 549 for (unsigned I = 0; I < getNumBlocks(); ++I) { 550 emitLoadCompareBlockMultipleLoads(I, LoadIndex); 551 } 552 553 emitMemCmpResultBlock(); 554 return PhiRes; 555 } 556 557 /// A memcmp expansion that compares equality with 0 and only has one block of 558 /// load and compare can bypass the compare, branch, and phi IR that is required 559 /// in the general case. 560 Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() { 561 unsigned LoadIndex = 0; 562 Value *Cmp = getCompareLoadPairs(0, LoadIndex); 563 assert(LoadIndex == getNumLoads() && "some entries were not consumed"); 564 return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext())); 565 } 566 567 /// A memcmp expansion that only has one block of load and compare can bypass 568 /// the compare, branch, and phi IR that is required in the general case. 569 Value *MemCmpExpansion::getMemCmpOneBlock() { 570 Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8); 571 Value *Source1 = CI->getArgOperand(0); 572 Value *Source2 = CI->getArgOperand(1); 573 574 // Cast source to LoadSizeType*. 575 if (Source1->getType() != LoadSizeType) 576 Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); 577 if (Source2->getType() != LoadSizeType) 578 Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); 579 580 // Load LoadSizeType from the base address. 581 Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 582 Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 583 584 if (DL.isLittleEndian() && Size != 1) { 585 Function *Bswap = Intrinsic::getDeclaration(CI->getModule(), 586 Intrinsic::bswap, LoadSizeType); 587 LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1); 588 LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2); 589 } 590 591 if (Size < 4) { 592 // The i8 and i16 cases don't need compares. We zext the loaded values and 593 // subtract them to get the suitable negative, zero, or positive i32 result. 594 LoadSrc1 = Builder.CreateZExt(LoadSrc1, Builder.getInt32Ty()); 595 LoadSrc2 = Builder.CreateZExt(LoadSrc2, Builder.getInt32Ty()); 596 return Builder.CreateSub(LoadSrc1, LoadSrc2); 597 } 598 599 // The result of memcmp is negative, zero, or positive, so produce that by 600 // subtracting 2 extended compare bits: sub (ugt, ult). 601 // If a target prefers to use selects to get -1/0/1, they should be able 602 // to transform this later. The inverse transform (going from selects to math) 603 // may not be possible in the DAG because the selects got converted into 604 // branches before we got there. 605 Value *CmpUGT = Builder.CreateICmpUGT(LoadSrc1, LoadSrc2); 606 Value *CmpULT = Builder.CreateICmpULT(LoadSrc1, LoadSrc2); 607 Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty()); 608 Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty()); 609 return Builder.CreateSub(ZextUGT, ZextULT); 610 } 611 612 // This function expands the memcmp call into an inline expansion and returns 613 // the memcmp result. 614 Value *MemCmpExpansion::getMemCmpExpansion() { 615 // Create the basic block framework for a multi-block expansion. 616 if (getNumBlocks() != 1) { 617 BasicBlock *StartBlock = CI->getParent(); 618 EndBlock = StartBlock->splitBasicBlock(CI, "endblock"); 619 setupEndBlockPHINodes(); 620 createResultBlock(); 621 622 // If return value of memcmp is not used in a zero equality, we need to 623 // calculate which source was larger. The calculation requires the 624 // two loaded source values of each load compare block. 625 // These will be saved in the phi nodes created by setupResultBlockPHINodes. 626 if (!IsUsedForZeroCmp) setupResultBlockPHINodes(); 627 628 // Create the number of required load compare basic blocks. 629 createLoadCmpBlocks(); 630 631 // Update the terminator added by splitBasicBlock to branch to the first 632 // LoadCmpBlock. 633 StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]); 634 } 635 636 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 637 638 if (IsUsedForZeroCmp) 639 return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock() 640 : getMemCmpExpansionZeroCase(); 641 642 if (getNumBlocks() == 1) 643 return getMemCmpOneBlock(); 644 645 for (unsigned I = 0; I < getNumBlocks(); ++I) { 646 emitLoadCompareBlock(I); 647 } 648 649 emitMemCmpResultBlock(); 650 return PhiRes; 651 } 652 653 // This function checks to see if an expansion of memcmp can be generated. 654 // It checks for constant compare size that is less than the max inline size. 655 // If an expansion cannot occur, returns false to leave as a library call. 656 // Otherwise, the library call is replaced with a new IR instruction sequence. 657 /// We want to transform: 658 /// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15) 659 /// To: 660 /// loadbb: 661 /// %0 = bitcast i32* %buffer2 to i8* 662 /// %1 = bitcast i32* %buffer1 to i8* 663 /// %2 = bitcast i8* %1 to i64* 664 /// %3 = bitcast i8* %0 to i64* 665 /// %4 = load i64, i64* %2 666 /// %5 = load i64, i64* %3 667 /// %6 = call i64 @llvm.bswap.i64(i64 %4) 668 /// %7 = call i64 @llvm.bswap.i64(i64 %5) 669 /// %8 = sub i64 %6, %7 670 /// %9 = icmp ne i64 %8, 0 671 /// br i1 %9, label %res_block, label %loadbb1 672 /// res_block: ; preds = %loadbb2, 673 /// %loadbb1, %loadbb 674 /// %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ] 675 /// %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ] 676 /// %10 = icmp ult i64 %phi.src1, %phi.src2 677 /// %11 = select i1 %10, i32 -1, i32 1 678 /// br label %endblock 679 /// loadbb1: ; preds = %loadbb 680 /// %12 = bitcast i32* %buffer2 to i8* 681 /// %13 = bitcast i32* %buffer1 to i8* 682 /// %14 = bitcast i8* %13 to i32* 683 /// %15 = bitcast i8* %12 to i32* 684 /// %16 = getelementptr i32, i32* %14, i32 2 685 /// %17 = getelementptr i32, i32* %15, i32 2 686 /// %18 = load i32, i32* %16 687 /// %19 = load i32, i32* %17 688 /// %20 = call i32 @llvm.bswap.i32(i32 %18) 689 /// %21 = call i32 @llvm.bswap.i32(i32 %19) 690 /// %22 = zext i32 %20 to i64 691 /// %23 = zext i32 %21 to i64 692 /// %24 = sub i64 %22, %23 693 /// %25 = icmp ne i64 %24, 0 694 /// br i1 %25, label %res_block, label %loadbb2 695 /// loadbb2: ; preds = %loadbb1 696 /// %26 = bitcast i32* %buffer2 to i8* 697 /// %27 = bitcast i32* %buffer1 to i8* 698 /// %28 = bitcast i8* %27 to i16* 699 /// %29 = bitcast i8* %26 to i16* 700 /// %30 = getelementptr i16, i16* %28, i16 6 701 /// %31 = getelementptr i16, i16* %29, i16 6 702 /// %32 = load i16, i16* %30 703 /// %33 = load i16, i16* %31 704 /// %34 = call i16 @llvm.bswap.i16(i16 %32) 705 /// %35 = call i16 @llvm.bswap.i16(i16 %33) 706 /// %36 = zext i16 %34 to i64 707 /// %37 = zext i16 %35 to i64 708 /// %38 = sub i64 %36, %37 709 /// %39 = icmp ne i64 %38, 0 710 /// br i1 %39, label %res_block, label %loadbb3 711 /// loadbb3: ; preds = %loadbb2 712 /// %40 = bitcast i32* %buffer2 to i8* 713 /// %41 = bitcast i32* %buffer1 to i8* 714 /// %42 = getelementptr i8, i8* %41, i8 14 715 /// %43 = getelementptr i8, i8* %40, i8 14 716 /// %44 = load i8, i8* %42 717 /// %45 = load i8, i8* %43 718 /// %46 = zext i8 %44 to i32 719 /// %47 = zext i8 %45 to i32 720 /// %48 = sub i32 %46, %47 721 /// br label %endblock 722 /// endblock: ; preds = %res_block, 723 /// %loadbb3 724 /// %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ] 725 /// ret i32 %phi.res 726 static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI, 727 const TargetLowering *TLI, const DataLayout *DL, 728 ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI) { 729 NumMemCmpCalls++; 730 731 // Early exit from expansion if -Oz. 732 if (CI->getFunction()->hasMinSize()) 733 return false; 734 735 // Early exit from expansion if size is not a constant. 736 ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2)); 737 if (!SizeCast) { 738 NumMemCmpNotConstant++; 739 return false; 740 } 741 const uint64_t SizeVal = SizeCast->getZExtValue(); 742 743 if (SizeVal == 0) { 744 return false; 745 } 746 // TTI call to check if target would like to expand memcmp. Also, get the 747 // available load sizes. 748 const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI); 749 bool OptForSize = CI->getFunction()->hasOptSize() || 750 llvm::shouldOptimizeForSize(CI->getParent(), PSI, BFI); 751 auto Options = TTI->enableMemCmpExpansion(OptForSize, 752 IsUsedForZeroCmp); 753 if (!Options) return false; 754 755 if (MemCmpEqZeroNumLoadsPerBlock.getNumOccurrences()) 756 Options.NumLoadsPerBlock = MemCmpEqZeroNumLoadsPerBlock; 757 758 if (OptForSize && 759 MaxLoadsPerMemcmpOptSize.getNumOccurrences()) 760 Options.MaxNumLoads = MaxLoadsPerMemcmpOptSize; 761 762 if (!OptForSize && MaxLoadsPerMemcmp.getNumOccurrences()) 763 Options.MaxNumLoads = MaxLoadsPerMemcmp; 764 765 MemCmpExpansion Expansion(CI, SizeVal, Options, IsUsedForZeroCmp, *DL); 766 767 // Don't expand if this will require more loads than desired by the target. 768 if (Expansion.getNumLoads() == 0) { 769 NumMemCmpGreaterThanMax++; 770 return false; 771 } 772 773 NumMemCmpInlined++; 774 775 Value *Res = Expansion.getMemCmpExpansion(); 776 777 // Replace call with result of expansion and erase call. 778 CI->replaceAllUsesWith(Res); 779 CI->eraseFromParent(); 780 781 return true; 782 } 783 784 785 786 class ExpandMemCmpPass : public FunctionPass { 787 public: 788 static char ID; 789 790 ExpandMemCmpPass() : FunctionPass(ID) { 791 initializeExpandMemCmpPassPass(*PassRegistry::getPassRegistry()); 792 } 793 794 bool runOnFunction(Function &F) override { 795 if (skipFunction(F)) return false; 796 797 auto *TPC = getAnalysisIfAvailable<TargetPassConfig>(); 798 if (!TPC) { 799 return false; 800 } 801 const TargetLowering* TL = 802 TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering(); 803 804 const TargetLibraryInfo *TLI = 805 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 806 const TargetTransformInfo *TTI = 807 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 808 auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); 809 auto *BFI = (PSI && PSI->hasProfileSummary()) ? 810 &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() : 811 nullptr; 812 auto PA = runImpl(F, TLI, TTI, TL, PSI, BFI); 813 return !PA.areAllPreserved(); 814 } 815 816 private: 817 void getAnalysisUsage(AnalysisUsage &AU) const override { 818 AU.addRequired<TargetLibraryInfoWrapperPass>(); 819 AU.addRequired<TargetTransformInfoWrapperPass>(); 820 AU.addRequired<ProfileSummaryInfoWrapperPass>(); 821 LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU); 822 FunctionPass::getAnalysisUsage(AU); 823 } 824 825 PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI, 826 const TargetTransformInfo *TTI, 827 const TargetLowering* TL, 828 ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI); 829 // Returns true if a change was made. 830 bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI, 831 const TargetTransformInfo *TTI, const TargetLowering* TL, 832 const DataLayout& DL, ProfileSummaryInfo *PSI, 833 BlockFrequencyInfo *BFI); 834 }; 835 836 bool ExpandMemCmpPass::runOnBlock( 837 BasicBlock &BB, const TargetLibraryInfo *TLI, 838 const TargetTransformInfo *TTI, const TargetLowering* TL, 839 const DataLayout& DL, ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI) { 840 for (Instruction& I : BB) { 841 CallInst *CI = dyn_cast<CallInst>(&I); 842 if (!CI) { 843 continue; 844 } 845 LibFunc Func; 846 if (TLI->getLibFunc(ImmutableCallSite(CI), Func) && 847 (Func == LibFunc_memcmp || Func == LibFunc_bcmp) && 848 expandMemCmp(CI, TTI, TL, &DL, PSI, BFI)) { 849 return true; 850 } 851 } 852 return false; 853 } 854 855 856 PreservedAnalyses ExpandMemCmpPass::runImpl( 857 Function &F, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, 858 const TargetLowering* TL, ProfileSummaryInfo *PSI, 859 BlockFrequencyInfo *BFI) { 860 const DataLayout& DL = F.getParent()->getDataLayout(); 861 bool MadeChanges = false; 862 for (auto BBIt = F.begin(); BBIt != F.end();) { 863 if (runOnBlock(*BBIt, TLI, TTI, TL, DL, PSI, BFI)) { 864 MadeChanges = true; 865 // If changes were made, restart the function from the beginning, since 866 // the structure of the function was changed. 867 BBIt = F.begin(); 868 } else { 869 ++BBIt; 870 } 871 } 872 return MadeChanges ? PreservedAnalyses::none() : PreservedAnalyses::all(); 873 } 874 875 } // namespace 876 877 char ExpandMemCmpPass::ID = 0; 878 INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp", 879 "Expand memcmp() to load/stores", false, false) 880 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 881 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 882 INITIALIZE_PASS_DEPENDENCY(LazyBlockFrequencyInfoPass) 883 INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) 884 INITIALIZE_PASS_END(ExpandMemCmpPass, "expandmemcmp", 885 "Expand memcmp() to load/stores", false, false) 886 887 FunctionPass *llvm::createExpandMemCmpPass() { 888 return new ExpandMemCmpPass(); 889 } 890