1 //===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass implements whole program optimization of virtual calls in cases 10 // where we know (via !type metadata) that the list of callees is fixed. This 11 // includes the following: 12 // - Single implementation devirtualization: if a virtual call has a single 13 // possible callee, replace all calls with a direct call to that callee. 14 // - Virtual constant propagation: if the virtual function's return type is an 15 // integer <=64 bits and all possible callees are readnone, for each class and 16 // each list of constant arguments: evaluate the function, store the return 17 // value alongside the virtual table, and rewrite each virtual call as a load 18 // from the virtual table. 19 // - Uniform return value optimization: if the conditions for virtual constant 20 // propagation hold and each function returns the same constant value, replace 21 // each virtual call with that constant. 22 // - Unique return value optimization for i1 return values: if the conditions 23 // for virtual constant propagation hold and a single vtable's function 24 // returns 0, or a single vtable's function returns 1, replace each virtual 25 // call with a comparison of the vptr against that vtable's address. 26 // 27 // This pass is intended to be used during the regular and thin LTO pipelines: 28 // 29 // During regular LTO, the pass determines the best optimization for each 30 // virtual call and applies the resolutions directly to virtual calls that are 31 // eligible for virtual call optimization (i.e. calls that use either of the 32 // llvm.assume(llvm.type.test) or llvm.type.checked.load intrinsics). 33 // 34 // During hybrid Regular/ThinLTO, the pass operates in two phases: 35 // - Export phase: this is run during the thin link over a single merged module 36 // that contains all vtables with !type metadata that participate in the link. 37 // The pass computes a resolution for each virtual call and stores it in the 38 // type identifier summary. 39 // - Import phase: this is run during the thin backends over the individual 40 // modules. The pass applies the resolutions previously computed during the 41 // import phase to each eligible virtual call. 42 // 43 // During ThinLTO, the pass operates in two phases: 44 // - Export phase: this is run during the thin link over the index which 45 // contains a summary of all vtables with !type metadata that participate in 46 // the link. It computes a resolution for each virtual call and stores it in 47 // the type identifier summary. Only single implementation devirtualization 48 // is supported. 49 // - Import phase: (same as with hybrid case above). 50 // 51 //===----------------------------------------------------------------------===// 52 53 #include "llvm/Transforms/IPO/WholeProgramDevirt.h" 54 #include "llvm/ADT/ArrayRef.h" 55 #include "llvm/ADT/DenseMap.h" 56 #include "llvm/ADT/DenseMapInfo.h" 57 #include "llvm/ADT/DenseSet.h" 58 #include "llvm/ADT/MapVector.h" 59 #include "llvm/ADT/SmallVector.h" 60 #include "llvm/ADT/Statistic.h" 61 #include "llvm/ADT/iterator_range.h" 62 #include "llvm/Analysis/AssumptionCache.h" 63 #include "llvm/Analysis/BasicAliasAnalysis.h" 64 #include "llvm/Analysis/OptimizationRemarkEmitter.h" 65 #include "llvm/Analysis/TypeMetadataUtils.h" 66 #include "llvm/Bitcode/BitcodeReader.h" 67 #include "llvm/Bitcode/BitcodeWriter.h" 68 #include "llvm/IR/Constants.h" 69 #include "llvm/IR/DataLayout.h" 70 #include "llvm/IR/DebugLoc.h" 71 #include "llvm/IR/DerivedTypes.h" 72 #include "llvm/IR/Dominators.h" 73 #include "llvm/IR/Function.h" 74 #include "llvm/IR/GlobalAlias.h" 75 #include "llvm/IR/GlobalVariable.h" 76 #include "llvm/IR/IRBuilder.h" 77 #include "llvm/IR/InstrTypes.h" 78 #include "llvm/IR/Instruction.h" 79 #include "llvm/IR/Instructions.h" 80 #include "llvm/IR/Intrinsics.h" 81 #include "llvm/IR/LLVMContext.h" 82 #include "llvm/IR/MDBuilder.h" 83 #include "llvm/IR/Metadata.h" 84 #include "llvm/IR/Module.h" 85 #include "llvm/IR/ModuleSummaryIndexYAML.h" 86 #include "llvm/Support/Casting.h" 87 #include "llvm/Support/CommandLine.h" 88 #include "llvm/Support/Errc.h" 89 #include "llvm/Support/Error.h" 90 #include "llvm/Support/FileSystem.h" 91 #include "llvm/Support/GlobPattern.h" 92 #include "llvm/Support/MathExtras.h" 93 #include "llvm/TargetParser/Triple.h" 94 #include "llvm/Transforms/IPO.h" 95 #include "llvm/Transforms/IPO/FunctionAttrs.h" 96 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 97 #include "llvm/Transforms/Utils/CallPromotionUtils.h" 98 #include "llvm/Transforms/Utils/Evaluator.h" 99 #include <algorithm> 100 #include <cstddef> 101 #include <map> 102 #include <set> 103 #include <string> 104 105 using namespace llvm; 106 using namespace wholeprogramdevirt; 107 108 #define DEBUG_TYPE "wholeprogramdevirt" 109 110 STATISTIC(NumDevirtTargets, "Number of whole program devirtualization targets"); 111 STATISTIC(NumSingleImpl, "Number of single implementation devirtualizations"); 112 STATISTIC(NumBranchFunnel, "Number of branch funnels"); 113 STATISTIC(NumUniformRetVal, "Number of uniform return value optimizations"); 114 STATISTIC(NumUniqueRetVal, "Number of unique return value optimizations"); 115 STATISTIC(NumVirtConstProp1Bit, 116 "Number of 1 bit virtual constant propagations"); 117 STATISTIC(NumVirtConstProp, "Number of virtual constant propagations"); 118 119 static cl::opt<PassSummaryAction> ClSummaryAction( 120 "wholeprogramdevirt-summary-action", 121 cl::desc("What to do with the summary when running this pass"), 122 cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"), 123 clEnumValN(PassSummaryAction::Import, "import", 124 "Import typeid resolutions from summary and globals"), 125 clEnumValN(PassSummaryAction::Export, "export", 126 "Export typeid resolutions to summary and globals")), 127 cl::Hidden); 128 129 static cl::opt<std::string> ClReadSummary( 130 "wholeprogramdevirt-read-summary", 131 cl::desc( 132 "Read summary from given bitcode or YAML file before running pass"), 133 cl::Hidden); 134 135 static cl::opt<std::string> ClWriteSummary( 136 "wholeprogramdevirt-write-summary", 137 cl::desc("Write summary to given bitcode or YAML file after running pass. " 138 "Output file format is deduced from extension: *.bc means writing " 139 "bitcode, otherwise YAML"), 140 cl::Hidden); 141 142 static cl::opt<unsigned> 143 ClThreshold("wholeprogramdevirt-branch-funnel-threshold", cl::Hidden, 144 cl::init(10), 145 cl::desc("Maximum number of call targets per " 146 "call site to enable branch funnels")); 147 148 static cl::opt<bool> 149 PrintSummaryDevirt("wholeprogramdevirt-print-index-based", cl::Hidden, 150 cl::desc("Print index-based devirtualization messages")); 151 152 /// Provide a way to force enable whole program visibility in tests. 153 /// This is needed to support legacy tests that don't contain 154 /// !vcall_visibility metadata (the mere presense of type tests 155 /// previously implied hidden visibility). 156 static cl::opt<bool> 157 WholeProgramVisibility("whole-program-visibility", cl::Hidden, 158 cl::desc("Enable whole program visibility")); 159 160 /// Provide a way to force disable whole program for debugging or workarounds, 161 /// when enabled via the linker. 162 static cl::opt<bool> DisableWholeProgramVisibility( 163 "disable-whole-program-visibility", cl::Hidden, 164 cl::desc("Disable whole program visibility (overrides enabling options)")); 165 166 /// Provide way to prevent certain function from being devirtualized 167 static cl::list<std::string> 168 SkipFunctionNames("wholeprogramdevirt-skip", 169 cl::desc("Prevent function(s) from being devirtualized"), 170 cl::Hidden, cl::CommaSeparated); 171 172 /// Mechanism to add runtime checking of devirtualization decisions, optionally 173 /// trapping or falling back to indirect call on any that are not correct. 174 /// Trapping mode is useful for debugging undefined behavior leading to failures 175 /// with WPD. Fallback mode is useful for ensuring safety when whole program 176 /// visibility may be compromised. 177 enum WPDCheckMode { None, Trap, Fallback }; 178 static cl::opt<WPDCheckMode> DevirtCheckMode( 179 "wholeprogramdevirt-check", cl::Hidden, 180 cl::desc("Type of checking for incorrect devirtualizations"), 181 cl::values(clEnumValN(WPDCheckMode::None, "none", "No checking"), 182 clEnumValN(WPDCheckMode::Trap, "trap", "Trap when incorrect"), 183 clEnumValN(WPDCheckMode::Fallback, "fallback", 184 "Fallback to indirect when incorrect"))); 185 186 namespace { 187 struct PatternList { 188 std::vector<GlobPattern> Patterns; 189 template <class T> void init(const T &StringList) { 190 for (const auto &S : StringList) 191 if (Expected<GlobPattern> Pat = GlobPattern::create(S)) 192 Patterns.push_back(std::move(*Pat)); 193 } 194 bool match(StringRef S) { 195 for (const GlobPattern &P : Patterns) 196 if (P.match(S)) 197 return true; 198 return false; 199 } 200 }; 201 } // namespace 202 203 // Find the minimum offset that we may store a value of size Size bits at. If 204 // IsAfter is set, look for an offset before the object, otherwise look for an 205 // offset after the object. 206 uint64_t 207 wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets, 208 bool IsAfter, uint64_t Size) { 209 // Find a minimum offset taking into account only vtable sizes. 210 uint64_t MinByte = 0; 211 for (const VirtualCallTarget &Target : Targets) { 212 if (IsAfter) 213 MinByte = std::max(MinByte, Target.minAfterBytes()); 214 else 215 MinByte = std::max(MinByte, Target.minBeforeBytes()); 216 } 217 218 // Build a vector of arrays of bytes covering, for each target, a slice of the 219 // used region (see AccumBitVector::BytesUsed in 220 // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively, 221 // this aligns the used regions to start at MinByte. 222 // 223 // In this example, A, B and C are vtables, # is a byte already allocated for 224 // a virtual function pointer, AAAA... (etc.) are the used regions for the 225 // vtables and Offset(X) is the value computed for the Offset variable below 226 // for X. 227 // 228 // Offset(A) 229 // | | 230 // |MinByte 231 // A: ################AAAAAAAA|AAAAAAAA 232 // B: ########BBBBBBBBBBBBBBBB|BBBB 233 // C: ########################|CCCCCCCCCCCCCCCC 234 // | Offset(B) | 235 // 236 // This code produces the slices of A, B and C that appear after the divider 237 // at MinByte. 238 std::vector<ArrayRef<uint8_t>> Used; 239 for (const VirtualCallTarget &Target : Targets) { 240 ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed 241 : Target.TM->Bits->Before.BytesUsed; 242 uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes() 243 : MinByte - Target.minBeforeBytes(); 244 245 // Disregard used regions that are smaller than Offset. These are 246 // effectively all-free regions that do not need to be checked. 247 if (VTUsed.size() > Offset) 248 Used.push_back(VTUsed.slice(Offset)); 249 } 250 251 if (Size == 1) { 252 // Find a free bit in each member of Used. 253 for (unsigned I = 0;; ++I) { 254 uint8_t BitsUsed = 0; 255 for (auto &&B : Used) 256 if (I < B.size()) 257 BitsUsed |= B[I]; 258 if (BitsUsed != 0xff) 259 return (MinByte + I) * 8 + llvm::countr_zero(uint8_t(~BitsUsed)); 260 } 261 } else { 262 // Find a free (Size/8) byte region in each member of Used. 263 // FIXME: see if alignment helps. 264 for (unsigned I = 0;; ++I) { 265 for (auto &&B : Used) { 266 unsigned Byte = 0; 267 while ((I + Byte) < B.size() && Byte < (Size / 8)) { 268 if (B[I + Byte]) 269 goto NextI; 270 ++Byte; 271 } 272 } 273 return (MinByte + I) * 8; 274 NextI:; 275 } 276 } 277 } 278 279 void wholeprogramdevirt::setBeforeReturnValues( 280 MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore, 281 unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { 282 if (BitWidth == 1) 283 OffsetByte = -(AllocBefore / 8 + 1); 284 else 285 OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8); 286 OffsetBit = AllocBefore % 8; 287 288 for (VirtualCallTarget &Target : Targets) { 289 if (BitWidth == 1) 290 Target.setBeforeBit(AllocBefore); 291 else 292 Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8); 293 } 294 } 295 296 void wholeprogramdevirt::setAfterReturnValues( 297 MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter, 298 unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) { 299 if (BitWidth == 1) 300 OffsetByte = AllocAfter / 8; 301 else 302 OffsetByte = (AllocAfter + 7) / 8; 303 OffsetBit = AllocAfter % 8; 304 305 for (VirtualCallTarget &Target : Targets) { 306 if (BitWidth == 1) 307 Target.setAfterBit(AllocAfter); 308 else 309 Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8); 310 } 311 } 312 313 VirtualCallTarget::VirtualCallTarget(GlobalValue *Fn, const TypeMemberInfo *TM) 314 : Fn(Fn), TM(TM), 315 IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()), 316 WasDevirt(false) {} 317 318 namespace { 319 320 // A slot in a set of virtual tables. The TypeID identifies the set of virtual 321 // tables, and the ByteOffset is the offset in bytes from the address point to 322 // the virtual function pointer. 323 struct VTableSlot { 324 Metadata *TypeID; 325 uint64_t ByteOffset; 326 }; 327 328 } // end anonymous namespace 329 330 namespace llvm { 331 332 template <> struct DenseMapInfo<VTableSlot> { 333 static VTableSlot getEmptyKey() { 334 return {DenseMapInfo<Metadata *>::getEmptyKey(), 335 DenseMapInfo<uint64_t>::getEmptyKey()}; 336 } 337 static VTableSlot getTombstoneKey() { 338 return {DenseMapInfo<Metadata *>::getTombstoneKey(), 339 DenseMapInfo<uint64_t>::getTombstoneKey()}; 340 } 341 static unsigned getHashValue(const VTableSlot &I) { 342 return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^ 343 DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset); 344 } 345 static bool isEqual(const VTableSlot &LHS, 346 const VTableSlot &RHS) { 347 return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset; 348 } 349 }; 350 351 template <> struct DenseMapInfo<VTableSlotSummary> { 352 static VTableSlotSummary getEmptyKey() { 353 return {DenseMapInfo<StringRef>::getEmptyKey(), 354 DenseMapInfo<uint64_t>::getEmptyKey()}; 355 } 356 static VTableSlotSummary getTombstoneKey() { 357 return {DenseMapInfo<StringRef>::getTombstoneKey(), 358 DenseMapInfo<uint64_t>::getTombstoneKey()}; 359 } 360 static unsigned getHashValue(const VTableSlotSummary &I) { 361 return DenseMapInfo<StringRef>::getHashValue(I.TypeID) ^ 362 DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset); 363 } 364 static bool isEqual(const VTableSlotSummary &LHS, 365 const VTableSlotSummary &RHS) { 366 return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset; 367 } 368 }; 369 370 } // end namespace llvm 371 372 namespace { 373 374 // Returns true if the function must be unreachable based on ValueInfo. 375 // 376 // In particular, identifies a function as unreachable in the following 377 // conditions 378 // 1) All summaries are live. 379 // 2) All function summaries indicate it's unreachable 380 // 3) There is no non-function with the same GUID (which is rare) 381 bool mustBeUnreachableFunction(ValueInfo TheFnVI) { 382 if ((!TheFnVI) || TheFnVI.getSummaryList().empty()) { 383 // Returns false if ValueInfo is absent, or the summary list is empty 384 // (e.g., function declarations). 385 return false; 386 } 387 388 for (const auto &Summary : TheFnVI.getSummaryList()) { 389 // Conservatively returns false if any non-live functions are seen. 390 // In general either all summaries should be live or all should be dead. 391 if (!Summary->isLive()) 392 return false; 393 if (auto *FS = dyn_cast<FunctionSummary>(Summary->getBaseObject())) { 394 if (!FS->fflags().MustBeUnreachable) 395 return false; 396 } 397 // Be conservative if a non-function has the same GUID (which is rare). 398 else 399 return false; 400 } 401 // All function summaries are live and all of them agree that the function is 402 // unreachble. 403 return true; 404 } 405 406 // A virtual call site. VTable is the loaded virtual table pointer, and CS is 407 // the indirect virtual call. 408 struct VirtualCallSite { 409 Value *VTable = nullptr; 410 CallBase &CB; 411 412 // If non-null, this field points to the associated unsafe use count stored in 413 // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description 414 // of that field for details. 415 unsigned *NumUnsafeUses = nullptr; 416 417 void 418 emitRemark(const StringRef OptName, const StringRef TargetName, 419 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) { 420 Function *F = CB.getCaller(); 421 DebugLoc DLoc = CB.getDebugLoc(); 422 BasicBlock *Block = CB.getParent(); 423 424 using namespace ore; 425 OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block) 426 << NV("Optimization", OptName) 427 << ": devirtualized a call to " 428 << NV("FunctionName", TargetName)); 429 } 430 431 void replaceAndErase( 432 const StringRef OptName, const StringRef TargetName, bool RemarksEnabled, 433 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, 434 Value *New) { 435 if (RemarksEnabled) 436 emitRemark(OptName, TargetName, OREGetter); 437 CB.replaceAllUsesWith(New); 438 if (auto *II = dyn_cast<InvokeInst>(&CB)) { 439 BranchInst::Create(II->getNormalDest(), &CB); 440 II->getUnwindDest()->removePredecessor(II->getParent()); 441 } 442 CB.eraseFromParent(); 443 // This use is no longer unsafe. 444 if (NumUnsafeUses) 445 --*NumUnsafeUses; 446 } 447 }; 448 449 // Call site information collected for a specific VTableSlot and possibly a list 450 // of constant integer arguments. The grouping by arguments is handled by the 451 // VTableSlotInfo class. 452 struct CallSiteInfo { 453 /// The set of call sites for this slot. Used during regular LTO and the 454 /// import phase of ThinLTO (as well as the export phase of ThinLTO for any 455 /// call sites that appear in the merged module itself); in each of these 456 /// cases we are directly operating on the call sites at the IR level. 457 std::vector<VirtualCallSite> CallSites; 458 459 /// Whether all call sites represented by this CallSiteInfo, including those 460 /// in summaries, have been devirtualized. This starts off as true because a 461 /// default constructed CallSiteInfo represents no call sites. 462 bool AllCallSitesDevirted = true; 463 464 // These fields are used during the export phase of ThinLTO and reflect 465 // information collected from function summaries. 466 467 /// Whether any function summary contains an llvm.assume(llvm.type.test) for 468 /// this slot. 469 bool SummaryHasTypeTestAssumeUsers = false; 470 471 /// CFI-specific: a vector containing the list of function summaries that use 472 /// the llvm.type.checked.load intrinsic and therefore will require 473 /// resolutions for llvm.type.test in order to implement CFI checks if 474 /// devirtualization was unsuccessful. If devirtualization was successful, the 475 /// pass will clear this vector by calling markDevirt(). If at the end of the 476 /// pass the vector is non-empty, we will need to add a use of llvm.type.test 477 /// to each of the function summaries in the vector. 478 std::vector<FunctionSummary *> SummaryTypeCheckedLoadUsers; 479 std::vector<FunctionSummary *> SummaryTypeTestAssumeUsers; 480 481 bool isExported() const { 482 return SummaryHasTypeTestAssumeUsers || 483 !SummaryTypeCheckedLoadUsers.empty(); 484 } 485 486 void addSummaryTypeCheckedLoadUser(FunctionSummary *FS) { 487 SummaryTypeCheckedLoadUsers.push_back(FS); 488 AllCallSitesDevirted = false; 489 } 490 491 void addSummaryTypeTestAssumeUser(FunctionSummary *FS) { 492 SummaryTypeTestAssumeUsers.push_back(FS); 493 SummaryHasTypeTestAssumeUsers = true; 494 AllCallSitesDevirted = false; 495 } 496 497 void markDevirt() { 498 AllCallSitesDevirted = true; 499 500 // As explained in the comment for SummaryTypeCheckedLoadUsers. 501 SummaryTypeCheckedLoadUsers.clear(); 502 } 503 }; 504 505 // Call site information collected for a specific VTableSlot. 506 struct VTableSlotInfo { 507 // The set of call sites which do not have all constant integer arguments 508 // (excluding "this"). 509 CallSiteInfo CSInfo; 510 511 // The set of call sites with all constant integer arguments (excluding 512 // "this"), grouped by argument list. 513 std::map<std::vector<uint64_t>, CallSiteInfo> ConstCSInfo; 514 515 void addCallSite(Value *VTable, CallBase &CB, unsigned *NumUnsafeUses); 516 517 private: 518 CallSiteInfo &findCallSiteInfo(CallBase &CB); 519 }; 520 521 CallSiteInfo &VTableSlotInfo::findCallSiteInfo(CallBase &CB) { 522 std::vector<uint64_t> Args; 523 auto *CBType = dyn_cast<IntegerType>(CB.getType()); 524 if (!CBType || CBType->getBitWidth() > 64 || CB.arg_empty()) 525 return CSInfo; 526 for (auto &&Arg : drop_begin(CB.args())) { 527 auto *CI = dyn_cast<ConstantInt>(Arg); 528 if (!CI || CI->getBitWidth() > 64) 529 return CSInfo; 530 Args.push_back(CI->getZExtValue()); 531 } 532 return ConstCSInfo[Args]; 533 } 534 535 void VTableSlotInfo::addCallSite(Value *VTable, CallBase &CB, 536 unsigned *NumUnsafeUses) { 537 auto &CSI = findCallSiteInfo(CB); 538 CSI.AllCallSitesDevirted = false; 539 CSI.CallSites.push_back({VTable, CB, NumUnsafeUses}); 540 } 541 542 struct DevirtModule { 543 Module &M; 544 function_ref<AAResults &(Function &)> AARGetter; 545 function_ref<DominatorTree &(Function &)> LookupDomTree; 546 547 ModuleSummaryIndex *ExportSummary; 548 const ModuleSummaryIndex *ImportSummary; 549 550 IntegerType *Int8Ty; 551 PointerType *Int8PtrTy; 552 IntegerType *Int32Ty; 553 IntegerType *Int64Ty; 554 IntegerType *IntPtrTy; 555 /// Sizeless array type, used for imported vtables. This provides a signal 556 /// to analyzers that these imports may alias, as they do for example 557 /// when multiple unique return values occur in the same vtable. 558 ArrayType *Int8Arr0Ty; 559 560 bool RemarksEnabled; 561 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter; 562 563 MapVector<VTableSlot, VTableSlotInfo> CallSlots; 564 565 // Calls that have already been optimized. We may add a call to multiple 566 // VTableSlotInfos if vtable loads are coalesced and need to make sure not to 567 // optimize a call more than once. 568 SmallPtrSet<CallBase *, 8> OptimizedCalls; 569 570 // Store calls that had their ptrauth bundle removed. They are to be deleted 571 // at the end of the optimization. 572 SmallVector<CallBase *, 8> CallsWithPtrAuthBundleRemoved; 573 574 // This map keeps track of the number of "unsafe" uses of a loaded function 575 // pointer. The key is the associated llvm.type.test intrinsic call generated 576 // by this pass. An unsafe use is one that calls the loaded function pointer 577 // directly. Every time we eliminate an unsafe use (for example, by 578 // devirtualizing it or by applying virtual constant propagation), we 579 // decrement the value stored in this map. If a value reaches zero, we can 580 // eliminate the type check by RAUWing the associated llvm.type.test call with 581 // true. 582 std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest; 583 PatternList FunctionsToSkip; 584 585 DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter, 586 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, 587 function_ref<DominatorTree &(Function &)> LookupDomTree, 588 ModuleSummaryIndex *ExportSummary, 589 const ModuleSummaryIndex *ImportSummary) 590 : M(M), AARGetter(AARGetter), LookupDomTree(LookupDomTree), 591 ExportSummary(ExportSummary), ImportSummary(ImportSummary), 592 Int8Ty(Type::getInt8Ty(M.getContext())), 593 Int8PtrTy(Type::getInt8PtrTy(M.getContext())), 594 Int32Ty(Type::getInt32Ty(M.getContext())), 595 Int64Ty(Type::getInt64Ty(M.getContext())), 596 IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)), 597 Int8Arr0Ty(ArrayType::get(Type::getInt8Ty(M.getContext()), 0)), 598 RemarksEnabled(areRemarksEnabled()), OREGetter(OREGetter) { 599 assert(!(ExportSummary && ImportSummary)); 600 FunctionsToSkip.init(SkipFunctionNames); 601 } 602 603 bool areRemarksEnabled(); 604 605 void 606 scanTypeTestUsers(Function *TypeTestFunc, 607 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); 608 void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc); 609 610 void buildTypeIdentifierMap( 611 std::vector<VTableBits> &Bits, 612 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap); 613 614 bool 615 tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot, 616 const std::set<TypeMemberInfo> &TypeMemberInfos, 617 uint64_t ByteOffset, 618 ModuleSummaryIndex *ExportSummary); 619 620 void applySingleImplDevirt(VTableSlotInfo &SlotInfo, Constant *TheFn, 621 bool &IsExported); 622 bool trySingleImplDevirt(ModuleSummaryIndex *ExportSummary, 623 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 624 VTableSlotInfo &SlotInfo, 625 WholeProgramDevirtResolution *Res); 626 627 void applyICallBranchFunnel(VTableSlotInfo &SlotInfo, Constant *JT, 628 bool &IsExported); 629 void tryICallBranchFunnel(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 630 VTableSlotInfo &SlotInfo, 631 WholeProgramDevirtResolution *Res, VTableSlot Slot); 632 633 bool tryEvaluateFunctionsWithArgs( 634 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 635 ArrayRef<uint64_t> Args); 636 637 void applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 638 uint64_t TheRetVal); 639 bool tryUniformRetValOpt(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 640 CallSiteInfo &CSInfo, 641 WholeProgramDevirtResolution::ByArg *Res); 642 643 // Returns the global symbol name that is used to export information about the 644 // given vtable slot and list of arguments. 645 std::string getGlobalName(VTableSlot Slot, ArrayRef<uint64_t> Args, 646 StringRef Name); 647 648 bool shouldExportConstantsAsAbsoluteSymbols(); 649 650 // This function is called during the export phase to create a symbol 651 // definition containing information about the given vtable slot and list of 652 // arguments. 653 void exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name, 654 Constant *C); 655 void exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name, 656 uint32_t Const, uint32_t &Storage); 657 658 // This function is called during the import phase to create a reference to 659 // the symbol definition created during the export phase. 660 Constant *importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, 661 StringRef Name); 662 Constant *importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, 663 StringRef Name, IntegerType *IntTy, 664 uint32_t Storage); 665 666 Constant *getMemberAddr(const TypeMemberInfo *M); 667 668 void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne, 669 Constant *UniqueMemberAddr); 670 bool tryUniqueRetValOpt(unsigned BitWidth, 671 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 672 CallSiteInfo &CSInfo, 673 WholeProgramDevirtResolution::ByArg *Res, 674 VTableSlot Slot, ArrayRef<uint64_t> Args); 675 676 void applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, 677 Constant *Byte, Constant *Bit); 678 bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot, 679 VTableSlotInfo &SlotInfo, 680 WholeProgramDevirtResolution *Res, VTableSlot Slot); 681 682 void rebuildGlobal(VTableBits &B); 683 684 // Apply the summary resolution for Slot to all virtual calls in SlotInfo. 685 void importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo); 686 687 // If we were able to eliminate all unsafe uses for a type checked load, 688 // eliminate the associated type tests by replacing them with true. 689 void removeRedundantTypeTests(); 690 691 bool run(); 692 693 // Look up the corresponding ValueInfo entry of `TheFn` in `ExportSummary`. 694 // 695 // Caller guarantees that `ExportSummary` is not nullptr. 696 static ValueInfo lookUpFunctionValueInfo(Function *TheFn, 697 ModuleSummaryIndex *ExportSummary); 698 699 // Returns true if the function definition must be unreachable. 700 // 701 // Note if this helper function returns true, `F` is guaranteed 702 // to be unreachable; if it returns false, `F` might still 703 // be unreachable but not covered by this helper function. 704 // 705 // Implementation-wise, if function definition is present, IR is analyzed; if 706 // not, look up function flags from ExportSummary as a fallback. 707 static bool mustBeUnreachableFunction(Function *const F, 708 ModuleSummaryIndex *ExportSummary); 709 710 // Lower the module using the action and summary passed as command line 711 // arguments. For testing purposes only. 712 static bool 713 runForTesting(Module &M, function_ref<AAResults &(Function &)> AARGetter, 714 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, 715 function_ref<DominatorTree &(Function &)> LookupDomTree); 716 }; 717 718 struct DevirtIndex { 719 ModuleSummaryIndex &ExportSummary; 720 // The set in which to record GUIDs exported from their module by 721 // devirtualization, used by client to ensure they are not internalized. 722 std::set<GlobalValue::GUID> &ExportedGUIDs; 723 // A map in which to record the information necessary to locate the WPD 724 // resolution for local targets in case they are exported by cross module 725 // importing. 726 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap; 727 728 MapVector<VTableSlotSummary, VTableSlotInfo> CallSlots; 729 730 PatternList FunctionsToSkip; 731 732 DevirtIndex( 733 ModuleSummaryIndex &ExportSummary, 734 std::set<GlobalValue::GUID> &ExportedGUIDs, 735 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) 736 : ExportSummary(ExportSummary), ExportedGUIDs(ExportedGUIDs), 737 LocalWPDTargetsMap(LocalWPDTargetsMap) { 738 FunctionsToSkip.init(SkipFunctionNames); 739 } 740 741 bool tryFindVirtualCallTargets(std::vector<ValueInfo> &TargetsForSlot, 742 const TypeIdCompatibleVtableInfo TIdInfo, 743 uint64_t ByteOffset); 744 745 bool trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot, 746 VTableSlotSummary &SlotSummary, 747 VTableSlotInfo &SlotInfo, 748 WholeProgramDevirtResolution *Res, 749 std::set<ValueInfo> &DevirtTargets); 750 751 void run(); 752 }; 753 } // end anonymous namespace 754 755 PreservedAnalyses WholeProgramDevirtPass::run(Module &M, 756 ModuleAnalysisManager &AM) { 757 auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 758 auto AARGetter = [&](Function &F) -> AAResults & { 759 return FAM.getResult<AAManager>(F); 760 }; 761 auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & { 762 return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); 763 }; 764 auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & { 765 return FAM.getResult<DominatorTreeAnalysis>(F); 766 }; 767 if (UseCommandLine) { 768 if (!DevirtModule::runForTesting(M, AARGetter, OREGetter, LookupDomTree)) 769 return PreservedAnalyses::all(); 770 return PreservedAnalyses::none(); 771 } 772 if (!DevirtModule(M, AARGetter, OREGetter, LookupDomTree, ExportSummary, 773 ImportSummary) 774 .run()) 775 return PreservedAnalyses::all(); 776 return PreservedAnalyses::none(); 777 } 778 779 namespace llvm { 780 // Enable whole program visibility if enabled by client (e.g. linker) or 781 // internal option, and not force disabled. 782 bool hasWholeProgramVisibility(bool WholeProgramVisibilityEnabledInLTO) { 783 return (WholeProgramVisibilityEnabledInLTO || WholeProgramVisibility) && 784 !DisableWholeProgramVisibility; 785 } 786 787 /// If whole program visibility asserted, then upgrade all public vcall 788 /// visibility metadata on vtable definitions to linkage unit visibility in 789 /// Module IR (for regular or hybrid LTO). 790 void updateVCallVisibilityInModule( 791 Module &M, bool WholeProgramVisibilityEnabledInLTO, 792 const DenseSet<GlobalValue::GUID> &DynamicExportSymbols) { 793 if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) 794 return; 795 for (GlobalVariable &GV : M.globals()) { 796 // Add linkage unit visibility to any variable with type metadata, which are 797 // the vtable definitions. We won't have an existing vcall_visibility 798 // metadata on vtable definitions with public visibility. 799 if (GV.hasMetadata(LLVMContext::MD_type) && 800 GV.getVCallVisibility() == GlobalObject::VCallVisibilityPublic && 801 // Don't upgrade the visibility for symbols exported to the dynamic 802 // linker, as we have no information on their eventual use. 803 !DynamicExportSymbols.count(GV.getGUID())) 804 GV.setVCallVisibilityMetadata(GlobalObject::VCallVisibilityLinkageUnit); 805 } 806 } 807 808 void updatePublicTypeTestCalls(Module &M, 809 bool WholeProgramVisibilityEnabledInLTO) { 810 Function *PublicTypeTestFunc = 811 M.getFunction(Intrinsic::getName(Intrinsic::public_type_test)); 812 if (!PublicTypeTestFunc) 813 return; 814 if (hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) { 815 Function *TypeTestFunc = 816 Intrinsic::getDeclaration(&M, Intrinsic::type_test); 817 for (Use &U : make_early_inc_range(PublicTypeTestFunc->uses())) { 818 auto *CI = cast<CallInst>(U.getUser()); 819 auto *NewCI = CallInst::Create( 820 TypeTestFunc, {CI->getArgOperand(0), CI->getArgOperand(1)}, 821 std::nullopt, "", CI); 822 CI->replaceAllUsesWith(NewCI); 823 CI->eraseFromParent(); 824 } 825 } else { 826 auto *True = ConstantInt::getTrue(M.getContext()); 827 for (Use &U : make_early_inc_range(PublicTypeTestFunc->uses())) { 828 auto *CI = cast<CallInst>(U.getUser()); 829 CI->replaceAllUsesWith(True); 830 CI->eraseFromParent(); 831 } 832 } 833 } 834 835 /// If whole program visibility asserted, then upgrade all public vcall 836 /// visibility metadata on vtable definition summaries to linkage unit 837 /// visibility in Module summary index (for ThinLTO). 838 void updateVCallVisibilityInIndex( 839 ModuleSummaryIndex &Index, bool WholeProgramVisibilityEnabledInLTO, 840 const DenseSet<GlobalValue::GUID> &DynamicExportSymbols) { 841 if (!hasWholeProgramVisibility(WholeProgramVisibilityEnabledInLTO)) 842 return; 843 for (auto &P : Index) { 844 // Don't upgrade the visibility for symbols exported to the dynamic 845 // linker, as we have no information on their eventual use. 846 if (DynamicExportSymbols.count(P.first)) 847 continue; 848 for (auto &S : P.second.SummaryList) { 849 auto *GVar = dyn_cast<GlobalVarSummary>(S.get()); 850 if (!GVar || 851 GVar->getVCallVisibility() != GlobalObject::VCallVisibilityPublic) 852 continue; 853 GVar->setVCallVisibility(GlobalObject::VCallVisibilityLinkageUnit); 854 } 855 } 856 } 857 858 void runWholeProgramDevirtOnIndex( 859 ModuleSummaryIndex &Summary, std::set<GlobalValue::GUID> &ExportedGUIDs, 860 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) { 861 DevirtIndex(Summary, ExportedGUIDs, LocalWPDTargetsMap).run(); 862 } 863 864 void updateIndexWPDForExports( 865 ModuleSummaryIndex &Summary, 866 function_ref<bool(StringRef, ValueInfo)> isExported, 867 std::map<ValueInfo, std::vector<VTableSlotSummary>> &LocalWPDTargetsMap) { 868 for (auto &T : LocalWPDTargetsMap) { 869 auto &VI = T.first; 870 // This was enforced earlier during trySingleImplDevirt. 871 assert(VI.getSummaryList().size() == 1 && 872 "Devirt of local target has more than one copy"); 873 auto &S = VI.getSummaryList()[0]; 874 if (!isExported(S->modulePath(), VI)) 875 continue; 876 877 // It's been exported by a cross module import. 878 for (auto &SlotSummary : T.second) { 879 auto *TIdSum = Summary.getTypeIdSummary(SlotSummary.TypeID); 880 assert(TIdSum); 881 auto WPDRes = TIdSum->WPDRes.find(SlotSummary.ByteOffset); 882 assert(WPDRes != TIdSum->WPDRes.end()); 883 WPDRes->second.SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal( 884 WPDRes->second.SingleImplName, 885 Summary.getModuleHash(S->modulePath())); 886 } 887 } 888 } 889 890 } // end namespace llvm 891 892 static Error checkCombinedSummaryForTesting(ModuleSummaryIndex *Summary) { 893 // Check that summary index contains regular LTO module when performing 894 // export to prevent occasional use of index from pure ThinLTO compilation 895 // (-fno-split-lto-module). This kind of summary index is passed to 896 // DevirtIndex::run, not to DevirtModule::run used by opt/runForTesting. 897 const auto &ModPaths = Summary->modulePaths(); 898 if (ClSummaryAction != PassSummaryAction::Import && 899 !ModPaths.contains(ModuleSummaryIndex::getRegularLTOModuleName())) 900 return createStringError( 901 errc::invalid_argument, 902 "combined summary should contain Regular LTO module"); 903 return ErrorSuccess(); 904 } 905 906 bool DevirtModule::runForTesting( 907 Module &M, function_ref<AAResults &(Function &)> AARGetter, 908 function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, 909 function_ref<DominatorTree &(Function &)> LookupDomTree) { 910 std::unique_ptr<ModuleSummaryIndex> Summary = 911 std::make_unique<ModuleSummaryIndex>(/*HaveGVs=*/false); 912 913 // Handle the command-line summary arguments. This code is for testing 914 // purposes only, so we handle errors directly. 915 if (!ClReadSummary.empty()) { 916 ExitOnError ExitOnErr("-wholeprogramdevirt-read-summary: " + ClReadSummary + 917 ": "); 918 auto ReadSummaryFile = 919 ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary))); 920 if (Expected<std::unique_ptr<ModuleSummaryIndex>> SummaryOrErr = 921 getModuleSummaryIndex(*ReadSummaryFile)) { 922 Summary = std::move(*SummaryOrErr); 923 ExitOnErr(checkCombinedSummaryForTesting(Summary.get())); 924 } else { 925 // Try YAML if we've failed with bitcode. 926 consumeError(SummaryOrErr.takeError()); 927 yaml::Input In(ReadSummaryFile->getBuffer()); 928 In >> *Summary; 929 ExitOnErr(errorCodeToError(In.error())); 930 } 931 } 932 933 bool Changed = 934 DevirtModule(M, AARGetter, OREGetter, LookupDomTree, 935 ClSummaryAction == PassSummaryAction::Export ? Summary.get() 936 : nullptr, 937 ClSummaryAction == PassSummaryAction::Import ? Summary.get() 938 : nullptr) 939 .run(); 940 941 if (!ClWriteSummary.empty()) { 942 ExitOnError ExitOnErr( 943 "-wholeprogramdevirt-write-summary: " + ClWriteSummary + ": "); 944 std::error_code EC; 945 if (StringRef(ClWriteSummary).endswith(".bc")) { 946 raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_None); 947 ExitOnErr(errorCodeToError(EC)); 948 writeIndexToFile(*Summary, OS); 949 } else { 950 raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_TextWithCRLF); 951 ExitOnErr(errorCodeToError(EC)); 952 yaml::Output Out(OS); 953 Out << *Summary; 954 } 955 } 956 957 return Changed; 958 } 959 960 void DevirtModule::buildTypeIdentifierMap( 961 std::vector<VTableBits> &Bits, 962 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { 963 DenseMap<GlobalVariable *, VTableBits *> GVToBits; 964 Bits.reserve(M.global_size()); 965 SmallVector<MDNode *, 2> Types; 966 for (GlobalVariable &GV : M.globals()) { 967 Types.clear(); 968 GV.getMetadata(LLVMContext::MD_type, Types); 969 if (GV.isDeclaration() || Types.empty()) 970 continue; 971 972 VTableBits *&BitsPtr = GVToBits[&GV]; 973 if (!BitsPtr) { 974 Bits.emplace_back(); 975 Bits.back().GV = &GV; 976 Bits.back().ObjectSize = 977 M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType()); 978 BitsPtr = &Bits.back(); 979 } 980 981 for (MDNode *Type : Types) { 982 auto TypeID = Type->getOperand(1).get(); 983 984 uint64_t Offset = 985 cast<ConstantInt>( 986 cast<ConstantAsMetadata>(Type->getOperand(0))->getValue()) 987 ->getZExtValue(); 988 989 TypeIdMap[TypeID].insert({BitsPtr, Offset}); 990 } 991 } 992 } 993 994 bool DevirtModule::tryFindVirtualCallTargets( 995 std::vector<VirtualCallTarget> &TargetsForSlot, 996 const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset, 997 ModuleSummaryIndex *ExportSummary) { 998 for (const TypeMemberInfo &TM : TypeMemberInfos) { 999 if (!TM.Bits->GV->isConstant()) 1000 return false; 1001 1002 // We cannot perform whole program devirtualization analysis on a vtable 1003 // with public LTO visibility. 1004 if (TM.Bits->GV->getVCallVisibility() == 1005 GlobalObject::VCallVisibilityPublic) 1006 return false; 1007 1008 Constant *Ptr = getPointerAtOffset(TM.Bits->GV->getInitializer(), 1009 TM.Offset + ByteOffset, M, TM.Bits->GV); 1010 if (!Ptr) 1011 return false; 1012 1013 auto C = Ptr->stripPointerCasts(); 1014 // Make sure this is a function or alias to a function. 1015 auto Fn = dyn_cast<Function>(C); 1016 auto A = dyn_cast<GlobalAlias>(C); 1017 if (!Fn && A) 1018 Fn = dyn_cast<Function>(A->getAliasee()); 1019 1020 if (!Fn) 1021 return false; 1022 1023 if (FunctionsToSkip.match(Fn->getName())) 1024 return false; 1025 1026 // We can disregard __cxa_pure_virtual as a possible call target, as 1027 // calls to pure virtuals are UB. 1028 if (Fn->getName() == "__cxa_pure_virtual") 1029 continue; 1030 1031 // We can disregard unreachable functions as possible call targets, as 1032 // unreachable functions shouldn't be called. 1033 if (mustBeUnreachableFunction(Fn, ExportSummary)) 1034 continue; 1035 1036 // Save the symbol used in the vtable to use as the devirtualization 1037 // target. 1038 auto GV = dyn_cast<GlobalValue>(C); 1039 assert(GV); 1040 TargetsForSlot.push_back({GV, &TM}); 1041 } 1042 1043 // Give up if we couldn't find any targets. 1044 return !TargetsForSlot.empty(); 1045 } 1046 1047 bool DevirtIndex::tryFindVirtualCallTargets( 1048 std::vector<ValueInfo> &TargetsForSlot, const TypeIdCompatibleVtableInfo TIdInfo, 1049 uint64_t ByteOffset) { 1050 for (const TypeIdOffsetVtableInfo &P : TIdInfo) { 1051 // Find a representative copy of the vtable initializer. 1052 // We can have multiple available_externally, linkonce_odr and weak_odr 1053 // vtable initializers. We can also have multiple external vtable 1054 // initializers in the case of comdats, which we cannot check here. 1055 // The linker should give an error in this case. 1056 // 1057 // Also, handle the case of same-named local Vtables with the same path 1058 // and therefore the same GUID. This can happen if there isn't enough 1059 // distinguishing path when compiling the source file. In that case we 1060 // conservatively return false early. 1061 const GlobalVarSummary *VS = nullptr; 1062 bool LocalFound = false; 1063 for (const auto &S : P.VTableVI.getSummaryList()) { 1064 if (GlobalValue::isLocalLinkage(S->linkage())) { 1065 if (LocalFound) 1066 return false; 1067 LocalFound = true; 1068 } 1069 auto *CurVS = cast<GlobalVarSummary>(S->getBaseObject()); 1070 if (!CurVS->vTableFuncs().empty() || 1071 // Previously clang did not attach the necessary type metadata to 1072 // available_externally vtables, in which case there would not 1073 // be any vtable functions listed in the summary and we need 1074 // to treat this case conservatively (in case the bitcode is old). 1075 // However, we will also not have any vtable functions in the 1076 // case of a pure virtual base class. In that case we do want 1077 // to set VS to avoid treating it conservatively. 1078 !GlobalValue::isAvailableExternallyLinkage(S->linkage())) { 1079 VS = CurVS; 1080 // We cannot perform whole program devirtualization analysis on a vtable 1081 // with public LTO visibility. 1082 if (VS->getVCallVisibility() == GlobalObject::VCallVisibilityPublic) 1083 return false; 1084 } 1085 } 1086 // There will be no VS if all copies are available_externally having no 1087 // type metadata. In that case we can't safely perform WPD. 1088 if (!VS) 1089 return false; 1090 if (!VS->isLive()) 1091 continue; 1092 for (auto VTP : VS->vTableFuncs()) { 1093 if (VTP.VTableOffset != P.AddressPointOffset + ByteOffset) 1094 continue; 1095 1096 if (mustBeUnreachableFunction(VTP.FuncVI)) 1097 continue; 1098 1099 TargetsForSlot.push_back(VTP.FuncVI); 1100 } 1101 } 1102 1103 // Give up if we couldn't find any targets. 1104 return !TargetsForSlot.empty(); 1105 } 1106 1107 void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, 1108 Constant *TheFn, bool &IsExported) { 1109 // Don't devirtualize function if we're told to skip it 1110 // in -wholeprogramdevirt-skip. 1111 if (FunctionsToSkip.match(TheFn->stripPointerCasts()->getName())) 1112 return; 1113 auto Apply = [&](CallSiteInfo &CSInfo) { 1114 for (auto &&VCallSite : CSInfo.CallSites) { 1115 if (!OptimizedCalls.insert(&VCallSite.CB).second) 1116 continue; 1117 1118 if (RemarksEnabled) 1119 VCallSite.emitRemark("single-impl", 1120 TheFn->stripPointerCasts()->getName(), OREGetter); 1121 NumSingleImpl++; 1122 auto &CB = VCallSite.CB; 1123 assert(!CB.getCalledFunction() && "devirtualizing direct call?"); 1124 IRBuilder<> Builder(&CB); 1125 Value *Callee = 1126 Builder.CreateBitCast(TheFn, CB.getCalledOperand()->getType()); 1127 1128 // If trap checking is enabled, add support to compare the virtual 1129 // function pointer to the devirtualized target. In case of a mismatch, 1130 // perform a debug trap. 1131 if (DevirtCheckMode == WPDCheckMode::Trap) { 1132 auto *Cond = Builder.CreateICmpNE(CB.getCalledOperand(), Callee); 1133 Instruction *ThenTerm = 1134 SplitBlockAndInsertIfThen(Cond, &CB, /*Unreachable=*/false); 1135 Builder.SetInsertPoint(ThenTerm); 1136 Function *TrapFn = Intrinsic::getDeclaration(&M, Intrinsic::debugtrap); 1137 auto *CallTrap = Builder.CreateCall(TrapFn); 1138 CallTrap->setDebugLoc(CB.getDebugLoc()); 1139 } 1140 1141 // If fallback checking is enabled, add support to compare the virtual 1142 // function pointer to the devirtualized target. In case of a mismatch, 1143 // fall back to indirect call. 1144 if (DevirtCheckMode == WPDCheckMode::Fallback) { 1145 MDNode *Weights = 1146 MDBuilder(M.getContext()).createBranchWeights((1U << 20) - 1, 1); 1147 // Version the indirect call site. If the called value is equal to the 1148 // given callee, 'NewInst' will be executed, otherwise the original call 1149 // site will be executed. 1150 CallBase &NewInst = versionCallSite(CB, Callee, Weights); 1151 NewInst.setCalledOperand(Callee); 1152 // Since the new call site is direct, we must clear metadata that 1153 // is only appropriate for indirect calls. This includes !prof and 1154 // !callees metadata. 1155 NewInst.setMetadata(LLVMContext::MD_prof, nullptr); 1156 NewInst.setMetadata(LLVMContext::MD_callees, nullptr); 1157 // Additionally, we should remove them from the fallback indirect call, 1158 // so that we don't attempt to perform indirect call promotion later. 1159 CB.setMetadata(LLVMContext::MD_prof, nullptr); 1160 CB.setMetadata(LLVMContext::MD_callees, nullptr); 1161 } 1162 1163 // In either trapping or non-checking mode, devirtualize original call. 1164 else { 1165 // Devirtualize unconditionally. 1166 CB.setCalledOperand(Callee); 1167 // Since the call site is now direct, we must clear metadata that 1168 // is only appropriate for indirect calls. This includes !prof and 1169 // !callees metadata. 1170 CB.setMetadata(LLVMContext::MD_prof, nullptr); 1171 CB.setMetadata(LLVMContext::MD_callees, nullptr); 1172 if (CB.getCalledOperand() && 1173 CB.getOperandBundle(LLVMContext::OB_ptrauth)) { 1174 auto *NewCS = 1175 CallBase::removeOperandBundle(&CB, LLVMContext::OB_ptrauth, &CB); 1176 CB.replaceAllUsesWith(NewCS); 1177 // Schedule for deletion at the end of pass run. 1178 CallsWithPtrAuthBundleRemoved.push_back(&CB); 1179 } 1180 } 1181 1182 // This use is no longer unsafe. 1183 if (VCallSite.NumUnsafeUses) 1184 --*VCallSite.NumUnsafeUses; 1185 } 1186 if (CSInfo.isExported()) 1187 IsExported = true; 1188 CSInfo.markDevirt(); 1189 }; 1190 Apply(SlotInfo.CSInfo); 1191 for (auto &P : SlotInfo.ConstCSInfo) 1192 Apply(P.second); 1193 } 1194 1195 static bool AddCalls(VTableSlotInfo &SlotInfo, const ValueInfo &Callee) { 1196 // We can't add calls if we haven't seen a definition 1197 if (Callee.getSummaryList().empty()) 1198 return false; 1199 1200 // Insert calls into the summary index so that the devirtualized targets 1201 // are eligible for import. 1202 // FIXME: Annotate type tests with hotness. For now, mark these as hot 1203 // to better ensure we have the opportunity to inline them. 1204 bool IsExported = false; 1205 auto &S = Callee.getSummaryList()[0]; 1206 CalleeInfo CI(CalleeInfo::HotnessType::Hot, /* RelBF = */ 0); 1207 auto AddCalls = [&](CallSiteInfo &CSInfo) { 1208 for (auto *FS : CSInfo.SummaryTypeCheckedLoadUsers) { 1209 FS->addCall({Callee, CI}); 1210 IsExported |= S->modulePath() != FS->modulePath(); 1211 } 1212 for (auto *FS : CSInfo.SummaryTypeTestAssumeUsers) { 1213 FS->addCall({Callee, CI}); 1214 IsExported |= S->modulePath() != FS->modulePath(); 1215 } 1216 }; 1217 AddCalls(SlotInfo.CSInfo); 1218 for (auto &P : SlotInfo.ConstCSInfo) 1219 AddCalls(P.second); 1220 return IsExported; 1221 } 1222 1223 bool DevirtModule::trySingleImplDevirt( 1224 ModuleSummaryIndex *ExportSummary, 1225 MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, 1226 WholeProgramDevirtResolution *Res) { 1227 // See if the program contains a single implementation of this virtual 1228 // function. 1229 auto *TheFn = TargetsForSlot[0].Fn; 1230 for (auto &&Target : TargetsForSlot) 1231 if (TheFn != Target.Fn) 1232 return false; 1233 1234 // If so, update each call site to call that implementation directly. 1235 if (RemarksEnabled || AreStatisticsEnabled()) 1236 TargetsForSlot[0].WasDevirt = true; 1237 1238 bool IsExported = false; 1239 applySingleImplDevirt(SlotInfo, TheFn, IsExported); 1240 if (!IsExported) 1241 return false; 1242 1243 // If the only implementation has local linkage, we must promote to external 1244 // to make it visible to thin LTO objects. We can only get here during the 1245 // ThinLTO export phase. 1246 if (TheFn->hasLocalLinkage()) { 1247 std::string NewName = (TheFn->getName() + ".llvm.merged").str(); 1248 1249 // Since we are renaming the function, any comdats with the same name must 1250 // also be renamed. This is required when targeting COFF, as the comdat name 1251 // must match one of the names of the symbols in the comdat. 1252 if (Comdat *C = TheFn->getComdat()) { 1253 if (C->getName() == TheFn->getName()) { 1254 Comdat *NewC = M.getOrInsertComdat(NewName); 1255 NewC->setSelectionKind(C->getSelectionKind()); 1256 for (GlobalObject &GO : M.global_objects()) 1257 if (GO.getComdat() == C) 1258 GO.setComdat(NewC); 1259 } 1260 } 1261 1262 TheFn->setLinkage(GlobalValue::ExternalLinkage); 1263 TheFn->setVisibility(GlobalValue::HiddenVisibility); 1264 TheFn->setName(NewName); 1265 } 1266 if (ValueInfo TheFnVI = ExportSummary->getValueInfo(TheFn->getGUID())) 1267 // Any needed promotion of 'TheFn' has already been done during 1268 // LTO unit split, so we can ignore return value of AddCalls. 1269 AddCalls(SlotInfo, TheFnVI); 1270 1271 Res->TheKind = WholeProgramDevirtResolution::SingleImpl; 1272 Res->SingleImplName = std::string(TheFn->getName()); 1273 1274 return true; 1275 } 1276 1277 bool DevirtIndex::trySingleImplDevirt(MutableArrayRef<ValueInfo> TargetsForSlot, 1278 VTableSlotSummary &SlotSummary, 1279 VTableSlotInfo &SlotInfo, 1280 WholeProgramDevirtResolution *Res, 1281 std::set<ValueInfo> &DevirtTargets) { 1282 // See if the program contains a single implementation of this virtual 1283 // function. 1284 auto TheFn = TargetsForSlot[0]; 1285 for (auto &&Target : TargetsForSlot) 1286 if (TheFn != Target) 1287 return false; 1288 1289 // Don't devirtualize if we don't have target definition. 1290 auto Size = TheFn.getSummaryList().size(); 1291 if (!Size) 1292 return false; 1293 1294 // Don't devirtualize function if we're told to skip it 1295 // in -wholeprogramdevirt-skip. 1296 if (FunctionsToSkip.match(TheFn.name())) 1297 return false; 1298 1299 // If the summary list contains multiple summaries where at least one is 1300 // a local, give up, as we won't know which (possibly promoted) name to use. 1301 for (const auto &S : TheFn.getSummaryList()) 1302 if (GlobalValue::isLocalLinkage(S->linkage()) && Size > 1) 1303 return false; 1304 1305 // Collect functions devirtualized at least for one call site for stats. 1306 if (PrintSummaryDevirt || AreStatisticsEnabled()) 1307 DevirtTargets.insert(TheFn); 1308 1309 auto &S = TheFn.getSummaryList()[0]; 1310 bool IsExported = AddCalls(SlotInfo, TheFn); 1311 if (IsExported) 1312 ExportedGUIDs.insert(TheFn.getGUID()); 1313 1314 // Record in summary for use in devirtualization during the ThinLTO import 1315 // step. 1316 Res->TheKind = WholeProgramDevirtResolution::SingleImpl; 1317 if (GlobalValue::isLocalLinkage(S->linkage())) { 1318 if (IsExported) 1319 // If target is a local function and we are exporting it by 1320 // devirtualizing a call in another module, we need to record the 1321 // promoted name. 1322 Res->SingleImplName = ModuleSummaryIndex::getGlobalNameForLocal( 1323 TheFn.name(), ExportSummary.getModuleHash(S->modulePath())); 1324 else { 1325 LocalWPDTargetsMap[TheFn].push_back(SlotSummary); 1326 Res->SingleImplName = std::string(TheFn.name()); 1327 } 1328 } else 1329 Res->SingleImplName = std::string(TheFn.name()); 1330 1331 // Name will be empty if this thin link driven off of serialized combined 1332 // index (e.g. llvm-lto). However, WPD is not supported/invoked for the 1333 // legacy LTO API anyway. 1334 assert(!Res->SingleImplName.empty()); 1335 1336 return true; 1337 } 1338 1339 void DevirtModule::tryICallBranchFunnel( 1340 MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, 1341 WholeProgramDevirtResolution *Res, VTableSlot Slot) { 1342 Triple T(M.getTargetTriple()); 1343 if (T.getArch() != Triple::x86_64) 1344 return; 1345 1346 if (TargetsForSlot.size() > ClThreshold) 1347 return; 1348 1349 bool HasNonDevirt = !SlotInfo.CSInfo.AllCallSitesDevirted; 1350 if (!HasNonDevirt) 1351 for (auto &P : SlotInfo.ConstCSInfo) 1352 if (!P.second.AllCallSitesDevirted) { 1353 HasNonDevirt = true; 1354 break; 1355 } 1356 1357 if (!HasNonDevirt) 1358 return; 1359 1360 FunctionType *FT = 1361 FunctionType::get(Type::getVoidTy(M.getContext()), {Int8PtrTy}, true); 1362 Function *JT; 1363 if (isa<MDString>(Slot.TypeID)) { 1364 JT = Function::Create(FT, Function::ExternalLinkage, 1365 M.getDataLayout().getProgramAddressSpace(), 1366 getGlobalName(Slot, {}, "branch_funnel"), &M); 1367 JT->setVisibility(GlobalValue::HiddenVisibility); 1368 } else { 1369 JT = Function::Create(FT, Function::InternalLinkage, 1370 M.getDataLayout().getProgramAddressSpace(), 1371 "branch_funnel", &M); 1372 } 1373 JT->addParamAttr(0, Attribute::Nest); 1374 1375 std::vector<Value *> JTArgs; 1376 JTArgs.push_back(JT->arg_begin()); 1377 for (auto &T : TargetsForSlot) { 1378 JTArgs.push_back(getMemberAddr(T.TM)); 1379 JTArgs.push_back(T.Fn); 1380 } 1381 1382 BasicBlock *BB = BasicBlock::Create(M.getContext(), "", JT, nullptr); 1383 Function *Intr = 1384 Intrinsic::getDeclaration(&M, llvm::Intrinsic::icall_branch_funnel, {}); 1385 1386 auto *CI = CallInst::Create(Intr, JTArgs, "", BB); 1387 CI->setTailCallKind(CallInst::TCK_MustTail); 1388 ReturnInst::Create(M.getContext(), nullptr, BB); 1389 1390 bool IsExported = false; 1391 applyICallBranchFunnel(SlotInfo, JT, IsExported); 1392 if (IsExported) 1393 Res->TheKind = WholeProgramDevirtResolution::BranchFunnel; 1394 } 1395 1396 void DevirtModule::applyICallBranchFunnel(VTableSlotInfo &SlotInfo, 1397 Constant *JT, bool &IsExported) { 1398 auto Apply = [&](CallSiteInfo &CSInfo) { 1399 if (CSInfo.isExported()) 1400 IsExported = true; 1401 if (CSInfo.AllCallSitesDevirted) 1402 return; 1403 1404 std::map<CallBase *, CallBase *> CallBases; 1405 for (auto &&VCallSite : CSInfo.CallSites) { 1406 CallBase &CB = VCallSite.CB; 1407 1408 if (CallBases.find(&CB) != CallBases.end()) { 1409 // When finding devirtualizable calls, it's possible to find the same 1410 // vtable passed to multiple llvm.type.test or llvm.type.checked.load 1411 // calls, which can cause duplicate call sites to be recorded in 1412 // [Const]CallSites. If we've already found one of these 1413 // call instances, just ignore it. It will be replaced later. 1414 continue; 1415 } 1416 1417 // Jump tables are only profitable if the retpoline mitigation is enabled. 1418 Attribute FSAttr = CB.getCaller()->getFnAttribute("target-features"); 1419 if (!FSAttr.isValid() || 1420 !FSAttr.getValueAsString().contains("+retpoline")) 1421 continue; 1422 1423 NumBranchFunnel++; 1424 if (RemarksEnabled) 1425 VCallSite.emitRemark("branch-funnel", 1426 JT->stripPointerCasts()->getName(), OREGetter); 1427 1428 // Pass the address of the vtable in the nest register, which is r10 on 1429 // x86_64. 1430 std::vector<Type *> NewArgs; 1431 NewArgs.push_back(Int8PtrTy); 1432 append_range(NewArgs, CB.getFunctionType()->params()); 1433 FunctionType *NewFT = 1434 FunctionType::get(CB.getFunctionType()->getReturnType(), NewArgs, 1435 CB.getFunctionType()->isVarArg()); 1436 PointerType *NewFTPtr = PointerType::getUnqual(NewFT); 1437 1438 IRBuilder<> IRB(&CB); 1439 std::vector<Value *> Args; 1440 Args.push_back(IRB.CreateBitCast(VCallSite.VTable, Int8PtrTy)); 1441 llvm::append_range(Args, CB.args()); 1442 1443 CallBase *NewCS = nullptr; 1444 if (isa<CallInst>(CB)) 1445 NewCS = IRB.CreateCall(NewFT, IRB.CreateBitCast(JT, NewFTPtr), Args); 1446 else 1447 NewCS = IRB.CreateInvoke(NewFT, IRB.CreateBitCast(JT, NewFTPtr), 1448 cast<InvokeInst>(CB).getNormalDest(), 1449 cast<InvokeInst>(CB).getUnwindDest(), Args); 1450 NewCS->setCallingConv(CB.getCallingConv()); 1451 1452 AttributeList Attrs = CB.getAttributes(); 1453 std::vector<AttributeSet> NewArgAttrs; 1454 NewArgAttrs.push_back(AttributeSet::get( 1455 M.getContext(), ArrayRef<Attribute>{Attribute::get( 1456 M.getContext(), Attribute::Nest)})); 1457 for (unsigned I = 0; I + 2 < Attrs.getNumAttrSets(); ++I) 1458 NewArgAttrs.push_back(Attrs.getParamAttrs(I)); 1459 NewCS->setAttributes( 1460 AttributeList::get(M.getContext(), Attrs.getFnAttrs(), 1461 Attrs.getRetAttrs(), NewArgAttrs)); 1462 1463 CallBases[&CB] = NewCS; 1464 1465 // This use is no longer unsafe. 1466 if (VCallSite.NumUnsafeUses) 1467 --*VCallSite.NumUnsafeUses; 1468 } 1469 // Don't mark as devirtualized because there may be callers compiled without 1470 // retpoline mitigation, which would mean that they are lowered to 1471 // llvm.type.test and therefore require an llvm.type.test resolution for the 1472 // type identifier. 1473 1474 std::for_each(CallBases.begin(), CallBases.end(), [](auto &CBs) { 1475 CBs.first->replaceAllUsesWith(CBs.second); 1476 CBs.first->eraseFromParent(); 1477 }); 1478 }; 1479 Apply(SlotInfo.CSInfo); 1480 for (auto &P : SlotInfo.ConstCSInfo) 1481 Apply(P.second); 1482 } 1483 1484 bool DevirtModule::tryEvaluateFunctionsWithArgs( 1485 MutableArrayRef<VirtualCallTarget> TargetsForSlot, 1486 ArrayRef<uint64_t> Args) { 1487 // Evaluate each function and store the result in each target's RetVal 1488 // field. 1489 for (VirtualCallTarget &Target : TargetsForSlot) { 1490 // TODO: Skip for now if the vtable symbol was an alias to a function, 1491 // need to evaluate whether it would be correct to analyze the aliasee 1492 // function for this optimization. 1493 auto Fn = dyn_cast<Function>(Target.Fn); 1494 if (!Fn) 1495 return false; 1496 1497 if (Fn->arg_size() != Args.size() + 1) 1498 return false; 1499 1500 Evaluator Eval(M.getDataLayout(), nullptr); 1501 SmallVector<Constant *, 2> EvalArgs; 1502 EvalArgs.push_back( 1503 Constant::getNullValue(Fn->getFunctionType()->getParamType(0))); 1504 for (unsigned I = 0; I != Args.size(); ++I) { 1505 auto *ArgTy = 1506 dyn_cast<IntegerType>(Fn->getFunctionType()->getParamType(I + 1)); 1507 if (!ArgTy) 1508 return false; 1509 EvalArgs.push_back(ConstantInt::get(ArgTy, Args[I])); 1510 } 1511 1512 Constant *RetVal; 1513 if (!Eval.EvaluateFunction(Fn, RetVal, EvalArgs) || 1514 !isa<ConstantInt>(RetVal)) 1515 return false; 1516 Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue(); 1517 } 1518 return true; 1519 } 1520 1521 void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 1522 uint64_t TheRetVal) { 1523 for (auto Call : CSInfo.CallSites) { 1524 if (!OptimizedCalls.insert(&Call.CB).second) 1525 continue; 1526 NumUniformRetVal++; 1527 Call.replaceAndErase( 1528 "uniform-ret-val", FnName, RemarksEnabled, OREGetter, 1529 ConstantInt::get(cast<IntegerType>(Call.CB.getType()), TheRetVal)); 1530 } 1531 CSInfo.markDevirt(); 1532 } 1533 1534 bool DevirtModule::tryUniformRetValOpt( 1535 MutableArrayRef<VirtualCallTarget> TargetsForSlot, CallSiteInfo &CSInfo, 1536 WholeProgramDevirtResolution::ByArg *Res) { 1537 // Uniform return value optimization. If all functions return the same 1538 // constant, replace all calls with that constant. 1539 uint64_t TheRetVal = TargetsForSlot[0].RetVal; 1540 for (const VirtualCallTarget &Target : TargetsForSlot) 1541 if (Target.RetVal != TheRetVal) 1542 return false; 1543 1544 if (CSInfo.isExported()) { 1545 Res->TheKind = WholeProgramDevirtResolution::ByArg::UniformRetVal; 1546 Res->Info = TheRetVal; 1547 } 1548 1549 applyUniformRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), TheRetVal); 1550 if (RemarksEnabled || AreStatisticsEnabled()) 1551 for (auto &&Target : TargetsForSlot) 1552 Target.WasDevirt = true; 1553 return true; 1554 } 1555 1556 std::string DevirtModule::getGlobalName(VTableSlot Slot, 1557 ArrayRef<uint64_t> Args, 1558 StringRef Name) { 1559 std::string FullName = "__typeid_"; 1560 raw_string_ostream OS(FullName); 1561 OS << cast<MDString>(Slot.TypeID)->getString() << '_' << Slot.ByteOffset; 1562 for (uint64_t Arg : Args) 1563 OS << '_' << Arg; 1564 OS << '_' << Name; 1565 return OS.str(); 1566 } 1567 1568 bool DevirtModule::shouldExportConstantsAsAbsoluteSymbols() { 1569 Triple T(M.getTargetTriple()); 1570 return T.isX86() && T.getObjectFormat() == Triple::ELF; 1571 } 1572 1573 void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, 1574 StringRef Name, Constant *C) { 1575 GlobalAlias *GA = GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage, 1576 getGlobalName(Slot, Args, Name), C, &M); 1577 GA->setVisibility(GlobalValue::HiddenVisibility); 1578 } 1579 1580 void DevirtModule::exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, 1581 StringRef Name, uint32_t Const, 1582 uint32_t &Storage) { 1583 if (shouldExportConstantsAsAbsoluteSymbols()) { 1584 exportGlobal( 1585 Slot, Args, Name, 1586 ConstantExpr::getIntToPtr(ConstantInt::get(Int32Ty, Const), Int8PtrTy)); 1587 return; 1588 } 1589 1590 Storage = Const; 1591 } 1592 1593 Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, 1594 StringRef Name) { 1595 Constant *C = 1596 M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Arr0Ty); 1597 auto *GV = dyn_cast<GlobalVariable>(C); 1598 if (GV) 1599 GV->setVisibility(GlobalValue::HiddenVisibility); 1600 return C; 1601 } 1602 1603 Constant *DevirtModule::importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, 1604 StringRef Name, IntegerType *IntTy, 1605 uint32_t Storage) { 1606 if (!shouldExportConstantsAsAbsoluteSymbols()) 1607 return ConstantInt::get(IntTy, Storage); 1608 1609 Constant *C = importGlobal(Slot, Args, Name); 1610 auto *GV = cast<GlobalVariable>(C->stripPointerCasts()); 1611 C = ConstantExpr::getPtrToInt(C, IntTy); 1612 1613 // We only need to set metadata if the global is newly created, in which 1614 // case it would not have hidden visibility. 1615 if (GV->hasMetadata(LLVMContext::MD_absolute_symbol)) 1616 return C; 1617 1618 auto SetAbsRange = [&](uint64_t Min, uint64_t Max) { 1619 auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min)); 1620 auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max)); 1621 GV->setMetadata(LLVMContext::MD_absolute_symbol, 1622 MDNode::get(M.getContext(), {MinC, MaxC})); 1623 }; 1624 unsigned AbsWidth = IntTy->getBitWidth(); 1625 if (AbsWidth == IntPtrTy->getBitWidth()) 1626 SetAbsRange(~0ull, ~0ull); // Full set. 1627 else 1628 SetAbsRange(0, 1ull << AbsWidth); 1629 return C; 1630 } 1631 1632 void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, 1633 bool IsOne, 1634 Constant *UniqueMemberAddr) { 1635 for (auto &&Call : CSInfo.CallSites) { 1636 if (!OptimizedCalls.insert(&Call.CB).second) 1637 continue; 1638 IRBuilder<> B(&Call.CB); 1639 Value *Cmp = 1640 B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, Call.VTable, 1641 B.CreateBitCast(UniqueMemberAddr, Call.VTable->getType())); 1642 Cmp = B.CreateZExt(Cmp, Call.CB.getType()); 1643 NumUniqueRetVal++; 1644 Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, OREGetter, 1645 Cmp); 1646 } 1647 CSInfo.markDevirt(); 1648 } 1649 1650 Constant *DevirtModule::getMemberAddr(const TypeMemberInfo *M) { 1651 Constant *C = ConstantExpr::getBitCast(M->Bits->GV, Int8PtrTy); 1652 return ConstantExpr::getGetElementPtr(Int8Ty, C, 1653 ConstantInt::get(Int64Ty, M->Offset)); 1654 } 1655 1656 bool DevirtModule::tryUniqueRetValOpt( 1657 unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot, 1658 CallSiteInfo &CSInfo, WholeProgramDevirtResolution::ByArg *Res, 1659 VTableSlot Slot, ArrayRef<uint64_t> Args) { 1660 // IsOne controls whether we look for a 0 or a 1. 1661 auto tryUniqueRetValOptFor = [&](bool IsOne) { 1662 const TypeMemberInfo *UniqueMember = nullptr; 1663 for (const VirtualCallTarget &Target : TargetsForSlot) { 1664 if (Target.RetVal == (IsOne ? 1 : 0)) { 1665 if (UniqueMember) 1666 return false; 1667 UniqueMember = Target.TM; 1668 } 1669 } 1670 1671 // We should have found a unique member or bailed out by now. We already 1672 // checked for a uniform return value in tryUniformRetValOpt. 1673 assert(UniqueMember); 1674 1675 Constant *UniqueMemberAddr = getMemberAddr(UniqueMember); 1676 if (CSInfo.isExported()) { 1677 Res->TheKind = WholeProgramDevirtResolution::ByArg::UniqueRetVal; 1678 Res->Info = IsOne; 1679 1680 exportGlobal(Slot, Args, "unique_member", UniqueMemberAddr); 1681 } 1682 1683 // Replace each call with the comparison. 1684 applyUniqueRetValOpt(CSInfo, TargetsForSlot[0].Fn->getName(), IsOne, 1685 UniqueMemberAddr); 1686 1687 // Update devirtualization statistics for targets. 1688 if (RemarksEnabled || AreStatisticsEnabled()) 1689 for (auto &&Target : TargetsForSlot) 1690 Target.WasDevirt = true; 1691 1692 return true; 1693 }; 1694 1695 if (BitWidth == 1) { 1696 if (tryUniqueRetValOptFor(true)) 1697 return true; 1698 if (tryUniqueRetValOptFor(false)) 1699 return true; 1700 } 1701 return false; 1702 } 1703 1704 void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, 1705 Constant *Byte, Constant *Bit) { 1706 for (auto Call : CSInfo.CallSites) { 1707 if (!OptimizedCalls.insert(&Call.CB).second) 1708 continue; 1709 auto *RetType = cast<IntegerType>(Call.CB.getType()); 1710 IRBuilder<> B(&Call.CB); 1711 Value *Addr = 1712 B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte); 1713 if (RetType->getBitWidth() == 1) { 1714 Value *Bits = B.CreateLoad(Int8Ty, Addr); 1715 Value *BitsAndBit = B.CreateAnd(Bits, Bit); 1716 auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); 1717 NumVirtConstProp1Bit++; 1718 Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled, 1719 OREGetter, IsBitSet); 1720 } else { 1721 Value *Val = B.CreateLoad(RetType, Addr); 1722 NumVirtConstProp++; 1723 Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, 1724 OREGetter, Val); 1725 } 1726 } 1727 CSInfo.markDevirt(); 1728 } 1729 1730 bool DevirtModule::tryVirtualConstProp( 1731 MutableArrayRef<VirtualCallTarget> TargetsForSlot, VTableSlotInfo &SlotInfo, 1732 WholeProgramDevirtResolution *Res, VTableSlot Slot) { 1733 // TODO: Skip for now if the vtable symbol was an alias to a function, 1734 // need to evaluate whether it would be correct to analyze the aliasee 1735 // function for this optimization. 1736 auto Fn = dyn_cast<Function>(TargetsForSlot[0].Fn); 1737 if (!Fn) 1738 return false; 1739 // This only works if the function returns an integer. 1740 auto RetType = dyn_cast<IntegerType>(Fn->getReturnType()); 1741 if (!RetType) 1742 return false; 1743 unsigned BitWidth = RetType->getBitWidth(); 1744 if (BitWidth > 64) 1745 return false; 1746 1747 // Make sure that each function is defined, does not access memory, takes at 1748 // least one argument, does not use its first argument (which we assume is 1749 // 'this'), and has the same return type. 1750 // 1751 // Note that we test whether this copy of the function is readnone, rather 1752 // than testing function attributes, which must hold for any copy of the 1753 // function, even a less optimized version substituted at link time. This is 1754 // sound because the virtual constant propagation optimizations effectively 1755 // inline all implementations of the virtual function into each call site, 1756 // rather than using function attributes to perform local optimization. 1757 for (VirtualCallTarget &Target : TargetsForSlot) { 1758 // TODO: Skip for now if the vtable symbol was an alias to a function, 1759 // need to evaluate whether it would be correct to analyze the aliasee 1760 // function for this optimization. 1761 auto Fn = dyn_cast<Function>(Target.Fn); 1762 if (!Fn) 1763 return false; 1764 1765 if (Fn->isDeclaration() || 1766 !computeFunctionBodyMemoryAccess(*Fn, AARGetter(*Fn)) 1767 .doesNotAccessMemory() || 1768 Fn->arg_empty() || !Fn->arg_begin()->use_empty() || 1769 Fn->getReturnType() != RetType) 1770 return false; 1771 } 1772 1773 for (auto &&CSByConstantArg : SlotInfo.ConstCSInfo) { 1774 if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first)) 1775 continue; 1776 1777 WholeProgramDevirtResolution::ByArg *ResByArg = nullptr; 1778 if (Res) 1779 ResByArg = &Res->ResByArg[CSByConstantArg.first]; 1780 1781 if (tryUniformRetValOpt(TargetsForSlot, CSByConstantArg.second, ResByArg)) 1782 continue; 1783 1784 if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second, 1785 ResByArg, Slot, CSByConstantArg.first)) 1786 continue; 1787 1788 // Find an allocation offset in bits in all vtables associated with the 1789 // type. 1790 uint64_t AllocBefore = 1791 findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth); 1792 uint64_t AllocAfter = 1793 findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth); 1794 1795 // Calculate the total amount of padding needed to store a value at both 1796 // ends of the object. 1797 uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0; 1798 for (auto &&Target : TargetsForSlot) { 1799 TotalPaddingBefore += std::max<int64_t>( 1800 (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0); 1801 TotalPaddingAfter += std::max<int64_t>( 1802 (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0); 1803 } 1804 1805 // If the amount of padding is too large, give up. 1806 // FIXME: do something smarter here. 1807 if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128) 1808 continue; 1809 1810 // Calculate the offset to the value as a (possibly negative) byte offset 1811 // and (if applicable) a bit offset, and store the values in the targets. 1812 int64_t OffsetByte; 1813 uint64_t OffsetBit; 1814 if (TotalPaddingBefore <= TotalPaddingAfter) 1815 setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte, 1816 OffsetBit); 1817 else 1818 setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte, 1819 OffsetBit); 1820 1821 if (RemarksEnabled || AreStatisticsEnabled()) 1822 for (auto &&Target : TargetsForSlot) 1823 Target.WasDevirt = true; 1824 1825 1826 if (CSByConstantArg.second.isExported()) { 1827 ResByArg->TheKind = WholeProgramDevirtResolution::ByArg::VirtualConstProp; 1828 exportConstant(Slot, CSByConstantArg.first, "byte", OffsetByte, 1829 ResByArg->Byte); 1830 exportConstant(Slot, CSByConstantArg.first, "bit", 1ULL << OffsetBit, 1831 ResByArg->Bit); 1832 } 1833 1834 // Rewrite each call to a load from OffsetByte/OffsetBit. 1835 Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte); 1836 Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); 1837 applyVirtualConstProp(CSByConstantArg.second, 1838 TargetsForSlot[0].Fn->getName(), ByteConst, BitConst); 1839 } 1840 return true; 1841 } 1842 1843 void DevirtModule::rebuildGlobal(VTableBits &B) { 1844 if (B.Before.Bytes.empty() && B.After.Bytes.empty()) 1845 return; 1846 1847 // Align the before byte array to the global's minimum alignment so that we 1848 // don't break any alignment requirements on the global. 1849 Align Alignment = M.getDataLayout().getValueOrABITypeAlignment( 1850 B.GV->getAlign(), B.GV->getValueType()); 1851 B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), Alignment)); 1852 1853 // Before was stored in reverse order; flip it now. 1854 for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I) 1855 std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]); 1856 1857 // Build an anonymous global containing the before bytes, followed by the 1858 // original initializer, followed by the after bytes. 1859 auto NewInit = ConstantStruct::getAnon( 1860 {ConstantDataArray::get(M.getContext(), B.Before.Bytes), 1861 B.GV->getInitializer(), 1862 ConstantDataArray::get(M.getContext(), B.After.Bytes)}); 1863 auto NewGV = 1864 new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(), 1865 GlobalVariable::PrivateLinkage, NewInit, "", B.GV); 1866 NewGV->setSection(B.GV->getSection()); 1867 NewGV->setComdat(B.GV->getComdat()); 1868 NewGV->setAlignment(B.GV->getAlign()); 1869 1870 // Copy the original vtable's metadata to the anonymous global, adjusting 1871 // offsets as required. 1872 NewGV->copyMetadata(B.GV, B.Before.Bytes.size()); 1873 1874 // Build an alias named after the original global, pointing at the second 1875 // element (the original initializer). 1876 auto Alias = GlobalAlias::create( 1877 B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "", 1878 ConstantExpr::getGetElementPtr( 1879 NewInit->getType(), NewGV, 1880 ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0), 1881 ConstantInt::get(Int32Ty, 1)}), 1882 &M); 1883 Alias->setVisibility(B.GV->getVisibility()); 1884 Alias->takeName(B.GV); 1885 1886 B.GV->replaceAllUsesWith(Alias); 1887 B.GV->eraseFromParent(); 1888 } 1889 1890 bool DevirtModule::areRemarksEnabled() { 1891 const auto &FL = M.getFunctionList(); 1892 for (const Function &Fn : FL) { 1893 if (Fn.empty()) 1894 continue; 1895 auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &Fn.front()); 1896 return DI.isEnabled(); 1897 } 1898 return false; 1899 } 1900 1901 void DevirtModule::scanTypeTestUsers( 1902 Function *TypeTestFunc, 1903 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) { 1904 // Find all virtual calls via a virtual table pointer %p under an assumption 1905 // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p 1906 // points to a member of the type identifier %md. Group calls by (type ID, 1907 // offset) pair (effectively the identity of the virtual function) and store 1908 // to CallSlots. 1909 for (Use &U : llvm::make_early_inc_range(TypeTestFunc->uses())) { 1910 auto *CI = dyn_cast<CallInst>(U.getUser()); 1911 if (!CI) 1912 continue; 1913 1914 // Search for virtual calls based on %p and add them to DevirtCalls. 1915 SmallVector<DevirtCallSite, 1> DevirtCalls; 1916 SmallVector<CallInst *, 1> Assumes; 1917 auto &DT = LookupDomTree(*CI->getFunction()); 1918 findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT); 1919 1920 Metadata *TypeId = 1921 cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata(); 1922 // If we found any, add them to CallSlots. 1923 if (!Assumes.empty()) { 1924 Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); 1925 for (DevirtCallSite Call : DevirtCalls) 1926 CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB, nullptr); 1927 } 1928 1929 auto RemoveTypeTestAssumes = [&]() { 1930 // We no longer need the assumes or the type test. 1931 for (auto *Assume : Assumes) 1932 Assume->eraseFromParent(); 1933 // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we 1934 // may use the vtable argument later. 1935 if (CI->use_empty()) 1936 CI->eraseFromParent(); 1937 }; 1938 1939 // At this point we could remove all type test assume sequences, as they 1940 // were originally inserted for WPD. However, we can keep these in the 1941 // code stream for later analysis (e.g. to help drive more efficient ICP 1942 // sequences). They will eventually be removed by a second LowerTypeTests 1943 // invocation that cleans them up. In order to do this correctly, the first 1944 // LowerTypeTests invocation needs to know that they have "Unknown" type 1945 // test resolution, so that they aren't treated as Unsat and lowered to 1946 // False, which will break any uses on assumes. Below we remove any type 1947 // test assumes that will not be treated as Unknown by LTT. 1948 1949 // The type test assumes will be treated by LTT as Unsat if the type id is 1950 // not used on a global (in which case it has no entry in the TypeIdMap). 1951 if (!TypeIdMap.count(TypeId)) 1952 RemoveTypeTestAssumes(); 1953 1954 // For ThinLTO importing, we need to remove the type test assumes if this is 1955 // an MDString type id without a corresponding TypeIdSummary. Any 1956 // non-MDString type ids are ignored and treated as Unknown by LTT, so their 1957 // type test assumes can be kept. If the MDString type id is missing a 1958 // TypeIdSummary (e.g. because there was no use on a vcall, preventing the 1959 // exporting phase of WPD from analyzing it), then it would be treated as 1960 // Unsat by LTT and we need to remove its type test assumes here. If not 1961 // used on a vcall we don't need them for later optimization use in any 1962 // case. 1963 else if (ImportSummary && isa<MDString>(TypeId)) { 1964 const TypeIdSummary *TidSummary = 1965 ImportSummary->getTypeIdSummary(cast<MDString>(TypeId)->getString()); 1966 if (!TidSummary) 1967 RemoveTypeTestAssumes(); 1968 else 1969 // If one was created it should not be Unsat, because if we reached here 1970 // the type id was used on a global. 1971 assert(TidSummary->TTRes.TheKind != TypeTestResolution::Unsat); 1972 } 1973 } 1974 } 1975 1976 void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) { 1977 Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test); 1978 1979 for (Use &U : llvm::make_early_inc_range(TypeCheckedLoadFunc->uses())) { 1980 auto *CI = dyn_cast<CallInst>(U.getUser()); 1981 if (!CI) 1982 continue; 1983 1984 Value *Ptr = CI->getArgOperand(0); 1985 Value *Offset = CI->getArgOperand(1); 1986 Value *TypeIdValue = CI->getArgOperand(2); 1987 Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata(); 1988 1989 SmallVector<DevirtCallSite, 1> DevirtCalls; 1990 SmallVector<Instruction *, 1> LoadedPtrs; 1991 SmallVector<Instruction *, 1> Preds; 1992 bool HasNonCallUses = false; 1993 auto &DT = LookupDomTree(*CI->getFunction()); 1994 findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds, 1995 HasNonCallUses, CI, DT); 1996 1997 // Start by generating "pessimistic" code that explicitly loads the function 1998 // pointer from the vtable and performs the type check. If possible, we will 1999 // eliminate the load and the type check later. 2000 2001 // If possible, only generate the load at the point where it is used. 2002 // This helps avoid unnecessary spills. 2003 IRBuilder<> LoadB( 2004 (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI); 2005 2006 Value *LoadedValue = nullptr; 2007 if (TypeCheckedLoadFunc->getIntrinsicID() == 2008 Intrinsic::type_checked_load_relative) { 2009 Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); 2010 Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int32Ty)); 2011 LoadedValue = LoadB.CreateLoad(Int32Ty, GEPPtr); 2012 LoadedValue = LoadB.CreateSExt(LoadedValue, IntPtrTy); 2013 GEP = LoadB.CreatePtrToInt(GEP, IntPtrTy); 2014 LoadedValue = LoadB.CreateAdd(GEP, LoadedValue); 2015 LoadedValue = LoadB.CreateIntToPtr(LoadedValue, Int8PtrTy); 2016 } else { 2017 Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset); 2018 Value *GEPPtr = 2019 LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy)); 2020 LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr); 2021 } 2022 2023 for (Instruction *LoadedPtr : LoadedPtrs) { 2024 LoadedPtr->replaceAllUsesWith(LoadedValue); 2025 LoadedPtr->eraseFromParent(); 2026 } 2027 2028 // Likewise for the type test. 2029 IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI); 2030 CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue}); 2031 2032 for (Instruction *Pred : Preds) { 2033 Pred->replaceAllUsesWith(TypeTestCall); 2034 Pred->eraseFromParent(); 2035 } 2036 2037 // We have already erased any extractvalue instructions that refer to the 2038 // intrinsic call, but the intrinsic may have other non-extractvalue uses 2039 // (although this is unlikely). In that case, explicitly build a pair and 2040 // RAUW it. 2041 if (!CI->use_empty()) { 2042 Value *Pair = PoisonValue::get(CI->getType()); 2043 IRBuilder<> B(CI); 2044 Pair = B.CreateInsertValue(Pair, LoadedValue, {0}); 2045 Pair = B.CreateInsertValue(Pair, TypeTestCall, {1}); 2046 CI->replaceAllUsesWith(Pair); 2047 } 2048 2049 // The number of unsafe uses is initially the number of uses. 2050 auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall]; 2051 NumUnsafeUses = DevirtCalls.size(); 2052 2053 // If the function pointer has a non-call user, we cannot eliminate the type 2054 // check, as one of those users may eventually call the pointer. Increment 2055 // the unsafe use count to make sure it cannot reach zero. 2056 if (HasNonCallUses) 2057 ++NumUnsafeUses; 2058 for (DevirtCallSite Call : DevirtCalls) { 2059 CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CB, 2060 &NumUnsafeUses); 2061 } 2062 2063 CI->eraseFromParent(); 2064 } 2065 } 2066 2067 void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) { 2068 auto *TypeId = dyn_cast<MDString>(Slot.TypeID); 2069 if (!TypeId) 2070 return; 2071 const TypeIdSummary *TidSummary = 2072 ImportSummary->getTypeIdSummary(TypeId->getString()); 2073 if (!TidSummary) 2074 return; 2075 auto ResI = TidSummary->WPDRes.find(Slot.ByteOffset); 2076 if (ResI == TidSummary->WPDRes.end()) 2077 return; 2078 const WholeProgramDevirtResolution &Res = ResI->second; 2079 2080 if (Res.TheKind == WholeProgramDevirtResolution::SingleImpl) { 2081 assert(!Res.SingleImplName.empty()); 2082 // The type of the function in the declaration is irrelevant because every 2083 // call site will cast it to the correct type. 2084 Constant *SingleImpl = 2085 cast<Constant>(M.getOrInsertFunction(Res.SingleImplName, 2086 Type::getVoidTy(M.getContext())) 2087 .getCallee()); 2088 2089 // This is the import phase so we should not be exporting anything. 2090 bool IsExported = false; 2091 applySingleImplDevirt(SlotInfo, SingleImpl, IsExported); 2092 assert(!IsExported); 2093 } 2094 2095 for (auto &CSByConstantArg : SlotInfo.ConstCSInfo) { 2096 auto I = Res.ResByArg.find(CSByConstantArg.first); 2097 if (I == Res.ResByArg.end()) 2098 continue; 2099 auto &ResByArg = I->second; 2100 // FIXME: We should figure out what to do about the "function name" argument 2101 // to the apply* functions, as the function names are unavailable during the 2102 // importing phase. For now we just pass the empty string. This does not 2103 // impact correctness because the function names are just used for remarks. 2104 switch (ResByArg.TheKind) { 2105 case WholeProgramDevirtResolution::ByArg::UniformRetVal: 2106 applyUniformRetValOpt(CSByConstantArg.second, "", ResByArg.Info); 2107 break; 2108 case WholeProgramDevirtResolution::ByArg::UniqueRetVal: { 2109 Constant *UniqueMemberAddr = 2110 importGlobal(Slot, CSByConstantArg.first, "unique_member"); 2111 applyUniqueRetValOpt(CSByConstantArg.second, "", ResByArg.Info, 2112 UniqueMemberAddr); 2113 break; 2114 } 2115 case WholeProgramDevirtResolution::ByArg::VirtualConstProp: { 2116 Constant *Byte = importConstant(Slot, CSByConstantArg.first, "byte", 2117 Int32Ty, ResByArg.Byte); 2118 Constant *Bit = importConstant(Slot, CSByConstantArg.first, "bit", Int8Ty, 2119 ResByArg.Bit); 2120 applyVirtualConstProp(CSByConstantArg.second, "", Byte, Bit); 2121 break; 2122 } 2123 default: 2124 break; 2125 } 2126 } 2127 2128 if (Res.TheKind == WholeProgramDevirtResolution::BranchFunnel) { 2129 // The type of the function is irrelevant, because it's bitcast at calls 2130 // anyhow. 2131 Constant *JT = cast<Constant>( 2132 M.getOrInsertFunction(getGlobalName(Slot, {}, "branch_funnel"), 2133 Type::getVoidTy(M.getContext())) 2134 .getCallee()); 2135 bool IsExported = false; 2136 applyICallBranchFunnel(SlotInfo, JT, IsExported); 2137 assert(!IsExported); 2138 } 2139 } 2140 2141 void DevirtModule::removeRedundantTypeTests() { 2142 auto True = ConstantInt::getTrue(M.getContext()); 2143 for (auto &&U : NumUnsafeUsesForTypeTest) { 2144 if (U.second == 0) { 2145 U.first->replaceAllUsesWith(True); 2146 U.first->eraseFromParent(); 2147 } 2148 } 2149 } 2150 2151 ValueInfo 2152 DevirtModule::lookUpFunctionValueInfo(Function *TheFn, 2153 ModuleSummaryIndex *ExportSummary) { 2154 assert((ExportSummary != nullptr) && 2155 "Caller guarantees ExportSummary is not nullptr"); 2156 2157 const auto TheFnGUID = TheFn->getGUID(); 2158 const auto TheFnGUIDWithExportedName = GlobalValue::getGUID(TheFn->getName()); 2159 // Look up ValueInfo with the GUID in the current linkage. 2160 ValueInfo TheFnVI = ExportSummary->getValueInfo(TheFnGUID); 2161 // If no entry is found and GUID is different from GUID computed using 2162 // exported name, look up ValueInfo with the exported name unconditionally. 2163 // This is a fallback. 2164 // 2165 // The reason to have a fallback: 2166 // 1. LTO could enable global value internalization via 2167 // `enable-lto-internalization`. 2168 // 2. The GUID in ExportedSummary is computed using exported name. 2169 if ((!TheFnVI) && (TheFnGUID != TheFnGUIDWithExportedName)) { 2170 TheFnVI = ExportSummary->getValueInfo(TheFnGUIDWithExportedName); 2171 } 2172 return TheFnVI; 2173 } 2174 2175 bool DevirtModule::mustBeUnreachableFunction( 2176 Function *const F, ModuleSummaryIndex *ExportSummary) { 2177 // First, learn unreachability by analyzing function IR. 2178 if (!F->isDeclaration()) { 2179 // A function must be unreachable if its entry block ends with an 2180 // 'unreachable'. 2181 return isa<UnreachableInst>(F->getEntryBlock().getTerminator()); 2182 } 2183 // Learn unreachability from ExportSummary if ExportSummary is present. 2184 return ExportSummary && 2185 ::mustBeUnreachableFunction( 2186 DevirtModule::lookUpFunctionValueInfo(F, ExportSummary)); 2187 } 2188 2189 bool DevirtModule::run() { 2190 // If only some of the modules were split, we cannot correctly perform 2191 // this transformation. We already checked for the presense of type tests 2192 // with partially split modules during the thin link, and would have emitted 2193 // an error if any were found, so here we can simply return. 2194 if ((ExportSummary && ExportSummary->partiallySplitLTOUnits()) || 2195 (ImportSummary && ImportSummary->partiallySplitLTOUnits())) 2196 return false; 2197 2198 Function *TypeTestFunc = 2199 M.getFunction(Intrinsic::getName(Intrinsic::type_test)); 2200 Function *TypeCheckedLoadFunc = 2201 M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load)); 2202 Function *TypeCheckedLoadRelativeFunc = 2203 M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load_relative)); 2204 Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume)); 2205 2206 // Normally if there are no users of the devirtualization intrinsics in the 2207 // module, this pass has nothing to do. But if we are exporting, we also need 2208 // to handle any users that appear only in the function summaries. 2209 if (!ExportSummary && 2210 (!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc || 2211 AssumeFunc->use_empty()) && 2212 (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()) && 2213 (!TypeCheckedLoadRelativeFunc || 2214 TypeCheckedLoadRelativeFunc->use_empty())) 2215 return false; 2216 2217 // Rebuild type metadata into a map for easy lookup. 2218 std::vector<VTableBits> Bits; 2219 DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap; 2220 buildTypeIdentifierMap(Bits, TypeIdMap); 2221 2222 if (TypeTestFunc && AssumeFunc) 2223 scanTypeTestUsers(TypeTestFunc, TypeIdMap); 2224 2225 if (TypeCheckedLoadFunc) 2226 scanTypeCheckedLoadUsers(TypeCheckedLoadFunc); 2227 2228 if (TypeCheckedLoadRelativeFunc) 2229 scanTypeCheckedLoadUsers(TypeCheckedLoadRelativeFunc); 2230 2231 if (ImportSummary) { 2232 for (auto &S : CallSlots) 2233 importResolution(S.first, S.second); 2234 2235 removeRedundantTypeTests(); 2236 2237 // We have lowered or deleted the type intrinsics, so we will no longer have 2238 // enough information to reason about the liveness of virtual function 2239 // pointers in GlobalDCE. 2240 for (GlobalVariable &GV : M.globals()) 2241 GV.eraseMetadata(LLVMContext::MD_vcall_visibility); 2242 2243 // The rest of the code is only necessary when exporting or during regular 2244 // LTO, so we are done. 2245 return true; 2246 } 2247 2248 if (TypeIdMap.empty()) 2249 return true; 2250 2251 // Collect information from summary about which calls to try to devirtualize. 2252 if (ExportSummary) { 2253 DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID; 2254 for (auto &P : TypeIdMap) { 2255 if (auto *TypeId = dyn_cast<MDString>(P.first)) 2256 MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back( 2257 TypeId); 2258 } 2259 2260 for (auto &P : *ExportSummary) { 2261 for (auto &S : P.second.SummaryList) { 2262 auto *FS = dyn_cast<FunctionSummary>(S.get()); 2263 if (!FS) 2264 continue; 2265 // FIXME: Only add live functions. 2266 for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) { 2267 for (Metadata *MD : MetadataByGUID[VF.GUID]) { 2268 CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS); 2269 } 2270 } 2271 for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) { 2272 for (Metadata *MD : MetadataByGUID[VF.GUID]) { 2273 CallSlots[{MD, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS); 2274 } 2275 } 2276 for (const FunctionSummary::ConstVCall &VC : 2277 FS->type_test_assume_const_vcalls()) { 2278 for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { 2279 CallSlots[{MD, VC.VFunc.Offset}] 2280 .ConstCSInfo[VC.Args] 2281 .addSummaryTypeTestAssumeUser(FS); 2282 } 2283 } 2284 for (const FunctionSummary::ConstVCall &VC : 2285 FS->type_checked_load_const_vcalls()) { 2286 for (Metadata *MD : MetadataByGUID[VC.VFunc.GUID]) { 2287 CallSlots[{MD, VC.VFunc.Offset}] 2288 .ConstCSInfo[VC.Args] 2289 .addSummaryTypeCheckedLoadUser(FS); 2290 } 2291 } 2292 } 2293 } 2294 } 2295 2296 // For each (type, offset) pair: 2297 bool DidVirtualConstProp = false; 2298 std::map<std::string, GlobalValue *> DevirtTargets; 2299 for (auto &S : CallSlots) { 2300 // Search each of the members of the type identifier for the virtual 2301 // function implementation at offset S.first.ByteOffset, and add to 2302 // TargetsForSlot. 2303 std::vector<VirtualCallTarget> TargetsForSlot; 2304 WholeProgramDevirtResolution *Res = nullptr; 2305 const std::set<TypeMemberInfo> &TypeMemberInfos = TypeIdMap[S.first.TypeID]; 2306 if (ExportSummary && isa<MDString>(S.first.TypeID) && 2307 TypeMemberInfos.size()) 2308 // For any type id used on a global's type metadata, create the type id 2309 // summary resolution regardless of whether we can devirtualize, so that 2310 // lower type tests knows the type id is not Unsat. If it was not used on 2311 // a global's type metadata, the TypeIdMap entry set will be empty, and 2312 // we don't want to create an entry (with the default Unknown type 2313 // resolution), which can prevent detection of the Unsat. 2314 Res = &ExportSummary 2315 ->getOrInsertTypeIdSummary( 2316 cast<MDString>(S.first.TypeID)->getString()) 2317 .WPDRes[S.first.ByteOffset]; 2318 if (tryFindVirtualCallTargets(TargetsForSlot, TypeMemberInfos, 2319 S.first.ByteOffset, ExportSummary)) { 2320 2321 if (!trySingleImplDevirt(ExportSummary, TargetsForSlot, S.second, Res)) { 2322 DidVirtualConstProp |= 2323 tryVirtualConstProp(TargetsForSlot, S.second, Res, S.first); 2324 2325 tryICallBranchFunnel(TargetsForSlot, S.second, Res, S.first); 2326 } 2327 2328 // Collect functions devirtualized at least for one call site for stats. 2329 if (RemarksEnabled || AreStatisticsEnabled()) 2330 for (const auto &T : TargetsForSlot) 2331 if (T.WasDevirt) 2332 DevirtTargets[std::string(T.Fn->getName())] = T.Fn; 2333 } 2334 2335 // CFI-specific: if we are exporting and any llvm.type.checked.load 2336 // intrinsics were *not* devirtualized, we need to add the resulting 2337 // llvm.type.test intrinsics to the function summaries so that the 2338 // LowerTypeTests pass will export them. 2339 if (ExportSummary && isa<MDString>(S.first.TypeID)) { 2340 auto GUID = 2341 GlobalValue::getGUID(cast<MDString>(S.first.TypeID)->getString()); 2342 for (auto *FS : S.second.CSInfo.SummaryTypeCheckedLoadUsers) 2343 FS->addTypeTest(GUID); 2344 for (auto &CCS : S.second.ConstCSInfo) 2345 for (auto *FS : CCS.second.SummaryTypeCheckedLoadUsers) 2346 FS->addTypeTest(GUID); 2347 } 2348 } 2349 2350 if (RemarksEnabled) { 2351 // Generate remarks for each devirtualized function. 2352 for (const auto &DT : DevirtTargets) { 2353 GlobalValue *GV = DT.second; 2354 auto F = dyn_cast<Function>(GV); 2355 if (!F) { 2356 auto A = dyn_cast<GlobalAlias>(GV); 2357 assert(A && isa<Function>(A->getAliasee())); 2358 F = dyn_cast<Function>(A->getAliasee()); 2359 assert(F); 2360 } 2361 2362 using namespace ore; 2363 OREGetter(F).emit(OptimizationRemark(DEBUG_TYPE, "Devirtualized", F) 2364 << "devirtualized " 2365 << NV("FunctionName", DT.first)); 2366 } 2367 } 2368 2369 NumDevirtTargets += DevirtTargets.size(); 2370 2371 removeRedundantTypeTests(); 2372 2373 // Rebuild each global we touched as part of virtual constant propagation to 2374 // include the before and after bytes. 2375 if (DidVirtualConstProp) 2376 for (VTableBits &B : Bits) 2377 rebuildGlobal(B); 2378 2379 // We have lowered or deleted the type intrinsics, so we will no longer have 2380 // enough information to reason about the liveness of virtual function 2381 // pointers in GlobalDCE. 2382 for (GlobalVariable &GV : M.globals()) 2383 GV.eraseMetadata(LLVMContext::MD_vcall_visibility); 2384 2385 for (auto *CI : CallsWithPtrAuthBundleRemoved) 2386 CI->eraseFromParent(); 2387 2388 return true; 2389 } 2390 2391 void DevirtIndex::run() { 2392 if (ExportSummary.typeIdCompatibleVtableMap().empty()) 2393 return; 2394 2395 DenseMap<GlobalValue::GUID, std::vector<StringRef>> NameByGUID; 2396 for (const auto &P : ExportSummary.typeIdCompatibleVtableMap()) { 2397 NameByGUID[GlobalValue::getGUID(P.first)].push_back(P.first); 2398 // Create the type id summary resolution regardlness of whether we can 2399 // devirtualize, so that lower type tests knows the type id is used on 2400 // a global and not Unsat. We do this here rather than in the loop over the 2401 // CallSlots, since that handling will only see type tests that directly 2402 // feed assumes, and we would miss any that aren't currently handled by WPD 2403 // (such as type tests that feed assumes via phis). 2404 ExportSummary.getOrInsertTypeIdSummary(P.first); 2405 } 2406 2407 // Collect information from summary about which calls to try to devirtualize. 2408 for (auto &P : ExportSummary) { 2409 for (auto &S : P.second.SummaryList) { 2410 auto *FS = dyn_cast<FunctionSummary>(S.get()); 2411 if (!FS) 2412 continue; 2413 // FIXME: Only add live functions. 2414 for (FunctionSummary::VFuncId VF : FS->type_test_assume_vcalls()) { 2415 for (StringRef Name : NameByGUID[VF.GUID]) { 2416 CallSlots[{Name, VF.Offset}].CSInfo.addSummaryTypeTestAssumeUser(FS); 2417 } 2418 } 2419 for (FunctionSummary::VFuncId VF : FS->type_checked_load_vcalls()) { 2420 for (StringRef Name : NameByGUID[VF.GUID]) { 2421 CallSlots[{Name, VF.Offset}].CSInfo.addSummaryTypeCheckedLoadUser(FS); 2422 } 2423 } 2424 for (const FunctionSummary::ConstVCall &VC : 2425 FS->type_test_assume_const_vcalls()) { 2426 for (StringRef Name : NameByGUID[VC.VFunc.GUID]) { 2427 CallSlots[{Name, VC.VFunc.Offset}] 2428 .ConstCSInfo[VC.Args] 2429 .addSummaryTypeTestAssumeUser(FS); 2430 } 2431 } 2432 for (const FunctionSummary::ConstVCall &VC : 2433 FS->type_checked_load_const_vcalls()) { 2434 for (StringRef Name : NameByGUID[VC.VFunc.GUID]) { 2435 CallSlots[{Name, VC.VFunc.Offset}] 2436 .ConstCSInfo[VC.Args] 2437 .addSummaryTypeCheckedLoadUser(FS); 2438 } 2439 } 2440 } 2441 } 2442 2443 std::set<ValueInfo> DevirtTargets; 2444 // For each (type, offset) pair: 2445 for (auto &S : CallSlots) { 2446 // Search each of the members of the type identifier for the virtual 2447 // function implementation at offset S.first.ByteOffset, and add to 2448 // TargetsForSlot. 2449 std::vector<ValueInfo> TargetsForSlot; 2450 auto TidSummary = ExportSummary.getTypeIdCompatibleVtableSummary(S.first.TypeID); 2451 assert(TidSummary); 2452 // The type id summary would have been created while building the NameByGUID 2453 // map earlier. 2454 WholeProgramDevirtResolution *Res = 2455 &ExportSummary.getTypeIdSummary(S.first.TypeID) 2456 ->WPDRes[S.first.ByteOffset]; 2457 if (tryFindVirtualCallTargets(TargetsForSlot, *TidSummary, 2458 S.first.ByteOffset)) { 2459 2460 if (!trySingleImplDevirt(TargetsForSlot, S.first, S.second, Res, 2461 DevirtTargets)) 2462 continue; 2463 } 2464 } 2465 2466 // Optionally have the thin link print message for each devirtualized 2467 // function. 2468 if (PrintSummaryDevirt) 2469 for (const auto &DT : DevirtTargets) 2470 errs() << "Devirtualized call to " << DT << "\n"; 2471 2472 NumDevirtTargets += DevirtTargets.size(); 2473 } 2474