xref: /freebsd/contrib/llvm-project/llvm/lib/Analysis/TypeMetadataUtils.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- TypeMetadataUtils.cpp - Utilities related to type metadata ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains functions that make it easier to manipulate type metadata
10 // for devirtualization.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Analysis/TypeMetadataUtils.h"
15 #include "llvm/IR/Constants.h"
16 #include "llvm/IR/Dominators.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/IR/IntrinsicInst.h"
19 #include "llvm/IR/Module.h"
20 
21 using namespace llvm;
22 
23 // Search for virtual calls that call FPtr and add them to DevirtCalls.
24 static void
findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> & DevirtCalls,bool * HasNonCallUses,Value * FPtr,uint64_t Offset,const CallInst * CI,DominatorTree & DT)25 findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> &DevirtCalls,
26                           bool *HasNonCallUses, Value *FPtr, uint64_t Offset,
27                           const CallInst *CI, DominatorTree &DT) {
28   for (const Use &U : FPtr->uses()) {
29     Instruction *User = cast<Instruction>(U.getUser());
30     // Ignore this instruction if it is not dominated by the type intrinsic
31     // being analyzed. Otherwise we may transform a call sharing the same
32     // vtable pointer incorrectly. Specifically, this situation can arise
33     // after indirect call promotion and inlining, where we may have uses
34     // of the vtable pointer guarded by a function pointer check, and a fallback
35     // indirect call.
36     if (CI->getFunction() != User->getFunction())
37       continue;
38     if (!DT.dominates(CI, User))
39       continue;
40     if (isa<BitCastInst>(User)) {
41       findCallsAtConstantOffset(DevirtCalls, HasNonCallUses, User, Offset, CI,
42                                 DT);
43     } else if (auto *CI = dyn_cast<CallInst>(User)) {
44       DevirtCalls.push_back({Offset, *CI});
45     } else if (auto *II = dyn_cast<InvokeInst>(User)) {
46       DevirtCalls.push_back({Offset, *II});
47     } else if (HasNonCallUses) {
48       *HasNonCallUses = true;
49     }
50   }
51 }
52 
53 // Search for virtual calls that load from VPtr and add them to DevirtCalls.
findLoadCallsAtConstantOffset(const Module * M,SmallVectorImpl<DevirtCallSite> & DevirtCalls,Value * VPtr,int64_t Offset,const CallInst * CI,DominatorTree & DT)54 static void findLoadCallsAtConstantOffset(
55     const Module *M, SmallVectorImpl<DevirtCallSite> &DevirtCalls, Value *VPtr,
56     int64_t Offset, const CallInst *CI, DominatorTree &DT) {
57   if (!VPtr->hasUseList())
58     return;
59 
60   for (const Use &U : VPtr->uses()) {
61     Value *User = U.getUser();
62     if (isa<BitCastInst>(User)) {
63       findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset, CI, DT);
64     } else if (isa<LoadInst>(User)) {
65       findCallsAtConstantOffset(DevirtCalls, nullptr, User, Offset, CI, DT);
66     } else if (auto GEP = dyn_cast<GetElementPtrInst>(User)) {
67       // Take into account the GEP offset.
68       if (VPtr == GEP->getPointerOperand() && GEP->hasAllConstantIndices()) {
69         SmallVector<Value *, 8> Indices(drop_begin(GEP->operands()));
70         int64_t GEPOffset = M->getDataLayout().getIndexedOffsetInType(
71             GEP->getSourceElementType(), Indices);
72         findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset + GEPOffset,
73                                       CI, DT);
74       }
75     } else if (auto *Call = dyn_cast<CallInst>(User)) {
76       if (Call->getIntrinsicID() == llvm::Intrinsic::load_relative) {
77         if (auto *LoadOffset = dyn_cast<ConstantInt>(Call->getOperand(1))) {
78           findCallsAtConstantOffset(DevirtCalls, nullptr, User,
79                                     Offset + LoadOffset->getSExtValue(), CI,
80                                     DT);
81         }
82       }
83     }
84   }
85 }
86 
findDevirtualizableCallsForTypeTest(SmallVectorImpl<DevirtCallSite> & DevirtCalls,SmallVectorImpl<CallInst * > & Assumes,const CallInst * CI,DominatorTree & DT)87 void llvm::findDevirtualizableCallsForTypeTest(
88     SmallVectorImpl<DevirtCallSite> &DevirtCalls,
89     SmallVectorImpl<CallInst *> &Assumes, const CallInst *CI,
90     DominatorTree &DT) {
91   assert(CI->getCalledFunction()->getIntrinsicID() == Intrinsic::type_test ||
92          CI->getCalledFunction()->getIntrinsicID() ==
93              Intrinsic::public_type_test);
94 
95   const Module *M = CI->getParent()->getParent()->getParent();
96 
97   // Find llvm.assume intrinsics for this llvm.type.test call.
98   for (const Use &CIU : CI->uses())
99     if (auto *Assume = dyn_cast<AssumeInst>(CIU.getUser()))
100       Assumes.push_back(Assume);
101 
102   // If we found any, search for virtual calls based on %p and add them to
103   // DevirtCalls.
104   if (!Assumes.empty())
105     findLoadCallsAtConstantOffset(
106         M, DevirtCalls, CI->getArgOperand(0)->stripPointerCasts(), 0, CI, DT);
107 }
108 
findDevirtualizableCallsForTypeCheckedLoad(SmallVectorImpl<DevirtCallSite> & DevirtCalls,SmallVectorImpl<Instruction * > & LoadedPtrs,SmallVectorImpl<Instruction * > & Preds,bool & HasNonCallUses,const CallInst * CI,DominatorTree & DT)109 void llvm::findDevirtualizableCallsForTypeCheckedLoad(
110     SmallVectorImpl<DevirtCallSite> &DevirtCalls,
111     SmallVectorImpl<Instruction *> &LoadedPtrs,
112     SmallVectorImpl<Instruction *> &Preds, bool &HasNonCallUses,
113     const CallInst *CI, DominatorTree &DT) {
114   assert(CI->getCalledFunction()->getIntrinsicID() ==
115              Intrinsic::type_checked_load ||
116          CI->getCalledFunction()->getIntrinsicID() ==
117              Intrinsic::type_checked_load_relative);
118 
119   auto *Offset = dyn_cast<ConstantInt>(CI->getArgOperand(1));
120   if (!Offset) {
121     HasNonCallUses = true;
122     return;
123   }
124 
125   for (const Use &U : CI->uses()) {
126     auto CIU = U.getUser();
127     if (auto EVI = dyn_cast<ExtractValueInst>(CIU)) {
128       if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 0) {
129         LoadedPtrs.push_back(EVI);
130         continue;
131       }
132       if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 1) {
133         Preds.push_back(EVI);
134         continue;
135       }
136     }
137     HasNonCallUses = true;
138   }
139 
140   for (Value *LoadedPtr : LoadedPtrs)
141     findCallsAtConstantOffset(DevirtCalls, &HasNonCallUses, LoadedPtr,
142                               Offset->getZExtValue(), CI, DT);
143 }
144 
getPointerAtOffset(Constant * I,uint64_t Offset,Module & M,Constant * TopLevelGlobal)145 Constant *llvm::getPointerAtOffset(Constant *I, uint64_t Offset, Module &M,
146                                    Constant *TopLevelGlobal) {
147   // TODO: Ideally it would be the caller who knows if it's appropriate to strip
148   // the DSOLocalEquicalent. More generally, it would feel more appropriate to
149   // have two functions that handle absolute and relative pointers separately.
150   if (auto *Equiv = dyn_cast<DSOLocalEquivalent>(I))
151     I = Equiv->getGlobalValue();
152 
153   if (I->getType()->isPointerTy()) {
154     if (Offset == 0)
155       return I;
156     return nullptr;
157   }
158 
159   const DataLayout &DL = M.getDataLayout();
160 
161   if (auto *C = dyn_cast<ConstantStruct>(I)) {
162     const StructLayout *SL = DL.getStructLayout(C->getType());
163     if (Offset >= SL->getSizeInBytes())
164       return nullptr;
165 
166     unsigned Op = SL->getElementContainingOffset(Offset);
167     return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
168                               Offset - SL->getElementOffset(Op), M,
169                               TopLevelGlobal);
170   }
171   if (auto *C = dyn_cast<ConstantArray>(I)) {
172     ArrayType *VTableTy = C->getType();
173     uint64_t ElemSize = DL.getTypeAllocSize(VTableTy->getElementType());
174 
175     unsigned Op = Offset / ElemSize;
176     if (Op >= C->getNumOperands())
177       return nullptr;
178 
179     return getPointerAtOffset(cast<Constant>(I->getOperand(Op)),
180                               Offset % ElemSize, M, TopLevelGlobal);
181   }
182 
183   // Relative-pointer support starts here.
184   if (auto *CI = dyn_cast<ConstantInt>(I)) {
185     if (Offset == 0 && CI->isZero()) {
186       return I;
187     }
188   }
189   if (auto *C = dyn_cast<ConstantExpr>(I)) {
190     switch (C->getOpcode()) {
191     case Instruction::Trunc:
192     case Instruction::PtrToInt:
193       return getPointerAtOffset(cast<Constant>(C->getOperand(0)), Offset, M,
194                                 TopLevelGlobal);
195     case Instruction::Sub: {
196       auto *Operand0 = cast<Constant>(C->getOperand(0));
197       auto *Operand1 = cast<Constant>(C->getOperand(1));
198 
199       auto StripGEP = [](Constant *C) {
200         auto *CE = dyn_cast<ConstantExpr>(C);
201         if (!CE)
202           return C;
203         if (CE->getOpcode() != Instruction::GetElementPtr)
204           return C;
205         return CE->getOperand(0);
206       };
207       auto *Operand1TargetGlobal = StripGEP(getPointerAtOffset(Operand1, 0, M));
208 
209       // Check that in the "sub (@a, @b)" expression, @b points back to the top
210       // level global (or a GEP thereof) that we're processing. Otherwise bail.
211       if (Operand1TargetGlobal != TopLevelGlobal)
212         return nullptr;
213 
214       return getPointerAtOffset(Operand0, Offset, M, TopLevelGlobal);
215     }
216     default:
217       return nullptr;
218     }
219   }
220   return nullptr;
221 }
222 
223 std::pair<Function *, Constant *>
getFunctionAtVTableOffset(GlobalVariable * GV,uint64_t Offset,Module & M)224 llvm::getFunctionAtVTableOffset(GlobalVariable *GV, uint64_t Offset,
225                                 Module &M) {
226   Constant *Ptr = getPointerAtOffset(GV->getInitializer(), Offset, M, GV);
227   if (!Ptr)
228     return std::pair<Function *, Constant *>(nullptr, nullptr);
229 
230   auto C = Ptr->stripPointerCasts();
231   // Make sure this is a function or alias to a function.
232   auto Fn = dyn_cast<Function>(C);
233   auto A = dyn_cast<GlobalAlias>(C);
234   if (!Fn && A)
235     Fn = dyn_cast<Function>(A->getAliasee());
236 
237   if (!Fn)
238     return std::pair<Function *, Constant *>(nullptr, nullptr);
239 
240   return std::pair<Function *, Constant *>(Fn, C);
241 }
242 
replaceRelativePointerUserWithZero(User * U)243 static void replaceRelativePointerUserWithZero(User *U) {
244   auto *PtrExpr = dyn_cast<ConstantExpr>(U);
245   if (!PtrExpr || PtrExpr->getOpcode() != Instruction::PtrToInt)
246     return;
247 
248   for (auto *PtrToIntUser : PtrExpr->users()) {
249     auto *SubExpr = dyn_cast<ConstantExpr>(PtrToIntUser);
250     if (!SubExpr || SubExpr->getOpcode() != Instruction::Sub)
251       return;
252 
253     SubExpr->replaceNonMetadataUsesWith(
254         ConstantInt::get(SubExpr->getType(), 0));
255   }
256 }
257 
replaceRelativePointerUsersWithZero(Constant * C)258 void llvm::replaceRelativePointerUsersWithZero(Constant *C) {
259   for (auto *U : C->users()) {
260     if (auto *Equiv = dyn_cast<DSOLocalEquivalent>(U))
261       replaceRelativePointerUsersWithZero(Equiv);
262     else
263       replaceRelativePointerUserWithZero(U);
264   }
265 }
266