xref: /freebsd/contrib/llvm-project/llvm/lib/Transforms/IPO/LowerTypeTests.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- LowerTypeTests.cpp - type metadata lowering pass -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass lowers type metadata and calls to the llvm.type.test intrinsic.
10 // It also ensures that globals are properly laid out for the
11 // llvm.icall.branch.funnel intrinsic.
12 // See http://llvm.org/docs/TypeMetadata.html for more information.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Transforms/IPO/LowerTypeTests.h"
17 #include "llvm/ADT/APInt.h"
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/DenseMap.h"
20 #include "llvm/ADT/EquivalenceClasses.h"
21 #include "llvm/ADT/PointerUnion.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/Statistic.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "llvm/ADT/TinyPtrVector.h"
27 #include "llvm/Analysis/TargetTransformInfo.h"
28 #include "llvm/Analysis/TypeMetadataUtils.h"
29 #include "llvm/Analysis/ValueTracking.h"
30 #include "llvm/IR/Attributes.h"
31 #include "llvm/IR/BasicBlock.h"
32 #include "llvm/IR/Constant.h"
33 #include "llvm/IR/Constants.h"
34 #include "llvm/IR/DataLayout.h"
35 #include "llvm/IR/DerivedTypes.h"
36 #include "llvm/IR/Function.h"
37 #include "llvm/IR/GlobalAlias.h"
38 #include "llvm/IR/GlobalObject.h"
39 #include "llvm/IR/GlobalValue.h"
40 #include "llvm/IR/GlobalVariable.h"
41 #include "llvm/IR/IRBuilder.h"
42 #include "llvm/IR/InlineAsm.h"
43 #include "llvm/IR/Instruction.h"
44 #include "llvm/IR/Instructions.h"
45 #include "llvm/IR/IntrinsicInst.h"
46 #include "llvm/IR/Intrinsics.h"
47 #include "llvm/IR/LLVMContext.h"
48 #include "llvm/IR/Metadata.h"
49 #include "llvm/IR/Module.h"
50 #include "llvm/IR/ModuleSummaryIndex.h"
51 #include "llvm/IR/ModuleSummaryIndexYAML.h"
52 #include "llvm/IR/Operator.h"
53 #include "llvm/IR/PassManager.h"
54 #include "llvm/IR/ReplaceConstant.h"
55 #include "llvm/IR/Type.h"
56 #include "llvm/IR/Use.h"
57 #include "llvm/IR/User.h"
58 #include "llvm/IR/Value.h"
59 #include "llvm/Support/Allocator.h"
60 #include "llvm/Support/Casting.h"
61 #include "llvm/Support/CommandLine.h"
62 #include "llvm/Support/Debug.h"
63 #include "llvm/Support/Error.h"
64 #include "llvm/Support/ErrorHandling.h"
65 #include "llvm/Support/FileSystem.h"
66 #include "llvm/Support/MathExtras.h"
67 #include "llvm/Support/MemoryBuffer.h"
68 #include "llvm/Support/TrailingObjects.h"
69 #include "llvm/Support/YAMLTraits.h"
70 #include "llvm/Support/raw_ostream.h"
71 #include "llvm/TargetParser/Triple.h"
72 #include "llvm/Transforms/IPO.h"
73 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
74 #include "llvm/Transforms/Utils/ModuleUtils.h"
75 #include <algorithm>
76 #include <cassert>
77 #include <cstdint>
78 #include <memory>
79 #include <set>
80 #include <string>
81 #include <system_error>
82 #include <utility>
83 #include <vector>
84 
85 using namespace llvm;
86 using namespace lowertypetests;
87 
88 #define DEBUG_TYPE "lowertypetests"
89 
90 STATISTIC(ByteArraySizeBits, "Byte array size in bits");
91 STATISTIC(ByteArraySizeBytes, "Byte array size in bytes");
92 STATISTIC(NumByteArraysCreated, "Number of byte arrays created");
93 STATISTIC(NumTypeTestCallsLowered, "Number of type test calls lowered");
94 STATISTIC(NumTypeIdDisjointSets, "Number of disjoint sets of type identifiers");
95 
96 static cl::opt<bool> AvoidReuse(
97     "lowertypetests-avoid-reuse",
98     cl::desc("Try to avoid reuse of byte array addresses using aliases"),
99     cl::Hidden, cl::init(true));
100 
101 static cl::opt<PassSummaryAction> ClSummaryAction(
102     "lowertypetests-summary-action",
103     cl::desc("What to do with the summary when running this pass"),
104     cl::values(clEnumValN(PassSummaryAction::None, "none", "Do nothing"),
105                clEnumValN(PassSummaryAction::Import, "import",
106                           "Import typeid resolutions from summary and globals"),
107                clEnumValN(PassSummaryAction::Export, "export",
108                           "Export typeid resolutions to summary and globals")),
109     cl::Hidden);
110 
111 static cl::opt<std::string> ClReadSummary(
112     "lowertypetests-read-summary",
113     cl::desc("Read summary from given YAML file before running pass"),
114     cl::Hidden);
115 
116 static cl::opt<std::string> ClWriteSummary(
117     "lowertypetests-write-summary",
118     cl::desc("Write summary to given YAML file after running pass"),
119     cl::Hidden);
120 
121 static cl::opt<bool>
122     ClDropTypeTests("lowertypetests-drop-type-tests",
123                     cl::desc("Simply drop type test assume sequences"),
124                     cl::Hidden, cl::init(false));
125 
containsGlobalOffset(uint64_t Offset) const126 bool BitSetInfo::containsGlobalOffset(uint64_t Offset) const {
127   if (Offset < ByteOffset)
128     return false;
129 
130   if ((Offset - ByteOffset) % (uint64_t(1) << AlignLog2) != 0)
131     return false;
132 
133   uint64_t BitOffset = (Offset - ByteOffset) >> AlignLog2;
134   if (BitOffset >= BitSize)
135     return false;
136 
137   return Bits.count(BitOffset);
138 }
139 
print(raw_ostream & OS) const140 void BitSetInfo::print(raw_ostream &OS) const {
141   OS << "offset " << ByteOffset << " size " << BitSize << " align "
142      << (1 << AlignLog2);
143 
144   if (isAllOnes()) {
145     OS << " all-ones\n";
146     return;
147   }
148 
149   OS << " { ";
150   for (uint64_t B : Bits)
151     OS << B << ' ';
152   OS << "}\n";
153 }
154 
build()155 BitSetInfo BitSetBuilder::build() {
156   if (Min > Max)
157     Min = 0;
158 
159   // Normalize each offset against the minimum observed offset, and compute
160   // the bitwise OR of each of the offsets. The number of trailing zeros
161   // in the mask gives us the log2 of the alignment of all offsets, which
162   // allows us to compress the bitset by only storing one bit per aligned
163   // address.
164   uint64_t Mask = 0;
165   for (uint64_t &Offset : Offsets) {
166     Offset -= Min;
167     Mask |= Offset;
168   }
169 
170   BitSetInfo BSI;
171   BSI.ByteOffset = Min;
172 
173   BSI.AlignLog2 = 0;
174   if (Mask != 0)
175     BSI.AlignLog2 = llvm::countr_zero(Mask);
176 
177   // Build the compressed bitset while normalizing the offsets against the
178   // computed alignment.
179   BSI.BitSize = ((Max - Min) >> BSI.AlignLog2) + 1;
180   for (uint64_t Offset : Offsets) {
181     Offset >>= BSI.AlignLog2;
182     BSI.Bits.insert(Offset);
183   }
184 
185   return BSI;
186 }
187 
addFragment(const std::set<uint64_t> & F)188 void GlobalLayoutBuilder::addFragment(const std::set<uint64_t> &F) {
189   // Create a new fragment to hold the layout for F.
190   Fragments.emplace_back();
191   std::vector<uint64_t> &Fragment = Fragments.back();
192   uint64_t FragmentIndex = Fragments.size() - 1;
193 
194   for (auto ObjIndex : F) {
195     uint64_t OldFragmentIndex = FragmentMap[ObjIndex];
196     if (OldFragmentIndex == 0) {
197       // We haven't seen this object index before, so just add it to the current
198       // fragment.
199       Fragment.push_back(ObjIndex);
200     } else {
201       // This index belongs to an existing fragment. Copy the elements of the
202       // old fragment into this one and clear the old fragment. We don't update
203       // the fragment map just yet, this ensures that any further references to
204       // indices from the old fragment in this fragment do not insert any more
205       // indices.
206       std::vector<uint64_t> &OldFragment = Fragments[OldFragmentIndex];
207       llvm::append_range(Fragment, OldFragment);
208       OldFragment.clear();
209     }
210   }
211 
212   // Update the fragment map to point our object indices to this fragment.
213   for (uint64_t ObjIndex : Fragment)
214     FragmentMap[ObjIndex] = FragmentIndex;
215 }
216 
allocate(const std::set<uint64_t> & Bits,uint64_t BitSize,uint64_t & AllocByteOffset,uint8_t & AllocMask)217 void ByteArrayBuilder::allocate(const std::set<uint64_t> &Bits,
218                                 uint64_t BitSize, uint64_t &AllocByteOffset,
219                                 uint8_t &AllocMask) {
220   // Find the smallest current allocation.
221   unsigned Bit = 0;
222   for (unsigned I = 1; I != BitsPerByte; ++I)
223     if (BitAllocs[I] < BitAllocs[Bit])
224       Bit = I;
225 
226   AllocByteOffset = BitAllocs[Bit];
227 
228   // Add our size to it.
229   unsigned ReqSize = AllocByteOffset + BitSize;
230   BitAllocs[Bit] = ReqSize;
231   if (Bytes.size() < ReqSize)
232     Bytes.resize(ReqSize);
233 
234   // Set our bits.
235   AllocMask = 1 << Bit;
236   for (uint64_t B : Bits)
237     Bytes[AllocByteOffset + B] |= AllocMask;
238 }
239 
isJumpTableCanonical(Function * F)240 bool lowertypetests::isJumpTableCanonical(Function *F) {
241   if (F->isDeclarationForLinker())
242     return false;
243   auto *CI = mdconst::extract_or_null<ConstantInt>(
244       F->getParent()->getModuleFlag("CFI Canonical Jump Tables"));
245   if (!CI || !CI->isZero())
246     return true;
247   return F->hasFnAttribute("cfi-canonical-jump-table");
248 }
249 
250 namespace {
251 
252 struct ByteArrayInfo {
253   std::set<uint64_t> Bits;
254   uint64_t BitSize;
255   GlobalVariable *ByteArray;
256   GlobalVariable *MaskGlobal;
257   uint8_t *MaskPtr = nullptr;
258 };
259 
260 /// A POD-like structure that we use to store a global reference together with
261 /// its metadata types. In this pass we frequently need to query the set of
262 /// metadata types referenced by a global, which at the IR level is an expensive
263 /// operation involving a map lookup; this data structure helps to reduce the
264 /// number of times we need to do this lookup.
265 class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> {
266   friend TrailingObjects;
267 
268   GlobalObject *GO;
269   size_t NTypes;
270 
271   // For functions: true if the jump table is canonical. This essentially means
272   // whether the canonical address (i.e. the symbol table entry) of the function
273   // is provided by the local jump table. This is normally the same as whether
274   // the function is defined locally, but if canonical jump tables are disabled
275   // by the user then the jump table never provides a canonical definition.
276   bool IsJumpTableCanonical;
277 
278   // For functions: true if this function is either defined or used in a thinlto
279   // module and its jumptable entry needs to be exported to thinlto backends.
280   bool IsExported;
281 
numTrailingObjects(OverloadToken<MDNode * >) const282   size_t numTrailingObjects(OverloadToken<MDNode *>) const { return NTypes; }
283 
284 public:
create(BumpPtrAllocator & Alloc,GlobalObject * GO,bool IsJumpTableCanonical,bool IsExported,ArrayRef<MDNode * > Types)285   static GlobalTypeMember *create(BumpPtrAllocator &Alloc, GlobalObject *GO,
286                                   bool IsJumpTableCanonical, bool IsExported,
287                                   ArrayRef<MDNode *> Types) {
288     auto *GTM = static_cast<GlobalTypeMember *>(Alloc.Allocate(
289         totalSizeToAlloc<MDNode *>(Types.size()), alignof(GlobalTypeMember)));
290     GTM->GO = GO;
291     GTM->NTypes = Types.size();
292     GTM->IsJumpTableCanonical = IsJumpTableCanonical;
293     GTM->IsExported = IsExported;
294     std::uninitialized_copy(Types.begin(), Types.end(),
295                             GTM->getTrailingObjects<MDNode *>());
296     return GTM;
297   }
298 
getGlobal() const299   GlobalObject *getGlobal() const {
300     return GO;
301   }
302 
isJumpTableCanonical() const303   bool isJumpTableCanonical() const {
304     return IsJumpTableCanonical;
305   }
306 
isExported() const307   bool isExported() const {
308     return IsExported;
309   }
310 
types() const311   ArrayRef<MDNode *> types() const {
312     return ArrayRef(getTrailingObjects<MDNode *>(), NTypes);
313   }
314 };
315 
316 struct ICallBranchFunnel final
317     : TrailingObjects<ICallBranchFunnel, GlobalTypeMember *> {
create__anonc00fd2a30111::ICallBranchFunnel318   static ICallBranchFunnel *create(BumpPtrAllocator &Alloc, CallInst *CI,
319                                    ArrayRef<GlobalTypeMember *> Targets,
320                                    unsigned UniqueId) {
321     auto *Call = static_cast<ICallBranchFunnel *>(
322         Alloc.Allocate(totalSizeToAlloc<GlobalTypeMember *>(Targets.size()),
323                        alignof(ICallBranchFunnel)));
324     Call->CI = CI;
325     Call->UniqueId = UniqueId;
326     Call->NTargets = Targets.size();
327     std::uninitialized_copy(Targets.begin(), Targets.end(),
328                             Call->getTrailingObjects<GlobalTypeMember *>());
329     return Call;
330   }
331 
332   CallInst *CI;
targets__anonc00fd2a30111::ICallBranchFunnel333   ArrayRef<GlobalTypeMember *> targets() const {
334     return ArrayRef(getTrailingObjects<GlobalTypeMember *>(), NTargets);
335   }
336 
337   unsigned UniqueId;
338 
339 private:
340   size_t NTargets;
341 };
342 
343 struct ScopedSaveAliaseesAndUsed {
344   Module &M;
345   SmallVector<GlobalValue *, 4> Used, CompilerUsed;
346   std::vector<std::pair<GlobalAlias *, Function *>> FunctionAliases;
347   std::vector<std::pair<GlobalIFunc *, Function *>> ResolverIFuncs;
348 
ScopedSaveAliaseesAndUsed__anonc00fd2a30111::ScopedSaveAliaseesAndUsed349   ScopedSaveAliaseesAndUsed(Module &M) : M(M) {
350     // The users of this class want to replace all function references except
351     // for aliases and llvm.used/llvm.compiler.used with references to a jump
352     // table. We avoid replacing aliases in order to avoid introducing a double
353     // indirection (or an alias pointing to a declaration in ThinLTO mode), and
354     // we avoid replacing llvm.used/llvm.compiler.used because these global
355     // variables describe properties of the global, not the jump table (besides,
356     // offseted references to the jump table in llvm.used are invalid).
357     // Unfortunately, LLVM doesn't have a "RAUW except for these (possibly
358     // indirect) users", so what we do is save the list of globals referenced by
359     // llvm.used/llvm.compiler.used and aliases, erase the used lists, let RAUW
360     // replace the aliasees and then set them back to their original values at
361     // the end.
362     if (GlobalVariable *GV = collectUsedGlobalVariables(M, Used, false))
363       GV->eraseFromParent();
364     if (GlobalVariable *GV = collectUsedGlobalVariables(M, CompilerUsed, true))
365       GV->eraseFromParent();
366 
367     for (auto &GA : M.aliases()) {
368       // FIXME: This should look past all aliases not just interposable ones,
369       // see discussion on D65118.
370       if (auto *F = dyn_cast<Function>(GA.getAliasee()->stripPointerCasts()))
371         FunctionAliases.push_back({&GA, F});
372     }
373 
374     for (auto &GI : M.ifuncs())
375       if (auto *F = dyn_cast<Function>(GI.getResolver()->stripPointerCasts()))
376         ResolverIFuncs.push_back({&GI, F});
377   }
378 
~ScopedSaveAliaseesAndUsed__anonc00fd2a30111::ScopedSaveAliaseesAndUsed379   ~ScopedSaveAliaseesAndUsed() {
380     appendToUsed(M, Used);
381     appendToCompilerUsed(M, CompilerUsed);
382 
383     for (auto P : FunctionAliases)
384       P.first->setAliasee(P.second);
385 
386     for (auto P : ResolverIFuncs) {
387       // This does not preserve pointer casts that may have been stripped by the
388       // constructor, but the resolver's type is different from that of the
389       // ifunc anyway.
390       P.first->setResolver(P.second);
391     }
392   }
393 };
394 
395 class LowerTypeTestsModule {
396   Module &M;
397 
398   ModuleSummaryIndex *ExportSummary;
399   const ModuleSummaryIndex *ImportSummary;
400   // Set when the client has invoked this to simply drop all type test assume
401   // sequences.
402   bool DropTypeTests;
403 
404   Triple::ArchType Arch;
405   Triple::OSType OS;
406   Triple::ObjectFormatType ObjectFormat;
407 
408   // Determines which kind of Thumb jump table we generate. If arch is
409   // either 'arm' or 'thumb' we need to find this out, because
410   // selectJumpTableArmEncoding may decide to use Thumb in either case.
411   bool CanUseArmJumpTable = false, CanUseThumbBWJumpTable = false;
412 
413   // Cache variable used by hasBranchTargetEnforcement().
414   int HasBranchTargetEnforcement = -1;
415 
416   // The jump table type we ended up deciding on. (Usually the same as
417   // Arch, except that 'arm' and 'thumb' are often interchangeable.)
418   Triple::ArchType JumpTableArch = Triple::UnknownArch;
419 
420   IntegerType *Int1Ty = Type::getInt1Ty(M.getContext());
421   IntegerType *Int8Ty = Type::getInt8Ty(M.getContext());
422   PointerType *Int8PtrTy = PointerType::getUnqual(M.getContext());
423   ArrayType *Int8Arr0Ty = ArrayType::get(Type::getInt8Ty(M.getContext()), 0);
424   IntegerType *Int32Ty = Type::getInt32Ty(M.getContext());
425   PointerType *Int32PtrTy = PointerType::getUnqual(M.getContext());
426   IntegerType *Int64Ty = Type::getInt64Ty(M.getContext());
427   IntegerType *IntPtrTy = M.getDataLayout().getIntPtrType(M.getContext(), 0);
428 
429   // Indirect function call index assignment counter for WebAssembly
430   uint64_t IndirectIndex = 1;
431 
432   // Mapping from type identifiers to the call sites that test them, as well as
433   // whether the type identifier needs to be exported to ThinLTO backends as
434   // part of the regular LTO phase of the ThinLTO pipeline (see exportTypeId).
435   struct TypeIdUserInfo {
436     std::vector<CallInst *> CallSites;
437     bool IsExported = false;
438   };
439   DenseMap<Metadata *, TypeIdUserInfo> TypeIdUsers;
440 
441   /// This structure describes how to lower type tests for a particular type
442   /// identifier. It is either built directly from the global analysis (during
443   /// regular LTO or the regular LTO phase of ThinLTO), or indirectly using type
444   /// identifier summaries and external symbol references (in ThinLTO backends).
445   struct TypeIdLowering {
446     TypeTestResolution::Kind TheKind = TypeTestResolution::Unsat;
447 
448     /// All except Unsat: the start address within the combined global.
449     Constant *OffsetedGlobal;
450 
451     /// ByteArray, Inline, AllOnes: log2 of the required global alignment
452     /// relative to the start address.
453     Constant *AlignLog2;
454 
455     /// ByteArray, Inline, AllOnes: one less than the size of the memory region
456     /// covering members of this type identifier as a multiple of 2^AlignLog2.
457     Constant *SizeM1;
458 
459     /// ByteArray: the byte array to test the address against.
460     Constant *TheByteArray;
461 
462     /// ByteArray: the bit mask to apply to bytes loaded from the byte array.
463     Constant *BitMask;
464 
465     /// Inline: the bit mask to test the address against.
466     Constant *InlineBits;
467   };
468 
469   std::vector<ByteArrayInfo> ByteArrayInfos;
470 
471   Function *WeakInitializerFn = nullptr;
472 
473   GlobalVariable *GlobalAnnotation;
474   DenseSet<Value *> FunctionAnnotations;
475 
476   bool shouldExportConstantsAsAbsoluteSymbols();
477   uint8_t *exportTypeId(StringRef TypeId, const TypeIdLowering &TIL);
478   TypeIdLowering importTypeId(StringRef TypeId);
479   void importTypeTest(CallInst *CI);
480   void importFunction(Function *F, bool isJumpTableCanonical,
481                       std::vector<GlobalAlias *> &AliasesToErase);
482 
483   BitSetInfo
484   buildBitSet(Metadata *TypeId,
485               const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout);
486   ByteArrayInfo *createByteArray(BitSetInfo &BSI);
487   void allocateByteArrays();
488   Value *createBitSetTest(IRBuilder<> &B, const TypeIdLowering &TIL,
489                           Value *BitOffset);
490   void lowerTypeTestCalls(
491       ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr,
492       const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout);
493   Value *lowerTypeTestCall(Metadata *TypeId, CallInst *CI,
494                            const TypeIdLowering &TIL);
495 
496   void buildBitSetsFromGlobalVariables(ArrayRef<Metadata *> TypeIds,
497                                        ArrayRef<GlobalTypeMember *> Globals);
498   Triple::ArchType
499   selectJumpTableArmEncoding(ArrayRef<GlobalTypeMember *> Functions);
500   bool hasBranchTargetEnforcement();
501   unsigned getJumpTableEntrySize();
502   Type *getJumpTableEntryType();
503   void createJumpTableEntry(raw_ostream &AsmOS, raw_ostream &ConstraintOS,
504                             Triple::ArchType JumpTableArch,
505                             SmallVectorImpl<Value *> &AsmArgs, Function *Dest);
506   void verifyTypeMDNode(GlobalObject *GO, MDNode *Type);
507   void buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds,
508                                  ArrayRef<GlobalTypeMember *> Functions);
509   void buildBitSetsFromFunctionsNative(ArrayRef<Metadata *> TypeIds,
510                                        ArrayRef<GlobalTypeMember *> Functions);
511   void buildBitSetsFromFunctionsWASM(ArrayRef<Metadata *> TypeIds,
512                                      ArrayRef<GlobalTypeMember *> Functions);
513   void
514   buildBitSetsFromDisjointSet(ArrayRef<Metadata *> TypeIds,
515                               ArrayRef<GlobalTypeMember *> Globals,
516                               ArrayRef<ICallBranchFunnel *> ICallBranchFunnels);
517 
518   void replaceWeakDeclarationWithJumpTablePtr(Function *F, Constant *JT,
519                                               bool IsJumpTableCanonical);
520   void moveInitializerToModuleConstructor(GlobalVariable *GV);
521   void findGlobalVariableUsersOf(Constant *C,
522                                  SmallSetVector<GlobalVariable *, 8> &Out);
523 
524   void createJumpTable(Function *F, ArrayRef<GlobalTypeMember *> Functions);
525 
526   /// replaceCfiUses - Go through the uses list for this definition
527   /// and make each use point to "V" instead of "this" when the use is outside
528   /// the block. 'This's use list is expected to have at least one element.
529   /// Unlike replaceAllUsesWith this function skips blockaddr and direct call
530   /// uses.
531   void replaceCfiUses(Function *Old, Value *New, bool IsJumpTableCanonical);
532 
533   /// replaceDirectCalls - Go through the uses list for this definition and
534   /// replace each use, which is a direct function call.
535   void replaceDirectCalls(Value *Old, Value *New);
536 
isFunctionAnnotation(Value * V) const537   bool isFunctionAnnotation(Value *V) const {
538     return FunctionAnnotations.contains(V);
539   }
540 
541 public:
542   LowerTypeTestsModule(Module &M, ModuleAnalysisManager &AM,
543                        ModuleSummaryIndex *ExportSummary,
544                        const ModuleSummaryIndex *ImportSummary,
545                        bool DropTypeTests);
546 
547   bool lower();
548 
549   // Lower the module using the action and summary passed as command line
550   // arguments. For testing purposes only.
551   static bool runForTesting(Module &M, ModuleAnalysisManager &AM);
552 };
553 } // end anonymous namespace
554 
555 /// Build a bit set for TypeId using the object layouts in
556 /// GlobalLayout.
buildBitSet(Metadata * TypeId,const DenseMap<GlobalTypeMember *,uint64_t> & GlobalLayout)557 BitSetInfo LowerTypeTestsModule::buildBitSet(
558     Metadata *TypeId,
559     const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout) {
560   BitSetBuilder BSB;
561 
562   // Compute the byte offset of each address associated with this type
563   // identifier.
564   for (const auto &GlobalAndOffset : GlobalLayout) {
565     for (MDNode *Type : GlobalAndOffset.first->types()) {
566       if (Type->getOperand(1) != TypeId)
567         continue;
568       uint64_t Offset =
569           cast<ConstantInt>(
570               cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())
571               ->getZExtValue();
572       BSB.addOffset(GlobalAndOffset.second + Offset);
573     }
574   }
575 
576   return BSB.build();
577 }
578 
579 /// Build a test that bit BitOffset mod sizeof(Bits)*8 is set in
580 /// Bits. This pattern matches to the bt instruction on x86.
createMaskedBitTest(IRBuilder<> & B,Value * Bits,Value * BitOffset)581 static Value *createMaskedBitTest(IRBuilder<> &B, Value *Bits,
582                                   Value *BitOffset) {
583   auto BitsType = cast<IntegerType>(Bits->getType());
584   unsigned BitWidth = BitsType->getBitWidth();
585 
586   BitOffset = B.CreateZExtOrTrunc(BitOffset, BitsType);
587   Value *BitIndex =
588       B.CreateAnd(BitOffset, ConstantInt::get(BitsType, BitWidth - 1));
589   Value *BitMask = B.CreateShl(ConstantInt::get(BitsType, 1), BitIndex);
590   Value *MaskedBits = B.CreateAnd(Bits, BitMask);
591   return B.CreateICmpNE(MaskedBits, ConstantInt::get(BitsType, 0));
592 }
593 
createByteArray(BitSetInfo & BSI)594 ByteArrayInfo *LowerTypeTestsModule::createByteArray(BitSetInfo &BSI) {
595   // Create globals to stand in for byte arrays and masks. These never actually
596   // get initialized, we RAUW and erase them later in allocateByteArrays() once
597   // we know the offset and mask to use.
598   auto ByteArrayGlobal = new GlobalVariable(
599       M, Int8Ty, /*isConstant=*/true, GlobalValue::PrivateLinkage, nullptr);
600   auto MaskGlobal = new GlobalVariable(M, Int8Ty, /*isConstant=*/true,
601                                        GlobalValue::PrivateLinkage, nullptr);
602 
603   ByteArrayInfos.emplace_back();
604   ByteArrayInfo *BAI = &ByteArrayInfos.back();
605 
606   BAI->Bits = BSI.Bits;
607   BAI->BitSize = BSI.BitSize;
608   BAI->ByteArray = ByteArrayGlobal;
609   BAI->MaskGlobal = MaskGlobal;
610   return BAI;
611 }
612 
allocateByteArrays()613 void LowerTypeTestsModule::allocateByteArrays() {
614   llvm::stable_sort(ByteArrayInfos,
615                     [](const ByteArrayInfo &BAI1, const ByteArrayInfo &BAI2) {
616                       return BAI1.BitSize > BAI2.BitSize;
617                     });
618 
619   std::vector<uint64_t> ByteArrayOffsets(ByteArrayInfos.size());
620 
621   ByteArrayBuilder BAB;
622   for (unsigned I = 0; I != ByteArrayInfos.size(); ++I) {
623     ByteArrayInfo *BAI = &ByteArrayInfos[I];
624 
625     uint8_t Mask;
626     BAB.allocate(BAI->Bits, BAI->BitSize, ByteArrayOffsets[I], Mask);
627 
628     BAI->MaskGlobal->replaceAllUsesWith(
629         ConstantExpr::getIntToPtr(ConstantInt::get(Int8Ty, Mask), Int8PtrTy));
630     BAI->MaskGlobal->eraseFromParent();
631     if (BAI->MaskPtr)
632       *BAI->MaskPtr = Mask;
633   }
634 
635   Constant *ByteArrayConst = ConstantDataArray::get(M.getContext(), BAB.Bytes);
636   auto ByteArray =
637       new GlobalVariable(M, ByteArrayConst->getType(), /*isConstant=*/true,
638                          GlobalValue::PrivateLinkage, ByteArrayConst);
639 
640   for (unsigned I = 0; I != ByteArrayInfos.size(); ++I) {
641     ByteArrayInfo *BAI = &ByteArrayInfos[I];
642 
643     Constant *Idxs[] = {ConstantInt::get(IntPtrTy, 0),
644                         ConstantInt::get(IntPtrTy, ByteArrayOffsets[I])};
645     Constant *GEP = ConstantExpr::getInBoundsGetElementPtr(
646         ByteArrayConst->getType(), ByteArray, Idxs);
647 
648     // Create an alias instead of RAUW'ing the gep directly. On x86 this ensures
649     // that the pc-relative displacement is folded into the lea instead of the
650     // test instruction getting another displacement.
651     GlobalAlias *Alias = GlobalAlias::create(
652         Int8Ty, 0, GlobalValue::PrivateLinkage, "bits", GEP, &M);
653     BAI->ByteArray->replaceAllUsesWith(Alias);
654     BAI->ByteArray->eraseFromParent();
655   }
656 
657   ByteArraySizeBits = BAB.BitAllocs[0] + BAB.BitAllocs[1] + BAB.BitAllocs[2] +
658                       BAB.BitAllocs[3] + BAB.BitAllocs[4] + BAB.BitAllocs[5] +
659                       BAB.BitAllocs[6] + BAB.BitAllocs[7];
660   ByteArraySizeBytes = BAB.Bytes.size();
661 }
662 
663 /// Build a test that bit BitOffset is set in the type identifier that was
664 /// lowered to TIL, which must be either an Inline or a ByteArray.
createBitSetTest(IRBuilder<> & B,const TypeIdLowering & TIL,Value * BitOffset)665 Value *LowerTypeTestsModule::createBitSetTest(IRBuilder<> &B,
666                                               const TypeIdLowering &TIL,
667                                               Value *BitOffset) {
668   if (TIL.TheKind == TypeTestResolution::Inline) {
669     // If the bit set is sufficiently small, we can avoid a load by bit testing
670     // a constant.
671     return createMaskedBitTest(B, TIL.InlineBits, BitOffset);
672   } else {
673     Constant *ByteArray = TIL.TheByteArray;
674     if (AvoidReuse && !ImportSummary) {
675       // Each use of the byte array uses a different alias. This makes the
676       // backend less likely to reuse previously computed byte array addresses,
677       // improving the security of the CFI mechanism based on this pass.
678       // This won't work when importing because TheByteArray is external.
679       ByteArray = GlobalAlias::create(Int8Ty, 0, GlobalValue::PrivateLinkage,
680                                       "bits_use", ByteArray, &M);
681     }
682 
683     Value *ByteAddr = B.CreateGEP(Int8Ty, ByteArray, BitOffset);
684     Value *Byte = B.CreateLoad(Int8Ty, ByteAddr);
685 
686     Value *ByteAndMask =
687         B.CreateAnd(Byte, ConstantExpr::getPtrToInt(TIL.BitMask, Int8Ty));
688     return B.CreateICmpNE(ByteAndMask, ConstantInt::get(Int8Ty, 0));
689   }
690 }
691 
isKnownTypeIdMember(Metadata * TypeId,const DataLayout & DL,Value * V,uint64_t COffset)692 static bool isKnownTypeIdMember(Metadata *TypeId, const DataLayout &DL,
693                                 Value *V, uint64_t COffset) {
694   if (auto GV = dyn_cast<GlobalObject>(V)) {
695     SmallVector<MDNode *, 2> Types;
696     GV->getMetadata(LLVMContext::MD_type, Types);
697     for (MDNode *Type : Types) {
698       if (Type->getOperand(1) != TypeId)
699         continue;
700       uint64_t Offset =
701           cast<ConstantInt>(
702               cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())
703               ->getZExtValue();
704       if (COffset == Offset)
705         return true;
706     }
707     return false;
708   }
709 
710   if (auto GEP = dyn_cast<GEPOperator>(V)) {
711     APInt APOffset(DL.getIndexSizeInBits(0), 0);
712     bool Result = GEP->accumulateConstantOffset(DL, APOffset);
713     if (!Result)
714       return false;
715     COffset += APOffset.getZExtValue();
716     return isKnownTypeIdMember(TypeId, DL, GEP->getPointerOperand(), COffset);
717   }
718 
719   if (auto Op = dyn_cast<Operator>(V)) {
720     if (Op->getOpcode() == Instruction::BitCast)
721       return isKnownTypeIdMember(TypeId, DL, Op->getOperand(0), COffset);
722 
723     if (Op->getOpcode() == Instruction::Select)
724       return isKnownTypeIdMember(TypeId, DL, Op->getOperand(1), COffset) &&
725              isKnownTypeIdMember(TypeId, DL, Op->getOperand(2), COffset);
726   }
727 
728   return false;
729 }
730 
731 /// Lower a llvm.type.test call to its implementation. Returns the value to
732 /// replace the call with.
lowerTypeTestCall(Metadata * TypeId,CallInst * CI,const TypeIdLowering & TIL)733 Value *LowerTypeTestsModule::lowerTypeTestCall(Metadata *TypeId, CallInst *CI,
734                                                const TypeIdLowering &TIL) {
735   // Delay lowering if the resolution is currently unknown.
736   if (TIL.TheKind == TypeTestResolution::Unknown)
737     return nullptr;
738   if (TIL.TheKind == TypeTestResolution::Unsat)
739     return ConstantInt::getFalse(M.getContext());
740 
741   Value *Ptr = CI->getArgOperand(0);
742   const DataLayout &DL = M.getDataLayout();
743   if (isKnownTypeIdMember(TypeId, DL, Ptr, 0))
744     return ConstantInt::getTrue(M.getContext());
745 
746   BasicBlock *InitialBB = CI->getParent();
747 
748   IRBuilder<> B(CI);
749 
750   Value *PtrAsInt = B.CreatePtrToInt(Ptr, IntPtrTy);
751 
752   Constant *OffsetedGlobalAsInt =
753       ConstantExpr::getPtrToInt(TIL.OffsetedGlobal, IntPtrTy);
754   if (TIL.TheKind == TypeTestResolution::Single)
755     return B.CreateICmpEQ(PtrAsInt, OffsetedGlobalAsInt);
756 
757   Value *PtrOffset = B.CreateSub(PtrAsInt, OffsetedGlobalAsInt);
758 
759   // We need to check that the offset both falls within our range and is
760   // suitably aligned. We can check both properties at the same time by
761   // performing a right rotate by log2(alignment) followed by an integer
762   // comparison against the bitset size. The rotate will move the lower
763   // order bits that need to be zero into the higher order bits of the
764   // result, causing the comparison to fail if they are nonzero. The rotate
765   // also conveniently gives us a bit offset to use during the load from
766   // the bitset.
767   Value *OffsetSHR =
768       B.CreateLShr(PtrOffset, B.CreateZExt(TIL.AlignLog2, IntPtrTy));
769   Value *OffsetSHL = B.CreateShl(
770       PtrOffset, B.CreateZExt(
771                      ConstantExpr::getSub(
772                          ConstantInt::get(Int8Ty, DL.getPointerSizeInBits(0)),
773                          TIL.AlignLog2),
774                      IntPtrTy));
775   Value *BitOffset = B.CreateOr(OffsetSHR, OffsetSHL);
776 
777   Value *OffsetInRange = B.CreateICmpULE(BitOffset, TIL.SizeM1);
778 
779   // If the bit set is all ones, testing against it is unnecessary.
780   if (TIL.TheKind == TypeTestResolution::AllOnes)
781     return OffsetInRange;
782 
783   // See if the intrinsic is used in the following common pattern:
784   //   br(llvm.type.test(...), thenbb, elsebb)
785   // where nothing happens between the type test and the br.
786   // If so, create slightly simpler IR.
787   if (CI->hasOneUse())
788     if (auto *Br = dyn_cast<BranchInst>(*CI->user_begin()))
789       if (CI->getNextNode() == Br) {
790         BasicBlock *Then = InitialBB->splitBasicBlock(CI->getIterator());
791         BasicBlock *Else = Br->getSuccessor(1);
792         BranchInst *NewBr = BranchInst::Create(Then, Else, OffsetInRange);
793         NewBr->setMetadata(LLVMContext::MD_prof,
794                            Br->getMetadata(LLVMContext::MD_prof));
795         ReplaceInstWithInst(InitialBB->getTerminator(), NewBr);
796 
797         // Update phis in Else resulting from InitialBB being split
798         for (auto &Phi : Else->phis())
799           Phi.addIncoming(Phi.getIncomingValueForBlock(Then), InitialBB);
800 
801         IRBuilder<> ThenB(CI);
802         return createBitSetTest(ThenB, TIL, BitOffset);
803       }
804 
805   IRBuilder<> ThenB(SplitBlockAndInsertIfThen(OffsetInRange, CI, false));
806 
807   // Now that we know that the offset is in range and aligned, load the
808   // appropriate bit from the bitset.
809   Value *Bit = createBitSetTest(ThenB, TIL, BitOffset);
810 
811   // The value we want is 0 if we came directly from the initial block
812   // (having failed the range or alignment checks), or the loaded bit if
813   // we came from the block in which we loaded it.
814   B.SetInsertPoint(CI);
815   PHINode *P = B.CreatePHI(Int1Ty, 2);
816   P->addIncoming(ConstantInt::get(Int1Ty, 0), InitialBB);
817   P->addIncoming(Bit, ThenB.GetInsertBlock());
818   return P;
819 }
820 
821 /// Given a disjoint set of type identifiers and globals, lay out the globals,
822 /// build the bit sets and lower the llvm.type.test calls.
buildBitSetsFromGlobalVariables(ArrayRef<Metadata * > TypeIds,ArrayRef<GlobalTypeMember * > Globals)823 void LowerTypeTestsModule::buildBitSetsFromGlobalVariables(
824     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Globals) {
825   // Build a new global with the combined contents of the referenced globals.
826   // This global is a struct whose even-indexed elements contain the original
827   // contents of the referenced globals and whose odd-indexed elements contain
828   // any padding required to align the next element to the next power of 2 plus
829   // any additional padding required to meet its alignment requirements.
830   std::vector<Constant *> GlobalInits;
831   const DataLayout &DL = M.getDataLayout();
832   DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout;
833   Align MaxAlign;
834   uint64_t CurOffset = 0;
835   uint64_t DesiredPadding = 0;
836   for (GlobalTypeMember *G : Globals) {
837     auto *GV = cast<GlobalVariable>(G->getGlobal());
838     Align Alignment =
839         DL.getValueOrABITypeAlignment(GV->getAlign(), GV->getValueType());
840     MaxAlign = std::max(MaxAlign, Alignment);
841     uint64_t GVOffset = alignTo(CurOffset + DesiredPadding, Alignment);
842     GlobalLayout[G] = GVOffset;
843     if (GVOffset != 0) {
844       uint64_t Padding = GVOffset - CurOffset;
845       GlobalInits.push_back(
846           ConstantAggregateZero::get(ArrayType::get(Int8Ty, Padding)));
847     }
848 
849     GlobalInits.push_back(GV->getInitializer());
850     uint64_t InitSize = DL.getTypeAllocSize(GV->getValueType());
851     CurOffset = GVOffset + InitSize;
852 
853     // Compute the amount of padding that we'd like for the next element.
854     DesiredPadding = NextPowerOf2(InitSize - 1) - InitSize;
855 
856     // Experiments of different caps with Chromium on both x64 and ARM64
857     // have shown that the 32-byte cap generates the smallest binary on
858     // both platforms while different caps yield similar performance.
859     // (see https://lists.llvm.org/pipermail/llvm-dev/2018-July/124694.html)
860     if (DesiredPadding > 32)
861       DesiredPadding = alignTo(InitSize, 32) - InitSize;
862   }
863 
864   Constant *NewInit = ConstantStruct::getAnon(M.getContext(), GlobalInits);
865   auto *CombinedGlobal =
866       new GlobalVariable(M, NewInit->getType(), /*isConstant=*/true,
867                          GlobalValue::PrivateLinkage, NewInit);
868   CombinedGlobal->setAlignment(MaxAlign);
869 
870   StructType *NewTy = cast<StructType>(NewInit->getType());
871   lowerTypeTestCalls(TypeIds, CombinedGlobal, GlobalLayout);
872 
873   // Build aliases pointing to offsets into the combined global for each
874   // global from which we built the combined global, and replace references
875   // to the original globals with references to the aliases.
876   for (unsigned I = 0; I != Globals.size(); ++I) {
877     GlobalVariable *GV = cast<GlobalVariable>(Globals[I]->getGlobal());
878 
879     // Multiply by 2 to account for padding elements.
880     Constant *CombinedGlobalIdxs[] = {ConstantInt::get(Int32Ty, 0),
881                                       ConstantInt::get(Int32Ty, I * 2)};
882     Constant *CombinedGlobalElemPtr = ConstantExpr::getInBoundsGetElementPtr(
883         NewInit->getType(), CombinedGlobal, CombinedGlobalIdxs);
884     assert(GV->getType()->getAddressSpace() == 0);
885     GlobalAlias *GAlias =
886         GlobalAlias::create(NewTy->getElementType(I * 2), 0, GV->getLinkage(),
887                             "", CombinedGlobalElemPtr, &M);
888     GAlias->setVisibility(GV->getVisibility());
889     GAlias->takeName(GV);
890     GV->replaceAllUsesWith(GAlias);
891     GV->eraseFromParent();
892   }
893 }
894 
shouldExportConstantsAsAbsoluteSymbols()895 bool LowerTypeTestsModule::shouldExportConstantsAsAbsoluteSymbols() {
896   return (Arch == Triple::x86 || Arch == Triple::x86_64) &&
897          ObjectFormat == Triple::ELF;
898 }
899 
900 /// Export the given type identifier so that ThinLTO backends may import it.
901 /// Type identifiers are exported by adding coarse-grained information about how
902 /// to test the type identifier to the summary, and creating symbols in the
903 /// object file (aliases and absolute symbols) containing fine-grained
904 /// information about the type identifier.
905 ///
906 /// Returns a pointer to the location in which to store the bitmask, if
907 /// applicable.
exportTypeId(StringRef TypeId,const TypeIdLowering & TIL)908 uint8_t *LowerTypeTestsModule::exportTypeId(StringRef TypeId,
909                                             const TypeIdLowering &TIL) {
910   TypeTestResolution &TTRes =
911       ExportSummary->getOrInsertTypeIdSummary(TypeId).TTRes;
912   TTRes.TheKind = TIL.TheKind;
913 
914   auto ExportGlobal = [&](StringRef Name, Constant *C) {
915     GlobalAlias *GA =
916         GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage,
917                             "__typeid_" + TypeId + "_" + Name, C, &M);
918     GA->setVisibility(GlobalValue::HiddenVisibility);
919   };
920 
921   auto ExportConstant = [&](StringRef Name, uint64_t &Storage, Constant *C) {
922     if (shouldExportConstantsAsAbsoluteSymbols())
923       ExportGlobal(Name, ConstantExpr::getIntToPtr(C, Int8PtrTy));
924     else
925       Storage = cast<ConstantInt>(C)->getZExtValue();
926   };
927 
928   if (TIL.TheKind != TypeTestResolution::Unsat)
929     ExportGlobal("global_addr", TIL.OffsetedGlobal);
930 
931   if (TIL.TheKind == TypeTestResolution::ByteArray ||
932       TIL.TheKind == TypeTestResolution::Inline ||
933       TIL.TheKind == TypeTestResolution::AllOnes) {
934     ExportConstant("align", TTRes.AlignLog2, TIL.AlignLog2);
935     ExportConstant("size_m1", TTRes.SizeM1, TIL.SizeM1);
936 
937     uint64_t BitSize = cast<ConstantInt>(TIL.SizeM1)->getZExtValue() + 1;
938     if (TIL.TheKind == TypeTestResolution::Inline)
939       TTRes.SizeM1BitWidth = (BitSize <= 32) ? 5 : 6;
940     else
941       TTRes.SizeM1BitWidth = (BitSize <= 128) ? 7 : 32;
942   }
943 
944   if (TIL.TheKind == TypeTestResolution::ByteArray) {
945     ExportGlobal("byte_array", TIL.TheByteArray);
946     if (shouldExportConstantsAsAbsoluteSymbols())
947       ExportGlobal("bit_mask", TIL.BitMask);
948     else
949       return &TTRes.BitMask;
950   }
951 
952   if (TIL.TheKind == TypeTestResolution::Inline)
953     ExportConstant("inline_bits", TTRes.InlineBits, TIL.InlineBits);
954 
955   return nullptr;
956 }
957 
958 LowerTypeTestsModule::TypeIdLowering
importTypeId(StringRef TypeId)959 LowerTypeTestsModule::importTypeId(StringRef TypeId) {
960   const TypeIdSummary *TidSummary = ImportSummary->getTypeIdSummary(TypeId);
961   if (!TidSummary)
962     return {}; // Unsat: no globals match this type id.
963   const TypeTestResolution &TTRes = TidSummary->TTRes;
964 
965   TypeIdLowering TIL;
966   TIL.TheKind = TTRes.TheKind;
967 
968   auto ImportGlobal = [&](StringRef Name) {
969     // Give the global a type of length 0 so that it is not assumed not to alias
970     // with any other global.
971     Constant *C = M.getOrInsertGlobal(("__typeid_" + TypeId + "_" + Name).str(),
972                                       Int8Arr0Ty);
973     if (auto *GV = dyn_cast<GlobalVariable>(C))
974       GV->setVisibility(GlobalValue::HiddenVisibility);
975     return C;
976   };
977 
978   auto ImportConstant = [&](StringRef Name, uint64_t Const, unsigned AbsWidth,
979                             Type *Ty) {
980     if (!shouldExportConstantsAsAbsoluteSymbols()) {
981       Constant *C =
982           ConstantInt::get(isa<IntegerType>(Ty) ? Ty : Int64Ty, Const);
983       if (!isa<IntegerType>(Ty))
984         C = ConstantExpr::getIntToPtr(C, Ty);
985       return C;
986     }
987 
988     Constant *C = ImportGlobal(Name);
989     auto *GV = cast<GlobalVariable>(C->stripPointerCasts());
990     if (isa<IntegerType>(Ty))
991       C = ConstantExpr::getPtrToInt(C, Ty);
992     if (GV->getMetadata(LLVMContext::MD_absolute_symbol))
993       return C;
994 
995     auto SetAbsRange = [&](uint64_t Min, uint64_t Max) {
996       auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min));
997       auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max));
998       GV->setMetadata(LLVMContext::MD_absolute_symbol,
999                       MDNode::get(M.getContext(), {MinC, MaxC}));
1000     };
1001     if (AbsWidth == IntPtrTy->getBitWidth())
1002       SetAbsRange(~0ull, ~0ull); // Full set.
1003     else
1004       SetAbsRange(0, 1ull << AbsWidth);
1005     return C;
1006   };
1007 
1008   if (TIL.TheKind != TypeTestResolution::Unsat)
1009     TIL.OffsetedGlobal = ImportGlobal("global_addr");
1010 
1011   if (TIL.TheKind == TypeTestResolution::ByteArray ||
1012       TIL.TheKind == TypeTestResolution::Inline ||
1013       TIL.TheKind == TypeTestResolution::AllOnes) {
1014     TIL.AlignLog2 = ImportConstant("align", TTRes.AlignLog2, 8, Int8Ty);
1015     TIL.SizeM1 =
1016         ImportConstant("size_m1", TTRes.SizeM1, TTRes.SizeM1BitWidth, IntPtrTy);
1017   }
1018 
1019   if (TIL.TheKind == TypeTestResolution::ByteArray) {
1020     TIL.TheByteArray = ImportGlobal("byte_array");
1021     TIL.BitMask = ImportConstant("bit_mask", TTRes.BitMask, 8, Int8PtrTy);
1022   }
1023 
1024   if (TIL.TheKind == TypeTestResolution::Inline)
1025     TIL.InlineBits = ImportConstant(
1026         "inline_bits", TTRes.InlineBits, 1 << TTRes.SizeM1BitWidth,
1027         TTRes.SizeM1BitWidth <= 5 ? Int32Ty : Int64Ty);
1028 
1029   return TIL;
1030 }
1031 
importTypeTest(CallInst * CI)1032 void LowerTypeTestsModule::importTypeTest(CallInst *CI) {
1033   auto TypeIdMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1));
1034   if (!TypeIdMDVal)
1035     report_fatal_error("Second argument of llvm.type.test must be metadata");
1036 
1037   auto TypeIdStr = dyn_cast<MDString>(TypeIdMDVal->getMetadata());
1038   // If this is a local unpromoted type, which doesn't have a metadata string,
1039   // treat as Unknown and delay lowering, so that we can still utilize it for
1040   // later optimizations.
1041   if (!TypeIdStr)
1042     return;
1043 
1044   TypeIdLowering TIL = importTypeId(TypeIdStr->getString());
1045   Value *Lowered = lowerTypeTestCall(TypeIdStr, CI, TIL);
1046   if (Lowered) {
1047     CI->replaceAllUsesWith(Lowered);
1048     CI->eraseFromParent();
1049   }
1050 }
1051 
1052 // ThinLTO backend: the function F has a jump table entry; update this module
1053 // accordingly. isJumpTableCanonical describes the type of the jump table entry.
importFunction(Function * F,bool isJumpTableCanonical,std::vector<GlobalAlias * > & AliasesToErase)1054 void LowerTypeTestsModule::importFunction(
1055     Function *F, bool isJumpTableCanonical,
1056     std::vector<GlobalAlias *> &AliasesToErase) {
1057   assert(F->getType()->getAddressSpace() == 0);
1058 
1059   GlobalValue::VisibilityTypes Visibility = F->getVisibility();
1060   std::string Name = std::string(F->getName());
1061 
1062   if (F->isDeclarationForLinker() && isJumpTableCanonical) {
1063     // Non-dso_local functions may be overriden at run time,
1064     // don't short curcuit them
1065     if (F->isDSOLocal()) {
1066       Function *RealF = Function::Create(F->getFunctionType(),
1067                                          GlobalValue::ExternalLinkage,
1068                                          F->getAddressSpace(),
1069                                          Name + ".cfi", &M);
1070       RealF->setVisibility(GlobalVariable::HiddenVisibility);
1071       replaceDirectCalls(F, RealF);
1072     }
1073     return;
1074   }
1075 
1076   Function *FDecl;
1077   if (!isJumpTableCanonical) {
1078     // Either a declaration of an external function or a reference to a locally
1079     // defined jump table.
1080     FDecl = Function::Create(F->getFunctionType(), GlobalValue::ExternalLinkage,
1081                              F->getAddressSpace(), Name + ".cfi_jt", &M);
1082     FDecl->setVisibility(GlobalValue::HiddenVisibility);
1083   } else {
1084     F->setName(Name + ".cfi");
1085     F->setLinkage(GlobalValue::ExternalLinkage);
1086     FDecl = Function::Create(F->getFunctionType(), GlobalValue::ExternalLinkage,
1087                              F->getAddressSpace(), Name, &M);
1088     FDecl->setVisibility(Visibility);
1089     Visibility = GlobalValue::HiddenVisibility;
1090 
1091     // Delete aliases pointing to this function, they'll be re-created in the
1092     // merged output. Don't do it yet though because ScopedSaveAliaseesAndUsed
1093     // will want to reset the aliasees first.
1094     for (auto &U : F->uses()) {
1095       if (auto *A = dyn_cast<GlobalAlias>(U.getUser())) {
1096         Function *AliasDecl = Function::Create(
1097             F->getFunctionType(), GlobalValue::ExternalLinkage,
1098             F->getAddressSpace(), "", &M);
1099         AliasDecl->takeName(A);
1100         A->replaceAllUsesWith(AliasDecl);
1101         AliasesToErase.push_back(A);
1102       }
1103     }
1104   }
1105 
1106   if (F->hasExternalWeakLinkage())
1107     replaceWeakDeclarationWithJumpTablePtr(F, FDecl, isJumpTableCanonical);
1108   else
1109     replaceCfiUses(F, FDecl, isJumpTableCanonical);
1110 
1111   // Set visibility late because it's used in replaceCfiUses() to determine
1112   // whether uses need to be replaced.
1113   F->setVisibility(Visibility);
1114 }
1115 
lowerTypeTestCalls(ArrayRef<Metadata * > TypeIds,Constant * CombinedGlobalAddr,const DenseMap<GlobalTypeMember *,uint64_t> & GlobalLayout)1116 void LowerTypeTestsModule::lowerTypeTestCalls(
1117     ArrayRef<Metadata *> TypeIds, Constant *CombinedGlobalAddr,
1118     const DenseMap<GlobalTypeMember *, uint64_t> &GlobalLayout) {
1119   // For each type identifier in this disjoint set...
1120   for (Metadata *TypeId : TypeIds) {
1121     // Build the bitset.
1122     BitSetInfo BSI = buildBitSet(TypeId, GlobalLayout);
1123     LLVM_DEBUG({
1124       if (auto MDS = dyn_cast<MDString>(TypeId))
1125         dbgs() << MDS->getString() << ": ";
1126       else
1127         dbgs() << "<unnamed>: ";
1128       BSI.print(dbgs());
1129     });
1130 
1131     ByteArrayInfo *BAI = nullptr;
1132     TypeIdLowering TIL;
1133     TIL.OffsetedGlobal = ConstantExpr::getGetElementPtr(
1134         Int8Ty, CombinedGlobalAddr, ConstantInt::get(IntPtrTy, BSI.ByteOffset)),
1135     TIL.AlignLog2 = ConstantInt::get(Int8Ty, BSI.AlignLog2);
1136     TIL.SizeM1 = ConstantInt::get(IntPtrTy, BSI.BitSize - 1);
1137     if (BSI.isAllOnes()) {
1138       TIL.TheKind = (BSI.BitSize == 1) ? TypeTestResolution::Single
1139                                        : TypeTestResolution::AllOnes;
1140     } else if (BSI.BitSize <= 64) {
1141       TIL.TheKind = TypeTestResolution::Inline;
1142       uint64_t InlineBits = 0;
1143       for (auto Bit : BSI.Bits)
1144         InlineBits |= uint64_t(1) << Bit;
1145       if (InlineBits == 0)
1146         TIL.TheKind = TypeTestResolution::Unsat;
1147       else
1148         TIL.InlineBits = ConstantInt::get(
1149             (BSI.BitSize <= 32) ? Int32Ty : Int64Ty, InlineBits);
1150     } else {
1151       TIL.TheKind = TypeTestResolution::ByteArray;
1152       ++NumByteArraysCreated;
1153       BAI = createByteArray(BSI);
1154       TIL.TheByteArray = BAI->ByteArray;
1155       TIL.BitMask = BAI->MaskGlobal;
1156     }
1157 
1158     TypeIdUserInfo &TIUI = TypeIdUsers[TypeId];
1159 
1160     if (TIUI.IsExported) {
1161       uint8_t *MaskPtr = exportTypeId(cast<MDString>(TypeId)->getString(), TIL);
1162       if (BAI)
1163         BAI->MaskPtr = MaskPtr;
1164     }
1165 
1166     // Lower each call to llvm.type.test for this type identifier.
1167     for (CallInst *CI : TIUI.CallSites) {
1168       ++NumTypeTestCallsLowered;
1169       Value *Lowered = lowerTypeTestCall(TypeId, CI, TIL);
1170       if (Lowered) {
1171         CI->replaceAllUsesWith(Lowered);
1172         CI->eraseFromParent();
1173       }
1174     }
1175   }
1176 }
1177 
verifyTypeMDNode(GlobalObject * GO,MDNode * Type)1178 void LowerTypeTestsModule::verifyTypeMDNode(GlobalObject *GO, MDNode *Type) {
1179   if (Type->getNumOperands() != 2)
1180     report_fatal_error("All operands of type metadata must have 2 elements");
1181 
1182   if (GO->isThreadLocal())
1183     report_fatal_error("Bit set element may not be thread-local");
1184   if (isa<GlobalVariable>(GO) && GO->hasSection())
1185     report_fatal_error(
1186         "A member of a type identifier may not have an explicit section");
1187 
1188   // FIXME: We previously checked that global var member of a type identifier
1189   // must be a definition, but the IR linker may leave type metadata on
1190   // declarations. We should restore this check after fixing PR31759.
1191 
1192   auto OffsetConstMD = dyn_cast<ConstantAsMetadata>(Type->getOperand(0));
1193   if (!OffsetConstMD)
1194     report_fatal_error("Type offset must be a constant");
1195   auto OffsetInt = dyn_cast<ConstantInt>(OffsetConstMD->getValue());
1196   if (!OffsetInt)
1197     report_fatal_error("Type offset must be an integer constant");
1198 }
1199 
1200 static const unsigned kX86JumpTableEntrySize = 8;
1201 static const unsigned kX86IBTJumpTableEntrySize = 16;
1202 static const unsigned kARMJumpTableEntrySize = 4;
1203 static const unsigned kARMBTIJumpTableEntrySize = 8;
1204 static const unsigned kARMv6MJumpTableEntrySize = 16;
1205 static const unsigned kRISCVJumpTableEntrySize = 8;
1206 static const unsigned kLOONGARCH64JumpTableEntrySize = 8;
1207 
hasBranchTargetEnforcement()1208 bool LowerTypeTestsModule::hasBranchTargetEnforcement() {
1209   if (HasBranchTargetEnforcement == -1) {
1210     // First time this query has been called. Find out the answer by checking
1211     // the module flags.
1212     if (const auto *BTE = mdconst::extract_or_null<ConstantInt>(
1213           M.getModuleFlag("branch-target-enforcement")))
1214       HasBranchTargetEnforcement = (BTE->getZExtValue() != 0);
1215     else
1216       HasBranchTargetEnforcement = 0;
1217   }
1218   return HasBranchTargetEnforcement;
1219 }
1220 
getJumpTableEntrySize()1221 unsigned LowerTypeTestsModule::getJumpTableEntrySize() {
1222   switch (JumpTableArch) {
1223   case Triple::x86:
1224   case Triple::x86_64:
1225     if (const auto *MD = mdconst::extract_or_null<ConstantInt>(
1226             M.getModuleFlag("cf-protection-branch")))
1227       if (MD->getZExtValue())
1228         return kX86IBTJumpTableEntrySize;
1229     return kX86JumpTableEntrySize;
1230   case Triple::arm:
1231     return kARMJumpTableEntrySize;
1232   case Triple::thumb:
1233     if (CanUseThumbBWJumpTable) {
1234       if (hasBranchTargetEnforcement())
1235         return kARMBTIJumpTableEntrySize;
1236       return kARMJumpTableEntrySize;
1237     } else {
1238       return kARMv6MJumpTableEntrySize;
1239     }
1240   case Triple::aarch64:
1241     if (hasBranchTargetEnforcement())
1242       return kARMBTIJumpTableEntrySize;
1243     return kARMJumpTableEntrySize;
1244   case Triple::riscv32:
1245   case Triple::riscv64:
1246     return kRISCVJumpTableEntrySize;
1247   case Triple::loongarch64:
1248     return kLOONGARCH64JumpTableEntrySize;
1249   default:
1250     report_fatal_error("Unsupported architecture for jump tables");
1251   }
1252 }
1253 
1254 // Create a jump table entry for the target. This consists of an instruction
1255 // sequence containing a relative branch to Dest. Appends inline asm text,
1256 // constraints and arguments to AsmOS, ConstraintOS and AsmArgs.
createJumpTableEntry(raw_ostream & AsmOS,raw_ostream & ConstraintOS,Triple::ArchType JumpTableArch,SmallVectorImpl<Value * > & AsmArgs,Function * Dest)1257 void LowerTypeTestsModule::createJumpTableEntry(
1258     raw_ostream &AsmOS, raw_ostream &ConstraintOS,
1259     Triple::ArchType JumpTableArch, SmallVectorImpl<Value *> &AsmArgs,
1260     Function *Dest) {
1261   unsigned ArgIndex = AsmArgs.size();
1262 
1263   if (JumpTableArch == Triple::x86 || JumpTableArch == Triple::x86_64) {
1264     bool Endbr = false;
1265     if (const auto *MD = mdconst::extract_or_null<ConstantInt>(
1266           Dest->getParent()->getModuleFlag("cf-protection-branch")))
1267       Endbr = !MD->isZero();
1268     if (Endbr)
1269       AsmOS << (JumpTableArch == Triple::x86 ? "endbr32\n" : "endbr64\n");
1270     AsmOS << "jmp ${" << ArgIndex << ":c}@plt\n";
1271     if (Endbr)
1272       AsmOS << ".balign 16, 0xcc\n";
1273     else
1274       AsmOS << "int3\nint3\nint3\n";
1275   } else if (JumpTableArch == Triple::arm) {
1276     AsmOS << "b $" << ArgIndex << "\n";
1277   } else if (JumpTableArch == Triple::aarch64) {
1278     if (hasBranchTargetEnforcement())
1279       AsmOS << "bti c\n";
1280     AsmOS << "b $" << ArgIndex << "\n";
1281   } else if (JumpTableArch == Triple::thumb) {
1282     if (!CanUseThumbBWJumpTable) {
1283       // In Armv6-M, this sequence will generate a branch without corrupting
1284       // any registers. We use two stack words; in the second, we construct the
1285       // address we'll pop into pc, and the first is used to save and restore
1286       // r0 which we use as a temporary register.
1287       //
1288       // To support position-independent use cases, the offset of the target
1289       // function is stored as a relative offset (which will expand into an
1290       // R_ARM_REL32 relocation in ELF, and presumably the equivalent in other
1291       // object file types), and added to pc after we load it. (The alternative
1292       // B.W is automatically pc-relative.)
1293       //
1294       // There are five 16-bit Thumb instructions here, so the .balign 4 adds a
1295       // sixth halfword of padding, and then the offset consumes a further 4
1296       // bytes, for a total of 16, which is very convenient since entries in
1297       // this jump table need to have power-of-two size.
1298       AsmOS << "push {r0,r1}\n"
1299             << "ldr r0, 1f\n"
1300             << "0: add r0, r0, pc\n"
1301             << "str r0, [sp, #4]\n"
1302             << "pop {r0,pc}\n"
1303             << ".balign 4\n"
1304             << "1: .word $" << ArgIndex << " - (0b + 4)\n";
1305     } else {
1306       if (hasBranchTargetEnforcement())
1307         AsmOS << "bti\n";
1308       AsmOS << "b.w $" << ArgIndex << "\n";
1309     }
1310   } else if (JumpTableArch == Triple::riscv32 ||
1311              JumpTableArch == Triple::riscv64) {
1312     AsmOS << "tail $" << ArgIndex << "@plt\n";
1313   } else if (JumpTableArch == Triple::loongarch64) {
1314     AsmOS << "pcalau12i $$t0, %pc_hi20($" << ArgIndex << ")\n"
1315           << "jirl $$r0, $$t0, %pc_lo12($" << ArgIndex << ")\n";
1316   } else {
1317     report_fatal_error("Unsupported architecture for jump tables");
1318   }
1319 
1320   ConstraintOS << (ArgIndex > 0 ? ",s" : "s");
1321   AsmArgs.push_back(Dest);
1322 }
1323 
getJumpTableEntryType()1324 Type *LowerTypeTestsModule::getJumpTableEntryType() {
1325   return ArrayType::get(Int8Ty, getJumpTableEntrySize());
1326 }
1327 
1328 /// Given a disjoint set of type identifiers and functions, build the bit sets
1329 /// and lower the llvm.type.test calls, architecture dependently.
buildBitSetsFromFunctions(ArrayRef<Metadata * > TypeIds,ArrayRef<GlobalTypeMember * > Functions)1330 void LowerTypeTestsModule::buildBitSetsFromFunctions(
1331     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Functions) {
1332   if (Arch == Triple::x86 || Arch == Triple::x86_64 || Arch == Triple::arm ||
1333       Arch == Triple::thumb || Arch == Triple::aarch64 ||
1334       Arch == Triple::riscv32 || Arch == Triple::riscv64 ||
1335       Arch == Triple::loongarch64)
1336     buildBitSetsFromFunctionsNative(TypeIds, Functions);
1337   else if (Arch == Triple::wasm32 || Arch == Triple::wasm64)
1338     buildBitSetsFromFunctionsWASM(TypeIds, Functions);
1339   else
1340     report_fatal_error("Unsupported architecture for jump tables");
1341 }
1342 
moveInitializerToModuleConstructor(GlobalVariable * GV)1343 void LowerTypeTestsModule::moveInitializerToModuleConstructor(
1344     GlobalVariable *GV) {
1345   if (WeakInitializerFn == nullptr) {
1346     WeakInitializerFn = Function::Create(
1347         FunctionType::get(Type::getVoidTy(M.getContext()),
1348                           /* IsVarArg */ false),
1349         GlobalValue::InternalLinkage,
1350         M.getDataLayout().getProgramAddressSpace(),
1351         "__cfi_global_var_init", &M);
1352     BasicBlock *BB =
1353         BasicBlock::Create(M.getContext(), "entry", WeakInitializerFn);
1354     ReturnInst::Create(M.getContext(), BB);
1355     WeakInitializerFn->setSection(
1356         ObjectFormat == Triple::MachO
1357             ? "__TEXT,__StaticInit,regular,pure_instructions"
1358             : ".text.startup");
1359     // This code is equivalent to relocation application, and should run at the
1360     // earliest possible time (i.e. with the highest priority).
1361     appendToGlobalCtors(M, WeakInitializerFn, /* Priority */ 0);
1362   }
1363 
1364   IRBuilder<> IRB(WeakInitializerFn->getEntryBlock().getTerminator());
1365   GV->setConstant(false);
1366   IRB.CreateAlignedStore(GV->getInitializer(), GV, GV->getAlign());
1367   GV->setInitializer(Constant::getNullValue(GV->getValueType()));
1368 }
1369 
findGlobalVariableUsersOf(Constant * C,SmallSetVector<GlobalVariable *,8> & Out)1370 void LowerTypeTestsModule::findGlobalVariableUsersOf(
1371     Constant *C, SmallSetVector<GlobalVariable *, 8> &Out) {
1372   for (auto *U : C->users()){
1373     if (auto *GV = dyn_cast<GlobalVariable>(U))
1374       Out.insert(GV);
1375     else if (auto *C2 = dyn_cast<Constant>(U))
1376       findGlobalVariableUsersOf(C2, Out);
1377   }
1378 }
1379 
1380 // Replace all uses of F with (F ? JT : 0).
replaceWeakDeclarationWithJumpTablePtr(Function * F,Constant * JT,bool IsJumpTableCanonical)1381 void LowerTypeTestsModule::replaceWeakDeclarationWithJumpTablePtr(
1382     Function *F, Constant *JT, bool IsJumpTableCanonical) {
1383   // The target expression can not appear in a constant initializer on most
1384   // (all?) targets. Switch to a runtime initializer.
1385   SmallSetVector<GlobalVariable *, 8> GlobalVarUsers;
1386   findGlobalVariableUsersOf(F, GlobalVarUsers);
1387   for (auto *GV : GlobalVarUsers) {
1388     if (GV == GlobalAnnotation)
1389       continue;
1390     moveInitializerToModuleConstructor(GV);
1391   }
1392 
1393   // Can not RAUW F with an expression that uses F. Replace with a temporary
1394   // placeholder first.
1395   Function *PlaceholderFn =
1396       Function::Create(cast<FunctionType>(F->getValueType()),
1397                        GlobalValue::ExternalWeakLinkage,
1398                        F->getAddressSpace(), "", &M);
1399   replaceCfiUses(F, PlaceholderFn, IsJumpTableCanonical);
1400 
1401   convertUsersOfConstantsToInstructions(PlaceholderFn);
1402   // Don't use range based loop, because use list will be modified.
1403   while (!PlaceholderFn->use_empty()) {
1404     Use &U = *PlaceholderFn->use_begin();
1405     auto *InsertPt = dyn_cast<Instruction>(U.getUser());
1406     assert(InsertPt && "Non-instruction users should have been eliminated");
1407     auto *PN = dyn_cast<PHINode>(InsertPt);
1408     if (PN)
1409       InsertPt = PN->getIncomingBlock(U)->getTerminator();
1410     IRBuilder Builder(InsertPt);
1411     Value *ICmp = Builder.CreateICmp(CmpInst::ICMP_NE, F,
1412                                      Constant::getNullValue(F->getType()));
1413     Value *Select = Builder.CreateSelect(ICmp, JT,
1414                                          Constant::getNullValue(F->getType()));
1415     // For phi nodes, we need to update the incoming value for all operands
1416     // with the same predecessor.
1417     if (PN)
1418       PN->setIncomingValueForBlock(InsertPt->getParent(), Select);
1419     else
1420       U.set(Select);
1421   }
1422   PlaceholderFn->eraseFromParent();
1423 }
1424 
isThumbFunction(Function * F,Triple::ArchType ModuleArch)1425 static bool isThumbFunction(Function *F, Triple::ArchType ModuleArch) {
1426   Attribute TFAttr = F->getFnAttribute("target-features");
1427   if (TFAttr.isValid()) {
1428     SmallVector<StringRef, 6> Features;
1429     TFAttr.getValueAsString().split(Features, ',');
1430     for (StringRef Feature : Features) {
1431       if (Feature == "-thumb-mode")
1432         return false;
1433       else if (Feature == "+thumb-mode")
1434         return true;
1435     }
1436   }
1437 
1438   return ModuleArch == Triple::thumb;
1439 }
1440 
1441 // Each jump table must be either ARM or Thumb as a whole for the bit-test math
1442 // to work. Pick one that matches the majority of members to minimize interop
1443 // veneers inserted by the linker.
selectJumpTableArmEncoding(ArrayRef<GlobalTypeMember * > Functions)1444 Triple::ArchType LowerTypeTestsModule::selectJumpTableArmEncoding(
1445     ArrayRef<GlobalTypeMember *> Functions) {
1446   if (Arch != Triple::arm && Arch != Triple::thumb)
1447     return Arch;
1448 
1449   if (!CanUseThumbBWJumpTable && CanUseArmJumpTable) {
1450     // In architectures that provide Arm and Thumb-1 but not Thumb-2,
1451     // we should always prefer the Arm jump table format, because the
1452     // Thumb-1 one is larger and slower.
1453     return Triple::arm;
1454   }
1455 
1456   // Otherwise, go with majority vote.
1457   unsigned ArmCount = 0, ThumbCount = 0;
1458   for (const auto GTM : Functions) {
1459     if (!GTM->isJumpTableCanonical()) {
1460       // PLT stubs are always ARM.
1461       // FIXME: This is the wrong heuristic for non-canonical jump tables.
1462       ++ArmCount;
1463       continue;
1464     }
1465 
1466     Function *F = cast<Function>(GTM->getGlobal());
1467     ++(isThumbFunction(F, Arch) ? ThumbCount : ArmCount);
1468   }
1469 
1470   return ArmCount > ThumbCount ? Triple::arm : Triple::thumb;
1471 }
1472 
createJumpTable(Function * F,ArrayRef<GlobalTypeMember * > Functions)1473 void LowerTypeTestsModule::createJumpTable(
1474     Function *F, ArrayRef<GlobalTypeMember *> Functions) {
1475   std::string AsmStr, ConstraintStr;
1476   raw_string_ostream AsmOS(AsmStr), ConstraintOS(ConstraintStr);
1477   SmallVector<Value *, 16> AsmArgs;
1478   AsmArgs.reserve(Functions.size() * 2);
1479 
1480   // Check if all entries have the NoUnwind attribute.
1481   // If all entries have it, we can safely mark the
1482   // cfi.jumptable as NoUnwind, otherwise, direct calls
1483   // to the jump table will not handle exceptions properly
1484   bool areAllEntriesNounwind = true;
1485   for (GlobalTypeMember *GTM : Functions) {
1486     if (!llvm::cast<llvm::Function>(GTM->getGlobal())
1487              ->hasFnAttribute(llvm::Attribute::NoUnwind)) {
1488       areAllEntriesNounwind = false;
1489     }
1490     createJumpTableEntry(AsmOS, ConstraintOS, JumpTableArch, AsmArgs,
1491                          cast<Function>(GTM->getGlobal()));
1492   }
1493 
1494   // Align the whole table by entry size.
1495   F->setAlignment(Align(getJumpTableEntrySize()));
1496   // Skip prologue.
1497   // Disabled on win32 due to https://llvm.org/bugs/show_bug.cgi?id=28641#c3.
1498   // Luckily, this function does not get any prologue even without the
1499   // attribute.
1500   if (OS != Triple::Win32)
1501     F->addFnAttr(Attribute::Naked);
1502   if (JumpTableArch == Triple::arm)
1503     F->addFnAttr("target-features", "-thumb-mode");
1504   if (JumpTableArch == Triple::thumb) {
1505     if (hasBranchTargetEnforcement()) {
1506       // If we're generating a Thumb jump table with BTI, add a target-features
1507       // setting to ensure BTI can be assembled.
1508       F->addFnAttr("target-features", "+thumb-mode,+pacbti");
1509     } else {
1510       F->addFnAttr("target-features", "+thumb-mode");
1511       if (CanUseThumbBWJumpTable) {
1512         // Thumb jump table assembly needs Thumb2. The following attribute is
1513         // added by Clang for -march=armv7.
1514         F->addFnAttr("target-cpu", "cortex-a8");
1515       }
1516     }
1517   }
1518   // When -mbranch-protection= is used, the inline asm adds a BTI. Suppress BTI
1519   // for the function to avoid double BTI. This is a no-op without
1520   // -mbranch-protection=.
1521   if (JumpTableArch == Triple::aarch64 || JumpTableArch == Triple::thumb) {
1522     if (F->hasFnAttribute("branch-target-enforcement"))
1523       F->removeFnAttr("branch-target-enforcement");
1524     if (F->hasFnAttribute("sign-return-address"))
1525       F->removeFnAttr("sign-return-address");
1526   }
1527   if (JumpTableArch == Triple::riscv32 || JumpTableArch == Triple::riscv64) {
1528     // Make sure the jump table assembly is not modified by the assembler or
1529     // the linker.
1530     F->addFnAttr("target-features", "-c,-relax");
1531   }
1532   // When -fcf-protection= is used, the inline asm adds an ENDBR. Suppress ENDBR
1533   // for the function to avoid double ENDBR. This is a no-op without
1534   // -fcf-protection=.
1535   if (JumpTableArch == Triple::x86 || JumpTableArch == Triple::x86_64)
1536     F->addFnAttr(Attribute::NoCfCheck);
1537 
1538   // Make sure we don't emit .eh_frame for this function if it isn't needed.
1539   if (areAllEntriesNounwind)
1540     F->addFnAttr(Attribute::NoUnwind);
1541 
1542   // Make sure we do not inline any calls to the cfi.jumptable.
1543   F->addFnAttr(Attribute::NoInline);
1544 
1545   BasicBlock *BB = BasicBlock::Create(M.getContext(), "entry", F);
1546   IRBuilder<> IRB(BB);
1547 
1548   SmallVector<Type *, 16> ArgTypes;
1549   ArgTypes.reserve(AsmArgs.size());
1550   for (const auto &Arg : AsmArgs)
1551     ArgTypes.push_back(Arg->getType());
1552   InlineAsm *JumpTableAsm =
1553       InlineAsm::get(FunctionType::get(IRB.getVoidTy(), ArgTypes, false),
1554                      AsmOS.str(), ConstraintOS.str(),
1555                      /*hasSideEffects=*/true);
1556 
1557   IRB.CreateCall(JumpTableAsm, AsmArgs);
1558   IRB.CreateUnreachable();
1559 }
1560 
1561 /// Given a disjoint set of type identifiers and functions, build a jump table
1562 /// for the functions, build the bit sets and lower the llvm.type.test calls.
buildBitSetsFromFunctionsNative(ArrayRef<Metadata * > TypeIds,ArrayRef<GlobalTypeMember * > Functions)1563 void LowerTypeTestsModule::buildBitSetsFromFunctionsNative(
1564     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Functions) {
1565   // Unlike the global bitset builder, the function bitset builder cannot
1566   // re-arrange functions in a particular order and base its calculations on the
1567   // layout of the functions' entry points, as we have no idea how large a
1568   // particular function will end up being (the size could even depend on what
1569   // this pass does!) Instead, we build a jump table, which is a block of code
1570   // consisting of one branch instruction for each of the functions in the bit
1571   // set that branches to the target function, and redirect any taken function
1572   // addresses to the corresponding jump table entry. In the object file's
1573   // symbol table, the symbols for the target functions also refer to the jump
1574   // table entries, so that addresses taken outside the module will pass any
1575   // verification done inside the module.
1576   //
1577   // In more concrete terms, suppose we have three functions f, g, h which are
1578   // of the same type, and a function foo that returns their addresses:
1579   //
1580   // f:
1581   // mov 0, %eax
1582   // ret
1583   //
1584   // g:
1585   // mov 1, %eax
1586   // ret
1587   //
1588   // h:
1589   // mov 2, %eax
1590   // ret
1591   //
1592   // foo:
1593   // mov f, %eax
1594   // mov g, %edx
1595   // mov h, %ecx
1596   // ret
1597   //
1598   // We output the jump table as module-level inline asm string. The end result
1599   // will (conceptually) look like this:
1600   //
1601   // f = .cfi.jumptable
1602   // g = .cfi.jumptable + 4
1603   // h = .cfi.jumptable + 8
1604   // .cfi.jumptable:
1605   // jmp f.cfi  ; 5 bytes
1606   // int3       ; 1 byte
1607   // int3       ; 1 byte
1608   // int3       ; 1 byte
1609   // jmp g.cfi  ; 5 bytes
1610   // int3       ; 1 byte
1611   // int3       ; 1 byte
1612   // int3       ; 1 byte
1613   // jmp h.cfi  ; 5 bytes
1614   // int3       ; 1 byte
1615   // int3       ; 1 byte
1616   // int3       ; 1 byte
1617   //
1618   // f.cfi:
1619   // mov 0, %eax
1620   // ret
1621   //
1622   // g.cfi:
1623   // mov 1, %eax
1624   // ret
1625   //
1626   // h.cfi:
1627   // mov 2, %eax
1628   // ret
1629   //
1630   // foo:
1631   // mov f, %eax
1632   // mov g, %edx
1633   // mov h, %ecx
1634   // ret
1635   //
1636   // Because the addresses of f, g, h are evenly spaced at a power of 2, in the
1637   // normal case the check can be carried out using the same kind of simple
1638   // arithmetic that we normally use for globals.
1639 
1640   // FIXME: find a better way to represent the jumptable in the IR.
1641   assert(!Functions.empty());
1642 
1643   // Decide on the jump table encoding, so that we know how big the
1644   // entries will be.
1645   JumpTableArch = selectJumpTableArmEncoding(Functions);
1646 
1647   // Build a simple layout based on the regular layout of jump tables.
1648   DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout;
1649   unsigned EntrySize = getJumpTableEntrySize();
1650   for (unsigned I = 0; I != Functions.size(); ++I)
1651     GlobalLayout[Functions[I]] = I * EntrySize;
1652 
1653   Function *JumpTableFn =
1654       Function::Create(FunctionType::get(Type::getVoidTy(M.getContext()),
1655                                          /* IsVarArg */ false),
1656                        GlobalValue::PrivateLinkage,
1657                        M.getDataLayout().getProgramAddressSpace(),
1658                        ".cfi.jumptable", &M);
1659   ArrayType *JumpTableType =
1660       ArrayType::get(getJumpTableEntryType(), Functions.size());
1661   auto JumpTable =
1662       ConstantExpr::getPointerCast(JumpTableFn, JumpTableType->getPointerTo(0));
1663 
1664   lowerTypeTestCalls(TypeIds, JumpTable, GlobalLayout);
1665 
1666   {
1667     ScopedSaveAliaseesAndUsed S(M);
1668 
1669     // Build aliases pointing to offsets into the jump table, and replace
1670     // references to the original functions with references to the aliases.
1671     for (unsigned I = 0; I != Functions.size(); ++I) {
1672       Function *F = cast<Function>(Functions[I]->getGlobal());
1673       bool IsJumpTableCanonical = Functions[I]->isJumpTableCanonical();
1674 
1675       Constant *CombinedGlobalElemPtr = ConstantExpr::getInBoundsGetElementPtr(
1676           JumpTableType, JumpTable,
1677           ArrayRef<Constant *>{ConstantInt::get(IntPtrTy, 0),
1678                                ConstantInt::get(IntPtrTy, I)});
1679 
1680       const bool IsExported = Functions[I]->isExported();
1681       if (!IsJumpTableCanonical) {
1682         GlobalValue::LinkageTypes LT = IsExported
1683                                            ? GlobalValue::ExternalLinkage
1684                                            : GlobalValue::InternalLinkage;
1685         GlobalAlias *JtAlias = GlobalAlias::create(F->getValueType(), 0, LT,
1686                                                    F->getName() + ".cfi_jt",
1687                                                    CombinedGlobalElemPtr, &M);
1688         if (IsExported)
1689           JtAlias->setVisibility(GlobalValue::HiddenVisibility);
1690         else
1691           appendToUsed(M, {JtAlias});
1692       }
1693 
1694       if (IsExported) {
1695         if (IsJumpTableCanonical)
1696           ExportSummary->cfiFunctionDefs().insert(std::string(F->getName()));
1697         else
1698           ExportSummary->cfiFunctionDecls().insert(std::string(F->getName()));
1699       }
1700 
1701       if (!IsJumpTableCanonical) {
1702         if (F->hasExternalWeakLinkage())
1703           replaceWeakDeclarationWithJumpTablePtr(F, CombinedGlobalElemPtr,
1704                                                  IsJumpTableCanonical);
1705         else
1706           replaceCfiUses(F, CombinedGlobalElemPtr, IsJumpTableCanonical);
1707       } else {
1708         assert(F->getType()->getAddressSpace() == 0);
1709 
1710         GlobalAlias *FAlias =
1711             GlobalAlias::create(F->getValueType(), 0, F->getLinkage(), "",
1712                                 CombinedGlobalElemPtr, &M);
1713         FAlias->setVisibility(F->getVisibility());
1714         FAlias->takeName(F);
1715         if (FAlias->hasName())
1716           F->setName(FAlias->getName() + ".cfi");
1717         replaceCfiUses(F, FAlias, IsJumpTableCanonical);
1718         if (!F->hasLocalLinkage())
1719           F->setVisibility(GlobalVariable::HiddenVisibility);
1720       }
1721     }
1722   }
1723 
1724   createJumpTable(JumpTableFn, Functions);
1725 }
1726 
1727 /// Assign a dummy layout using an incrementing counter, tag each function
1728 /// with its index represented as metadata, and lower each type test to an
1729 /// integer range comparison. During generation of the indirect function call
1730 /// table in the backend, it will assign the given indexes.
1731 /// Note: Dynamic linking is not supported, as the WebAssembly ABI has not yet
1732 /// been finalized.
buildBitSetsFromFunctionsWASM(ArrayRef<Metadata * > TypeIds,ArrayRef<GlobalTypeMember * > Functions)1733 void LowerTypeTestsModule::buildBitSetsFromFunctionsWASM(
1734     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Functions) {
1735   assert(!Functions.empty());
1736 
1737   // Build consecutive monotonic integer ranges for each call target set
1738   DenseMap<GlobalTypeMember *, uint64_t> GlobalLayout;
1739 
1740   for (GlobalTypeMember *GTM : Functions) {
1741     Function *F = cast<Function>(GTM->getGlobal());
1742 
1743     // Skip functions that are not address taken, to avoid bloating the table
1744     if (!F->hasAddressTaken())
1745       continue;
1746 
1747     // Store metadata with the index for each function
1748     MDNode *MD = MDNode::get(F->getContext(),
1749                              ArrayRef<Metadata *>(ConstantAsMetadata::get(
1750                                  ConstantInt::get(Int64Ty, IndirectIndex))));
1751     F->setMetadata("wasm.index", MD);
1752 
1753     // Assign the counter value
1754     GlobalLayout[GTM] = IndirectIndex++;
1755   }
1756 
1757   // The indirect function table index space starts at zero, so pass a NULL
1758   // pointer as the subtracted "jump table" offset.
1759   lowerTypeTestCalls(TypeIds, ConstantPointerNull::get(Int32PtrTy),
1760                      GlobalLayout);
1761 }
1762 
buildBitSetsFromDisjointSet(ArrayRef<Metadata * > TypeIds,ArrayRef<GlobalTypeMember * > Globals,ArrayRef<ICallBranchFunnel * > ICallBranchFunnels)1763 void LowerTypeTestsModule::buildBitSetsFromDisjointSet(
1764     ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Globals,
1765     ArrayRef<ICallBranchFunnel *> ICallBranchFunnels) {
1766   DenseMap<Metadata *, uint64_t> TypeIdIndices;
1767   for (unsigned I = 0; I != TypeIds.size(); ++I)
1768     TypeIdIndices[TypeIds[I]] = I;
1769 
1770   // For each type identifier, build a set of indices that refer to members of
1771   // the type identifier.
1772   std::vector<std::set<uint64_t>> TypeMembers(TypeIds.size());
1773   unsigned GlobalIndex = 0;
1774   DenseMap<GlobalTypeMember *, uint64_t> GlobalIndices;
1775   for (GlobalTypeMember *GTM : Globals) {
1776     for (MDNode *Type : GTM->types()) {
1777       // Type = { offset, type identifier }
1778       auto I = TypeIdIndices.find(Type->getOperand(1));
1779       if (I != TypeIdIndices.end())
1780         TypeMembers[I->second].insert(GlobalIndex);
1781     }
1782     GlobalIndices[GTM] = GlobalIndex;
1783     GlobalIndex++;
1784   }
1785 
1786   for (ICallBranchFunnel *JT : ICallBranchFunnels) {
1787     TypeMembers.emplace_back();
1788     std::set<uint64_t> &TMSet = TypeMembers.back();
1789     for (GlobalTypeMember *T : JT->targets())
1790       TMSet.insert(GlobalIndices[T]);
1791   }
1792 
1793   // Order the sets of indices by size. The GlobalLayoutBuilder works best
1794   // when given small index sets first.
1795   llvm::stable_sort(TypeMembers, [](const std::set<uint64_t> &O1,
1796                                     const std::set<uint64_t> &O2) {
1797     return O1.size() < O2.size();
1798   });
1799 
1800   // Create a GlobalLayoutBuilder and provide it with index sets as layout
1801   // fragments. The GlobalLayoutBuilder tries to lay out members of fragments as
1802   // close together as possible.
1803   GlobalLayoutBuilder GLB(Globals.size());
1804   for (auto &&MemSet : TypeMembers)
1805     GLB.addFragment(MemSet);
1806 
1807   // Build a vector of globals with the computed layout.
1808   bool IsGlobalSet =
1809       Globals.empty() || isa<GlobalVariable>(Globals[0]->getGlobal());
1810   std::vector<GlobalTypeMember *> OrderedGTMs(Globals.size());
1811   auto OGTMI = OrderedGTMs.begin();
1812   for (auto &&F : GLB.Fragments) {
1813     for (auto &&Offset : F) {
1814       if (IsGlobalSet != isa<GlobalVariable>(Globals[Offset]->getGlobal()))
1815         report_fatal_error("Type identifier may not contain both global "
1816                            "variables and functions");
1817       *OGTMI++ = Globals[Offset];
1818     }
1819   }
1820 
1821   // Build the bitsets from this disjoint set.
1822   if (IsGlobalSet)
1823     buildBitSetsFromGlobalVariables(TypeIds, OrderedGTMs);
1824   else
1825     buildBitSetsFromFunctions(TypeIds, OrderedGTMs);
1826 }
1827 
1828 /// Lower all type tests in this module.
LowerTypeTestsModule(Module & M,ModuleAnalysisManager & AM,ModuleSummaryIndex * ExportSummary,const ModuleSummaryIndex * ImportSummary,bool DropTypeTests)1829 LowerTypeTestsModule::LowerTypeTestsModule(
1830     Module &M, ModuleAnalysisManager &AM, ModuleSummaryIndex *ExportSummary,
1831     const ModuleSummaryIndex *ImportSummary, bool DropTypeTests)
1832     : M(M), ExportSummary(ExportSummary), ImportSummary(ImportSummary),
1833       DropTypeTests(DropTypeTests || ClDropTypeTests) {
1834   assert(!(ExportSummary && ImportSummary));
1835   Triple TargetTriple(M.getTargetTriple());
1836   Arch = TargetTriple.getArch();
1837   if (Arch == Triple::arm)
1838     CanUseArmJumpTable = true;
1839   if (Arch == Triple::arm || Arch == Triple::thumb) {
1840     auto &FAM =
1841         AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
1842     for (Function &F : M) {
1843       auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
1844       if (TTI.hasArmWideBranch(false))
1845         CanUseArmJumpTable = true;
1846       if (TTI.hasArmWideBranch(true))
1847         CanUseThumbBWJumpTable = true;
1848     }
1849   }
1850   OS = TargetTriple.getOS();
1851   ObjectFormat = TargetTriple.getObjectFormat();
1852 
1853   // Function annotation describes or applies to function itself, and
1854   // shouldn't be associated with jump table thunk generated for CFI.
1855   GlobalAnnotation = M.getGlobalVariable("llvm.global.annotations");
1856   if (GlobalAnnotation && GlobalAnnotation->hasInitializer()) {
1857     const ConstantArray *CA =
1858         cast<ConstantArray>(GlobalAnnotation->getInitializer());
1859     for (Value *Op : CA->operands())
1860       FunctionAnnotations.insert(Op);
1861   }
1862 }
1863 
runForTesting(Module & M,ModuleAnalysisManager & AM)1864 bool LowerTypeTestsModule::runForTesting(Module &M, ModuleAnalysisManager &AM) {
1865   ModuleSummaryIndex Summary(/*HaveGVs=*/false);
1866 
1867   // Handle the command-line summary arguments. This code is for testing
1868   // purposes only, so we handle errors directly.
1869   if (!ClReadSummary.empty()) {
1870     ExitOnError ExitOnErr("-lowertypetests-read-summary: " + ClReadSummary +
1871                           ": ");
1872     auto ReadSummaryFile =
1873         ExitOnErr(errorOrToExpected(MemoryBuffer::getFile(ClReadSummary)));
1874 
1875     yaml::Input In(ReadSummaryFile->getBuffer());
1876     In >> Summary;
1877     ExitOnErr(errorCodeToError(In.error()));
1878   }
1879 
1880   bool Changed =
1881       LowerTypeTestsModule(
1882           M, AM,
1883           ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr,
1884           ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr,
1885           /*DropTypeTests*/ false)
1886           .lower();
1887 
1888   if (!ClWriteSummary.empty()) {
1889     ExitOnError ExitOnErr("-lowertypetests-write-summary: " + ClWriteSummary +
1890                           ": ");
1891     std::error_code EC;
1892     raw_fd_ostream OS(ClWriteSummary, EC, sys::fs::OF_TextWithCRLF);
1893     ExitOnErr(errorCodeToError(EC));
1894 
1895     yaml::Output Out(OS);
1896     Out << Summary;
1897   }
1898 
1899   return Changed;
1900 }
1901 
isDirectCall(Use & U)1902 static bool isDirectCall(Use& U) {
1903   auto *Usr = dyn_cast<CallInst>(U.getUser());
1904   if (Usr) {
1905     auto *CB = dyn_cast<CallBase>(Usr);
1906     if (CB && CB->isCallee(&U))
1907       return true;
1908   }
1909   return false;
1910 }
1911 
replaceCfiUses(Function * Old,Value * New,bool IsJumpTableCanonical)1912 void LowerTypeTestsModule::replaceCfiUses(Function *Old, Value *New,
1913                                           bool IsJumpTableCanonical) {
1914   SmallSetVector<Constant *, 4> Constants;
1915   for (Use &U : llvm::make_early_inc_range(Old->uses())) {
1916     // Skip block addresses and no_cfi values, which refer to the function
1917     // body instead of the jump table.
1918     if (isa<BlockAddress, NoCFIValue>(U.getUser()))
1919       continue;
1920 
1921     // Skip direct calls to externally defined or non-dso_local functions.
1922     if (isDirectCall(U) && (Old->isDSOLocal() || !IsJumpTableCanonical))
1923       continue;
1924 
1925     // Skip function annotation.
1926     if (isFunctionAnnotation(U.getUser()))
1927       continue;
1928 
1929     // Must handle Constants specially, we cannot call replaceUsesOfWith on a
1930     // constant because they are uniqued.
1931     if (auto *C = dyn_cast<Constant>(U.getUser())) {
1932       if (!isa<GlobalValue>(C)) {
1933         // Save unique users to avoid processing operand replacement
1934         // more than once.
1935         Constants.insert(C);
1936         continue;
1937       }
1938     }
1939 
1940     U.set(New);
1941   }
1942 
1943   // Process operand replacement of saved constants.
1944   for (auto *C : Constants)
1945     C->handleOperandChange(Old, New);
1946 }
1947 
replaceDirectCalls(Value * Old,Value * New)1948 void LowerTypeTestsModule::replaceDirectCalls(Value *Old, Value *New) {
1949   Old->replaceUsesWithIf(New, isDirectCall);
1950 }
1951 
dropTypeTests(Module & M,Function & TypeTestFunc)1952 static void dropTypeTests(Module &M, Function &TypeTestFunc) {
1953   for (Use &U : llvm::make_early_inc_range(TypeTestFunc.uses())) {
1954     auto *CI = cast<CallInst>(U.getUser());
1955     // Find and erase llvm.assume intrinsics for this llvm.type.test call.
1956     for (Use &CIU : llvm::make_early_inc_range(CI->uses()))
1957       if (auto *Assume = dyn_cast<AssumeInst>(CIU.getUser()))
1958         Assume->eraseFromParent();
1959     // If the assume was merged with another assume, we might have a use on a
1960     // phi (which will feed the assume). Simply replace the use on the phi
1961     // with "true" and leave the merged assume.
1962     if (!CI->use_empty()) {
1963       assert(
1964           all_of(CI->users(), [](User *U) -> bool { return isa<PHINode>(U); }));
1965       CI->replaceAllUsesWith(ConstantInt::getTrue(M.getContext()));
1966     }
1967     CI->eraseFromParent();
1968   }
1969 }
1970 
lower()1971 bool LowerTypeTestsModule::lower() {
1972   Function *TypeTestFunc =
1973       M.getFunction(Intrinsic::getName(Intrinsic::type_test));
1974 
1975   if (DropTypeTests) {
1976     if (TypeTestFunc)
1977       dropTypeTests(M, *TypeTestFunc);
1978     // Normally we'd have already removed all @llvm.public.type.test calls,
1979     // except for in the case where we originally were performing ThinLTO but
1980     // decided not to in the backend.
1981     Function *PublicTypeTestFunc =
1982         M.getFunction(Intrinsic::getName(Intrinsic::public_type_test));
1983     if (PublicTypeTestFunc)
1984       dropTypeTests(M, *PublicTypeTestFunc);
1985     if (TypeTestFunc || PublicTypeTestFunc) {
1986       // We have deleted the type intrinsics, so we no longer have enough
1987       // information to reason about the liveness of virtual function pointers
1988       // in GlobalDCE.
1989       for (GlobalVariable &GV : M.globals())
1990         GV.eraseMetadata(LLVMContext::MD_vcall_visibility);
1991       return true;
1992     }
1993     return false;
1994   }
1995 
1996   // If only some of the modules were split, we cannot correctly perform
1997   // this transformation. We already checked for the presense of type tests
1998   // with partially split modules during the thin link, and would have emitted
1999   // an error if any were found, so here we can simply return.
2000   if ((ExportSummary && ExportSummary->partiallySplitLTOUnits()) ||
2001       (ImportSummary && ImportSummary->partiallySplitLTOUnits()))
2002     return false;
2003 
2004   Function *ICallBranchFunnelFunc =
2005       M.getFunction(Intrinsic::getName(Intrinsic::icall_branch_funnel));
2006   if ((!TypeTestFunc || TypeTestFunc->use_empty()) &&
2007       (!ICallBranchFunnelFunc || ICallBranchFunnelFunc->use_empty()) &&
2008       !ExportSummary && !ImportSummary)
2009     return false;
2010 
2011   if (ImportSummary) {
2012     if (TypeTestFunc)
2013       for (Use &U : llvm::make_early_inc_range(TypeTestFunc->uses()))
2014         importTypeTest(cast<CallInst>(U.getUser()));
2015 
2016     if (ICallBranchFunnelFunc && !ICallBranchFunnelFunc->use_empty())
2017       report_fatal_error(
2018           "unexpected call to llvm.icall.branch.funnel during import phase");
2019 
2020     SmallVector<Function *, 8> Defs;
2021     SmallVector<Function *, 8> Decls;
2022     for (auto &F : M) {
2023       // CFI functions are either external, or promoted. A local function may
2024       // have the same name, but it's not the one we are looking for.
2025       if (F.hasLocalLinkage())
2026         continue;
2027       if (ImportSummary->cfiFunctionDefs().count(std::string(F.getName())))
2028         Defs.push_back(&F);
2029       else if (ImportSummary->cfiFunctionDecls().count(
2030                    std::string(F.getName())))
2031         Decls.push_back(&F);
2032     }
2033 
2034     std::vector<GlobalAlias *> AliasesToErase;
2035     {
2036       ScopedSaveAliaseesAndUsed S(M);
2037       for (auto *F : Defs)
2038         importFunction(F, /*isJumpTableCanonical*/ true, AliasesToErase);
2039       for (auto *F : Decls)
2040         importFunction(F, /*isJumpTableCanonical*/ false, AliasesToErase);
2041     }
2042     for (GlobalAlias *GA : AliasesToErase)
2043       GA->eraseFromParent();
2044 
2045     return true;
2046   }
2047 
2048   // Equivalence class set containing type identifiers and the globals that
2049   // reference them. This is used to partition the set of type identifiers in
2050   // the module into disjoint sets.
2051   using GlobalClassesTy = EquivalenceClasses<
2052       PointerUnion<GlobalTypeMember *, Metadata *, ICallBranchFunnel *>>;
2053   GlobalClassesTy GlobalClasses;
2054 
2055   // Verify the type metadata and build a few data structures to let us
2056   // efficiently enumerate the type identifiers associated with a global:
2057   // a list of GlobalTypeMembers (a GlobalObject stored alongside a vector
2058   // of associated type metadata) and a mapping from type identifiers to their
2059   // list of GlobalTypeMembers and last observed index in the list of globals.
2060   // The indices will be used later to deterministically order the list of type
2061   // identifiers.
2062   BumpPtrAllocator Alloc;
2063   struct TIInfo {
2064     unsigned UniqueId;
2065     std::vector<GlobalTypeMember *> RefGlobals;
2066   };
2067   DenseMap<Metadata *, TIInfo> TypeIdInfo;
2068   unsigned CurUniqueId = 0;
2069   SmallVector<MDNode *, 2> Types;
2070 
2071   // Cross-DSO CFI emits jumptable entries for exported functions as well as
2072   // address taken functions in case they are address taken in other modules.
2073   const bool CrossDsoCfi = M.getModuleFlag("Cross-DSO CFI") != nullptr;
2074 
2075   struct ExportedFunctionInfo {
2076     CfiFunctionLinkage Linkage;
2077     MDNode *FuncMD; // {name, linkage, type[, type...]}
2078   };
2079   MapVector<StringRef, ExportedFunctionInfo> ExportedFunctions;
2080   if (ExportSummary) {
2081     // A set of all functions that are address taken by a live global object.
2082     DenseSet<GlobalValue::GUID> AddressTaken;
2083     for (auto &I : *ExportSummary)
2084       for (auto &GVS : I.second.SummaryList)
2085         if (GVS->isLive())
2086           for (const auto &Ref : GVS->refs())
2087             AddressTaken.insert(Ref.getGUID());
2088 
2089     NamedMDNode *CfiFunctionsMD = M.getNamedMetadata("cfi.functions");
2090     if (CfiFunctionsMD) {
2091       for (auto *FuncMD : CfiFunctionsMD->operands()) {
2092         assert(FuncMD->getNumOperands() >= 2);
2093         StringRef FunctionName =
2094             cast<MDString>(FuncMD->getOperand(0))->getString();
2095         CfiFunctionLinkage Linkage = static_cast<CfiFunctionLinkage>(
2096             cast<ConstantAsMetadata>(FuncMD->getOperand(1))
2097                 ->getValue()
2098                 ->getUniqueInteger()
2099                 .getZExtValue());
2100         const GlobalValue::GUID GUID = GlobalValue::getGUID(
2101                 GlobalValue::dropLLVMManglingEscape(FunctionName));
2102         // Do not emit jumptable entries for functions that are not-live and
2103         // have no live references (and are not exported with cross-DSO CFI.)
2104         if (!ExportSummary->isGUIDLive(GUID))
2105           continue;
2106         if (!AddressTaken.count(GUID)) {
2107           if (!CrossDsoCfi || Linkage != CFL_Definition)
2108             continue;
2109 
2110           bool Exported = false;
2111           if (auto VI = ExportSummary->getValueInfo(GUID))
2112             for (const auto &GVS : VI.getSummaryList())
2113               if (GVS->isLive() && !GlobalValue::isLocalLinkage(GVS->linkage()))
2114                 Exported = true;
2115 
2116           if (!Exported)
2117             continue;
2118         }
2119         auto P = ExportedFunctions.insert({FunctionName, {Linkage, FuncMD}});
2120         if (!P.second && P.first->second.Linkage != CFL_Definition)
2121           P.first->second = {Linkage, FuncMD};
2122       }
2123 
2124       for (const auto &P : ExportedFunctions) {
2125         StringRef FunctionName = P.first;
2126         CfiFunctionLinkage Linkage = P.second.Linkage;
2127         MDNode *FuncMD = P.second.FuncMD;
2128         Function *F = M.getFunction(FunctionName);
2129         if (F && F->hasLocalLinkage()) {
2130           // Locally defined function that happens to have the same name as a
2131           // function defined in a ThinLTO module. Rename it to move it out of
2132           // the way of the external reference that we're about to create.
2133           // Note that setName will find a unique name for the function, so even
2134           // if there is an existing function with the suffix there won't be a
2135           // name collision.
2136           F->setName(F->getName() + ".1");
2137           F = nullptr;
2138         }
2139 
2140         if (!F)
2141           F = Function::Create(
2142               FunctionType::get(Type::getVoidTy(M.getContext()), false),
2143               GlobalVariable::ExternalLinkage,
2144               M.getDataLayout().getProgramAddressSpace(), FunctionName, &M);
2145 
2146         // If the function is available_externally, remove its definition so
2147         // that it is handled the same way as a declaration. Later we will try
2148         // to create an alias using this function's linkage, which will fail if
2149         // the linkage is available_externally. This will also result in us
2150         // following the code path below to replace the type metadata.
2151         if (F->hasAvailableExternallyLinkage()) {
2152           F->setLinkage(GlobalValue::ExternalLinkage);
2153           F->deleteBody();
2154           F->setComdat(nullptr);
2155           F->clearMetadata();
2156         }
2157 
2158         // Update the linkage for extern_weak declarations when a definition
2159         // exists.
2160         if (Linkage == CFL_Definition && F->hasExternalWeakLinkage())
2161           F->setLinkage(GlobalValue::ExternalLinkage);
2162 
2163         // If the function in the full LTO module is a declaration, replace its
2164         // type metadata with the type metadata we found in cfi.functions. That
2165         // metadata is presumed to be more accurate than the metadata attached
2166         // to the declaration.
2167         if (F->isDeclaration()) {
2168           if (Linkage == CFL_WeakDeclaration)
2169             F->setLinkage(GlobalValue::ExternalWeakLinkage);
2170 
2171           F->eraseMetadata(LLVMContext::MD_type);
2172           for (unsigned I = 2; I < FuncMD->getNumOperands(); ++I)
2173             F->addMetadata(LLVMContext::MD_type,
2174                            *cast<MDNode>(FuncMD->getOperand(I).get()));
2175         }
2176       }
2177     }
2178   }
2179 
2180   DenseMap<GlobalObject *, GlobalTypeMember *> GlobalTypeMembers;
2181   for (GlobalObject &GO : M.global_objects()) {
2182     if (isa<GlobalVariable>(GO) && GO.isDeclarationForLinker())
2183       continue;
2184 
2185     Types.clear();
2186     GO.getMetadata(LLVMContext::MD_type, Types);
2187 
2188     bool IsJumpTableCanonical = false;
2189     bool IsExported = false;
2190     if (Function *F = dyn_cast<Function>(&GO)) {
2191       IsJumpTableCanonical = isJumpTableCanonical(F);
2192       if (ExportedFunctions.count(F->getName())) {
2193         IsJumpTableCanonical |=
2194             ExportedFunctions[F->getName()].Linkage == CFL_Definition;
2195         IsExported = true;
2196       // TODO: The logic here checks only that the function is address taken,
2197       // not that the address takers are live. This can be updated to check
2198       // their liveness and emit fewer jumptable entries once monolithic LTO
2199       // builds also emit summaries.
2200       } else if (!F->hasAddressTaken()) {
2201         if (!CrossDsoCfi || !IsJumpTableCanonical || F->hasLocalLinkage())
2202           continue;
2203       }
2204     }
2205 
2206     auto *GTM = GlobalTypeMember::create(Alloc, &GO, IsJumpTableCanonical,
2207                                          IsExported, Types);
2208     GlobalTypeMembers[&GO] = GTM;
2209     for (MDNode *Type : Types) {
2210       verifyTypeMDNode(&GO, Type);
2211       auto &Info = TypeIdInfo[Type->getOperand(1)];
2212       Info.UniqueId = ++CurUniqueId;
2213       Info.RefGlobals.push_back(GTM);
2214     }
2215   }
2216 
2217   auto AddTypeIdUse = [&](Metadata *TypeId) -> TypeIdUserInfo & {
2218     // Add the call site to the list of call sites for this type identifier. We
2219     // also use TypeIdUsers to keep track of whether we have seen this type
2220     // identifier before. If we have, we don't need to re-add the referenced
2221     // globals to the equivalence class.
2222     auto Ins = TypeIdUsers.insert({TypeId, {}});
2223     if (Ins.second) {
2224       // Add the type identifier to the equivalence class.
2225       GlobalClassesTy::iterator GCI = GlobalClasses.insert(TypeId);
2226       GlobalClassesTy::member_iterator CurSet = GlobalClasses.findLeader(GCI);
2227 
2228       // Add the referenced globals to the type identifier's equivalence class.
2229       for (GlobalTypeMember *GTM : TypeIdInfo[TypeId].RefGlobals)
2230         CurSet = GlobalClasses.unionSets(
2231             CurSet, GlobalClasses.findLeader(GlobalClasses.insert(GTM)));
2232     }
2233 
2234     return Ins.first->second;
2235   };
2236 
2237   if (TypeTestFunc) {
2238     for (const Use &U : TypeTestFunc->uses()) {
2239       auto CI = cast<CallInst>(U.getUser());
2240       // If this type test is only used by llvm.assume instructions, it
2241       // was used for whole program devirtualization, and is being kept
2242       // for use by other optimization passes. We do not need or want to
2243       // lower it here. We also don't want to rewrite any associated globals
2244       // unnecessarily. These will be removed by a subsequent LTT invocation
2245       // with the DropTypeTests flag set.
2246       bool OnlyAssumeUses = !CI->use_empty();
2247       for (const Use &CIU : CI->uses()) {
2248         if (isa<AssumeInst>(CIU.getUser()))
2249           continue;
2250         OnlyAssumeUses = false;
2251         break;
2252       }
2253       if (OnlyAssumeUses)
2254         continue;
2255 
2256       auto TypeIdMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1));
2257       if (!TypeIdMDVal)
2258         report_fatal_error("Second argument of llvm.type.test must be metadata");
2259       auto TypeId = TypeIdMDVal->getMetadata();
2260       AddTypeIdUse(TypeId).CallSites.push_back(CI);
2261     }
2262   }
2263 
2264   if (ICallBranchFunnelFunc) {
2265     for (const Use &U : ICallBranchFunnelFunc->uses()) {
2266       if (Arch != Triple::x86_64)
2267         report_fatal_error(
2268             "llvm.icall.branch.funnel not supported on this target");
2269 
2270       auto CI = cast<CallInst>(U.getUser());
2271 
2272       std::vector<GlobalTypeMember *> Targets;
2273       if (CI->arg_size() % 2 != 1)
2274         report_fatal_error("number of arguments should be odd");
2275 
2276       GlobalClassesTy::member_iterator CurSet;
2277       for (unsigned I = 1; I != CI->arg_size(); I += 2) {
2278         int64_t Offset;
2279         auto *Base = dyn_cast<GlobalObject>(GetPointerBaseWithConstantOffset(
2280             CI->getOperand(I), Offset, M.getDataLayout()));
2281         if (!Base)
2282           report_fatal_error(
2283               "Expected branch funnel operand to be global value");
2284 
2285         GlobalTypeMember *GTM = GlobalTypeMembers[Base];
2286         Targets.push_back(GTM);
2287         GlobalClassesTy::member_iterator NewSet =
2288             GlobalClasses.findLeader(GlobalClasses.insert(GTM));
2289         if (I == 1)
2290           CurSet = NewSet;
2291         else
2292           CurSet = GlobalClasses.unionSets(CurSet, NewSet);
2293       }
2294 
2295       GlobalClasses.unionSets(
2296           CurSet, GlobalClasses.findLeader(
2297                       GlobalClasses.insert(ICallBranchFunnel::create(
2298                           Alloc, CI, Targets, ++CurUniqueId))));
2299     }
2300   }
2301 
2302   if (ExportSummary) {
2303     DenseMap<GlobalValue::GUID, TinyPtrVector<Metadata *>> MetadataByGUID;
2304     for (auto &P : TypeIdInfo) {
2305       if (auto *TypeId = dyn_cast<MDString>(P.first))
2306         MetadataByGUID[GlobalValue::getGUID(TypeId->getString())].push_back(
2307             TypeId);
2308     }
2309 
2310     for (auto &P : *ExportSummary) {
2311       for (auto &S : P.second.SummaryList) {
2312         if (!ExportSummary->isGlobalValueLive(S.get()))
2313           continue;
2314         if (auto *FS = dyn_cast<FunctionSummary>(S->getBaseObject()))
2315           for (GlobalValue::GUID G : FS->type_tests())
2316             for (Metadata *MD : MetadataByGUID[G])
2317               AddTypeIdUse(MD).IsExported = true;
2318       }
2319     }
2320   }
2321 
2322   if (GlobalClasses.empty())
2323     return false;
2324 
2325   // Build a list of disjoint sets ordered by their maximum global index for
2326   // determinism.
2327   std::vector<std::pair<GlobalClassesTy::iterator, unsigned>> Sets;
2328   for (GlobalClassesTy::iterator I = GlobalClasses.begin(),
2329                                  E = GlobalClasses.end();
2330        I != E; ++I) {
2331     if (!I->isLeader())
2332       continue;
2333     ++NumTypeIdDisjointSets;
2334 
2335     unsigned MaxUniqueId = 0;
2336     for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(I);
2337          MI != GlobalClasses.member_end(); ++MI) {
2338       if (auto *MD = dyn_cast_if_present<Metadata *>(*MI))
2339         MaxUniqueId = std::max(MaxUniqueId, TypeIdInfo[MD].UniqueId);
2340       else if (auto *BF = dyn_cast_if_present<ICallBranchFunnel *>(*MI))
2341         MaxUniqueId = std::max(MaxUniqueId, BF->UniqueId);
2342     }
2343     Sets.emplace_back(I, MaxUniqueId);
2344   }
2345   llvm::sort(Sets, llvm::less_second());
2346 
2347   // For each disjoint set we found...
2348   for (const auto &S : Sets) {
2349     // Build the list of type identifiers in this disjoint set.
2350     std::vector<Metadata *> TypeIds;
2351     std::vector<GlobalTypeMember *> Globals;
2352     std::vector<ICallBranchFunnel *> ICallBranchFunnels;
2353     for (GlobalClassesTy::member_iterator MI =
2354              GlobalClasses.member_begin(S.first);
2355          MI != GlobalClasses.member_end(); ++MI) {
2356       if (isa<Metadata *>(*MI))
2357         TypeIds.push_back(cast<Metadata *>(*MI));
2358       else if (isa<GlobalTypeMember *>(*MI))
2359         Globals.push_back(cast<GlobalTypeMember *>(*MI));
2360       else
2361         ICallBranchFunnels.push_back(cast<ICallBranchFunnel *>(*MI));
2362     }
2363 
2364     // Order type identifiers by unique ID for determinism. This ordering is
2365     // stable as there is a one-to-one mapping between metadata and unique IDs.
2366     llvm::sort(TypeIds, [&](Metadata *M1, Metadata *M2) {
2367       return TypeIdInfo[M1].UniqueId < TypeIdInfo[M2].UniqueId;
2368     });
2369 
2370     // Same for the branch funnels.
2371     llvm::sort(ICallBranchFunnels,
2372                [&](ICallBranchFunnel *F1, ICallBranchFunnel *F2) {
2373                  return F1->UniqueId < F2->UniqueId;
2374                });
2375 
2376     // Build bitsets for this disjoint set.
2377     buildBitSetsFromDisjointSet(TypeIds, Globals, ICallBranchFunnels);
2378   }
2379 
2380   allocateByteArrays();
2381 
2382   // Parse alias data to replace stand-in function declarations for aliases
2383   // with an alias to the intended target.
2384   if (ExportSummary) {
2385     if (NamedMDNode *AliasesMD = M.getNamedMetadata("aliases")) {
2386       for (auto *AliasMD : AliasesMD->operands()) {
2387         assert(AliasMD->getNumOperands() >= 4);
2388         StringRef AliasName =
2389             cast<MDString>(AliasMD->getOperand(0))->getString();
2390         StringRef Aliasee = cast<MDString>(AliasMD->getOperand(1))->getString();
2391 
2392         if (!ExportedFunctions.count(Aliasee) ||
2393             ExportedFunctions[Aliasee].Linkage != CFL_Definition ||
2394             !M.getNamedAlias(Aliasee))
2395           continue;
2396 
2397         GlobalValue::VisibilityTypes Visibility =
2398             static_cast<GlobalValue::VisibilityTypes>(
2399                 cast<ConstantAsMetadata>(AliasMD->getOperand(2))
2400                     ->getValue()
2401                     ->getUniqueInteger()
2402                     .getZExtValue());
2403         bool Weak =
2404             static_cast<bool>(cast<ConstantAsMetadata>(AliasMD->getOperand(3))
2405                                   ->getValue()
2406                                   ->getUniqueInteger()
2407                                   .getZExtValue());
2408 
2409         auto *Alias = GlobalAlias::create("", M.getNamedAlias(Aliasee));
2410         Alias->setVisibility(Visibility);
2411         if (Weak)
2412           Alias->setLinkage(GlobalValue::WeakAnyLinkage);
2413 
2414         if (auto *F = M.getFunction(AliasName)) {
2415           Alias->takeName(F);
2416           F->replaceAllUsesWith(Alias);
2417           F->eraseFromParent();
2418         } else {
2419           Alias->setName(AliasName);
2420         }
2421       }
2422     }
2423   }
2424 
2425   // Emit .symver directives for exported functions, if they exist.
2426   if (ExportSummary) {
2427     if (NamedMDNode *SymversMD = M.getNamedMetadata("symvers")) {
2428       for (auto *Symver : SymversMD->operands()) {
2429         assert(Symver->getNumOperands() >= 2);
2430         StringRef SymbolName =
2431             cast<MDString>(Symver->getOperand(0))->getString();
2432         StringRef Alias = cast<MDString>(Symver->getOperand(1))->getString();
2433 
2434         if (!ExportedFunctions.count(SymbolName))
2435           continue;
2436 
2437         M.appendModuleInlineAsm(
2438             (llvm::Twine(".symver ") + SymbolName + ", " + Alias).str());
2439       }
2440     }
2441   }
2442 
2443   return true;
2444 }
2445 
run(Module & M,ModuleAnalysisManager & AM)2446 PreservedAnalyses LowerTypeTestsPass::run(Module &M,
2447                                           ModuleAnalysisManager &AM) {
2448   bool Changed;
2449   if (UseCommandLine)
2450     Changed = LowerTypeTestsModule::runForTesting(M, AM);
2451   else
2452     Changed =
2453         LowerTypeTestsModule(M, AM, ExportSummary, ImportSummary, DropTypeTests)
2454             .lower();
2455   if (!Changed)
2456     return PreservedAnalyses::all();
2457   return PreservedAnalyses::none();
2458 }
2459