xref: /freebsd/contrib/llvm-project/llvm/lib/SandboxIR/Context.cpp (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- Context.cpp - The Context class of Sandbox IR ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/SandboxIR/Context.h"
10 #include "llvm/IR/InlineAsm.h"
11 #include "llvm/SandboxIR/Function.h"
12 #include "llvm/SandboxIR/Instruction.h"
13 #include "llvm/SandboxIR/Module.h"
14 
15 namespace llvm::sandboxir {
16 
detachLLVMValue(llvm::Value * V)17 std::unique_ptr<Value> Context::detachLLVMValue(llvm::Value *V) {
18   std::unique_ptr<Value> Erased;
19   auto It = LLVMValueToValueMap.find(V);
20   if (It != LLVMValueToValueMap.end()) {
21     auto *Val = It->second.release();
22     Erased = std::unique_ptr<Value>(Val);
23     LLVMValueToValueMap.erase(It);
24   }
25   return Erased;
26 }
27 
detach(Value * V)28 std::unique_ptr<Value> Context::detach(Value *V) {
29   assert(V->getSubclassID() != Value::ClassID::Constant &&
30          "Can't detach a constant!");
31   assert(V->getSubclassID() != Value::ClassID::User && "Can't detach a user!");
32   return detachLLVMValue(V->Val);
33 }
34 
registerValue(std::unique_ptr<Value> && VPtr)35 Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
36   assert(VPtr->getSubclassID() != Value::ClassID::User &&
37          "Can't register a user!");
38 
39   Value *V = VPtr.get();
40   [[maybe_unused]] auto Pair =
41       LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
42   assert(Pair.second && "Already exists!");
43 
44   // Track creation of instructions.
45   // Please note that we don't allow the creation of detached instructions,
46   // meaning that the instructions need to be inserted into a block upon
47   // creation. This is why the tracker class combines creation and insertion.
48   if (auto *I = dyn_cast<Instruction>(V)) {
49     getTracker().emplaceIfTracking<CreateAndInsertInst>(I);
50     runCreateInstrCallbacks(I);
51   }
52 
53   return V;
54 }
55 
getOrCreateValueInternal(llvm::Value * LLVMV,llvm::User * U)56 Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
57   auto Pair = LLVMValueToValueMap.try_emplace(LLVMV);
58   auto It = Pair.first;
59   if (!Pair.second)
60     return It->second.get();
61 
62   // Instruction
63   if (auto *LLVMI = dyn_cast<llvm::Instruction>(LLVMV)) {
64     switch (LLVMI->getOpcode()) {
65     case llvm::Instruction::VAArg: {
66       auto *LLVMVAArg = cast<llvm::VAArgInst>(LLVMV);
67       It->second = std::unique_ptr<VAArgInst>(new VAArgInst(LLVMVAArg, *this));
68       return It->second.get();
69     }
70     case llvm::Instruction::Freeze: {
71       auto *LLVMFreeze = cast<llvm::FreezeInst>(LLVMV);
72       It->second =
73           std::unique_ptr<FreezeInst>(new FreezeInst(LLVMFreeze, *this));
74       return It->second.get();
75     }
76     case llvm::Instruction::Fence: {
77       auto *LLVMFence = cast<llvm::FenceInst>(LLVMV);
78       It->second = std::unique_ptr<FenceInst>(new FenceInst(LLVMFence, *this));
79       return It->second.get();
80     }
81     case llvm::Instruction::Select: {
82       auto *LLVMSel = cast<llvm::SelectInst>(LLVMV);
83       It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
84       return It->second.get();
85     }
86     case llvm::Instruction::ExtractElement: {
87       auto *LLVMIns = cast<llvm::ExtractElementInst>(LLVMV);
88       It->second = std::unique_ptr<ExtractElementInst>(
89           new ExtractElementInst(LLVMIns, *this));
90       return It->second.get();
91     }
92     case llvm::Instruction::InsertElement: {
93       auto *LLVMIns = cast<llvm::InsertElementInst>(LLVMV);
94       It->second = std::unique_ptr<InsertElementInst>(
95           new InsertElementInst(LLVMIns, *this));
96       return It->second.get();
97     }
98     case llvm::Instruction::ShuffleVector: {
99       auto *LLVMIns = cast<llvm::ShuffleVectorInst>(LLVMV);
100       It->second = std::unique_ptr<ShuffleVectorInst>(
101           new ShuffleVectorInst(LLVMIns, *this));
102       return It->second.get();
103     }
104     case llvm::Instruction::ExtractValue: {
105       auto *LLVMIns = cast<llvm::ExtractValueInst>(LLVMV);
106       It->second = std::unique_ptr<ExtractValueInst>(
107           new ExtractValueInst(LLVMIns, *this));
108       return It->second.get();
109     }
110     case llvm::Instruction::InsertValue: {
111       auto *LLVMIns = cast<llvm::InsertValueInst>(LLVMV);
112       It->second =
113           std::unique_ptr<InsertValueInst>(new InsertValueInst(LLVMIns, *this));
114       return It->second.get();
115     }
116     case llvm::Instruction::Br: {
117       auto *LLVMBr = cast<llvm::BranchInst>(LLVMV);
118       It->second = std::unique_ptr<BranchInst>(new BranchInst(LLVMBr, *this));
119       return It->second.get();
120     }
121     case llvm::Instruction::Load: {
122       auto *LLVMLd = cast<llvm::LoadInst>(LLVMV);
123       It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
124       return It->second.get();
125     }
126     case llvm::Instruction::Store: {
127       auto *LLVMSt = cast<llvm::StoreInst>(LLVMV);
128       It->second = std::unique_ptr<StoreInst>(new StoreInst(LLVMSt, *this));
129       return It->second.get();
130     }
131     case llvm::Instruction::Ret: {
132       auto *LLVMRet = cast<llvm::ReturnInst>(LLVMV);
133       It->second = std::unique_ptr<ReturnInst>(new ReturnInst(LLVMRet, *this));
134       return It->second.get();
135     }
136     case llvm::Instruction::Call: {
137       auto *LLVMCall = cast<llvm::CallInst>(LLVMV);
138       It->second = std::unique_ptr<CallInst>(new CallInst(LLVMCall, *this));
139       return It->second.get();
140     }
141     case llvm::Instruction::Invoke: {
142       auto *LLVMInvoke = cast<llvm::InvokeInst>(LLVMV);
143       It->second =
144           std::unique_ptr<InvokeInst>(new InvokeInst(LLVMInvoke, *this));
145       return It->second.get();
146     }
147     case llvm::Instruction::CallBr: {
148       auto *LLVMCallBr = cast<llvm::CallBrInst>(LLVMV);
149       It->second =
150           std::unique_ptr<CallBrInst>(new CallBrInst(LLVMCallBr, *this));
151       return It->second.get();
152     }
153     case llvm::Instruction::LandingPad: {
154       auto *LLVMLPad = cast<llvm::LandingPadInst>(LLVMV);
155       It->second =
156           std::unique_ptr<LandingPadInst>(new LandingPadInst(LLVMLPad, *this));
157       return It->second.get();
158     }
159     case llvm::Instruction::CatchPad: {
160       auto *LLVMCPI = cast<llvm::CatchPadInst>(LLVMV);
161       It->second =
162           std::unique_ptr<CatchPadInst>(new CatchPadInst(LLVMCPI, *this));
163       return It->second.get();
164     }
165     case llvm::Instruction::CleanupPad: {
166       auto *LLVMCPI = cast<llvm::CleanupPadInst>(LLVMV);
167       It->second =
168           std::unique_ptr<CleanupPadInst>(new CleanupPadInst(LLVMCPI, *this));
169       return It->second.get();
170     }
171     case llvm::Instruction::CatchRet: {
172       auto *LLVMCRI = cast<llvm::CatchReturnInst>(LLVMV);
173       It->second =
174           std::unique_ptr<CatchReturnInst>(new CatchReturnInst(LLVMCRI, *this));
175       return It->second.get();
176     }
177     case llvm::Instruction::CleanupRet: {
178       auto *LLVMCRI = cast<llvm::CleanupReturnInst>(LLVMV);
179       It->second = std::unique_ptr<CleanupReturnInst>(
180           new CleanupReturnInst(LLVMCRI, *this));
181       return It->second.get();
182     }
183     case llvm::Instruction::GetElementPtr: {
184       auto *LLVMGEP = cast<llvm::GetElementPtrInst>(LLVMV);
185       It->second = std::unique_ptr<GetElementPtrInst>(
186           new GetElementPtrInst(LLVMGEP, *this));
187       return It->second.get();
188     }
189     case llvm::Instruction::CatchSwitch: {
190       auto *LLVMCatchSwitchInst = cast<llvm::CatchSwitchInst>(LLVMV);
191       It->second = std::unique_ptr<CatchSwitchInst>(
192           new CatchSwitchInst(LLVMCatchSwitchInst, *this));
193       return It->second.get();
194     }
195     case llvm::Instruction::Resume: {
196       auto *LLVMResumeInst = cast<llvm::ResumeInst>(LLVMV);
197       It->second =
198           std::unique_ptr<ResumeInst>(new ResumeInst(LLVMResumeInst, *this));
199       return It->second.get();
200     }
201     case llvm::Instruction::Switch: {
202       auto *LLVMSwitchInst = cast<llvm::SwitchInst>(LLVMV);
203       It->second =
204           std::unique_ptr<SwitchInst>(new SwitchInst(LLVMSwitchInst, *this));
205       return It->second.get();
206     }
207     case llvm::Instruction::FNeg: {
208       auto *LLVMUnaryOperator = cast<llvm::UnaryOperator>(LLVMV);
209       It->second = std::unique_ptr<UnaryOperator>(
210           new UnaryOperator(LLVMUnaryOperator, *this));
211       return It->second.get();
212     }
213     case llvm::Instruction::Add:
214     case llvm::Instruction::FAdd:
215     case llvm::Instruction::Sub:
216     case llvm::Instruction::FSub:
217     case llvm::Instruction::Mul:
218     case llvm::Instruction::FMul:
219     case llvm::Instruction::UDiv:
220     case llvm::Instruction::SDiv:
221     case llvm::Instruction::FDiv:
222     case llvm::Instruction::URem:
223     case llvm::Instruction::SRem:
224     case llvm::Instruction::FRem:
225     case llvm::Instruction::Shl:
226     case llvm::Instruction::LShr:
227     case llvm::Instruction::AShr:
228     case llvm::Instruction::And:
229     case llvm::Instruction::Or:
230     case llvm::Instruction::Xor: {
231       auto *LLVMBinaryOperator = cast<llvm::BinaryOperator>(LLVMV);
232       It->second = std::unique_ptr<BinaryOperator>(
233           new BinaryOperator(LLVMBinaryOperator, *this));
234       return It->second.get();
235     }
236     case llvm::Instruction::AtomicRMW: {
237       auto *LLVMAtomicRMW = cast<llvm::AtomicRMWInst>(LLVMV);
238       It->second = std::unique_ptr<AtomicRMWInst>(
239           new AtomicRMWInst(LLVMAtomicRMW, *this));
240       return It->second.get();
241     }
242     case llvm::Instruction::AtomicCmpXchg: {
243       auto *LLVMAtomicCmpXchg = cast<llvm::AtomicCmpXchgInst>(LLVMV);
244       It->second = std::unique_ptr<AtomicCmpXchgInst>(
245           new AtomicCmpXchgInst(LLVMAtomicCmpXchg, *this));
246       return It->second.get();
247     }
248     case llvm::Instruction::Alloca: {
249       auto *LLVMAlloca = cast<llvm::AllocaInst>(LLVMV);
250       It->second =
251           std::unique_ptr<AllocaInst>(new AllocaInst(LLVMAlloca, *this));
252       return It->second.get();
253     }
254     case llvm::Instruction::ZExt:
255     case llvm::Instruction::SExt:
256     case llvm::Instruction::FPToUI:
257     case llvm::Instruction::FPToSI:
258     case llvm::Instruction::FPExt:
259     case llvm::Instruction::PtrToInt:
260     case llvm::Instruction::IntToPtr:
261     case llvm::Instruction::SIToFP:
262     case llvm::Instruction::UIToFP:
263     case llvm::Instruction::Trunc:
264     case llvm::Instruction::FPTrunc:
265     case llvm::Instruction::BitCast:
266     case llvm::Instruction::AddrSpaceCast: {
267       auto *LLVMCast = cast<llvm::CastInst>(LLVMV);
268       It->second = std::unique_ptr<CastInst>(new CastInst(LLVMCast, *this));
269       return It->second.get();
270     }
271     case llvm::Instruction::PHI: {
272       auto *LLVMPhi = cast<llvm::PHINode>(LLVMV);
273       It->second = std::unique_ptr<PHINode>(new PHINode(LLVMPhi, *this));
274       return It->second.get();
275     }
276     case llvm::Instruction::ICmp: {
277       auto *LLVMICmp = cast<llvm::ICmpInst>(LLVMV);
278       It->second = std::unique_ptr<ICmpInst>(new ICmpInst(LLVMICmp, *this));
279       return It->second.get();
280     }
281     case llvm::Instruction::FCmp: {
282       auto *LLVMFCmp = cast<llvm::FCmpInst>(LLVMV);
283       It->second = std::unique_ptr<FCmpInst>(new FCmpInst(LLVMFCmp, *this));
284       return It->second.get();
285     }
286     case llvm::Instruction::Unreachable: {
287       auto *LLVMUnreachable = cast<llvm::UnreachableInst>(LLVMV);
288       It->second = std::unique_ptr<UnreachableInst>(
289           new UnreachableInst(LLVMUnreachable, *this));
290       return It->second.get();
291     }
292     default:
293       break;
294     }
295     It->second = std::unique_ptr<OpaqueInst>(
296         new OpaqueInst(cast<llvm::Instruction>(LLVMV), *this));
297     return It->second.get();
298   }
299   // Constant
300   if (auto *LLVMC = dyn_cast<llvm::Constant>(LLVMV)) {
301     switch (LLVMC->getValueID()) {
302     case llvm::Value::ConstantIntVal:
303       It->second = std::unique_ptr<ConstantInt>(
304           new ConstantInt(cast<llvm::ConstantInt>(LLVMC), *this));
305       return It->second.get();
306     case llvm::Value::ConstantFPVal:
307       It->second = std::unique_ptr<ConstantFP>(
308           new ConstantFP(cast<llvm::ConstantFP>(LLVMC), *this));
309       return It->second.get();
310     case llvm::Value::BlockAddressVal:
311       It->second = std::unique_ptr<BlockAddress>(
312           new BlockAddress(cast<llvm::BlockAddress>(LLVMC), *this));
313       return It->second.get();
314     case llvm::Value::ConstantTokenNoneVal:
315       It->second = std::unique_ptr<ConstantTokenNone>(
316           new ConstantTokenNone(cast<llvm::ConstantTokenNone>(LLVMC), *this));
317       return It->second.get();
318     case llvm::Value::ConstantAggregateZeroVal: {
319       auto *CAZ = cast<llvm::ConstantAggregateZero>(LLVMC);
320       It->second = std::unique_ptr<ConstantAggregateZero>(
321           new ConstantAggregateZero(CAZ, *this));
322       auto *Ret = It->second.get();
323       // Must create sandboxir for elements.
324       auto EC = CAZ->getElementCount();
325       if (EC.isFixed()) {
326         for (auto ElmIdx : seq<unsigned>(0, EC.getFixedValue()))
327           getOrCreateValueInternal(CAZ->getElementValue(ElmIdx), CAZ);
328       }
329       return Ret;
330     }
331     case llvm::Value::ConstantPointerNullVal:
332       It->second = std::unique_ptr<ConstantPointerNull>(new ConstantPointerNull(
333           cast<llvm::ConstantPointerNull>(LLVMC), *this));
334       return It->second.get();
335     case llvm::Value::PoisonValueVal:
336       It->second = std::unique_ptr<PoisonValue>(
337           new PoisonValue(cast<llvm::PoisonValue>(LLVMC), *this));
338       return It->second.get();
339     case llvm::Value::UndefValueVal:
340       It->second = std::unique_ptr<UndefValue>(
341           new UndefValue(cast<llvm::UndefValue>(LLVMC), *this));
342       return It->second.get();
343     case llvm::Value::DSOLocalEquivalentVal: {
344       auto *DSOLE = cast<llvm::DSOLocalEquivalent>(LLVMC);
345       It->second = std::unique_ptr<DSOLocalEquivalent>(
346           new DSOLocalEquivalent(DSOLE, *this));
347       auto *Ret = It->second.get();
348       getOrCreateValueInternal(DSOLE->getGlobalValue(), DSOLE);
349       return Ret;
350     }
351     case llvm::Value::ConstantArrayVal:
352       It->second = std::unique_ptr<ConstantArray>(
353           new ConstantArray(cast<llvm::ConstantArray>(LLVMC), *this));
354       break;
355     case llvm::Value::ConstantStructVal:
356       It->second = std::unique_ptr<ConstantStruct>(
357           new ConstantStruct(cast<llvm::ConstantStruct>(LLVMC), *this));
358       break;
359     case llvm::Value::ConstantVectorVal:
360       It->second = std::unique_ptr<ConstantVector>(
361           new ConstantVector(cast<llvm::ConstantVector>(LLVMC), *this));
362       break;
363     case llvm::Value::ConstantDataArrayVal:
364       It->second = std::unique_ptr<ConstantDataArray>(
365           new ConstantDataArray(cast<llvm::ConstantDataArray>(LLVMC), *this));
366       break;
367     case llvm::Value::ConstantDataVectorVal:
368       It->second = std::unique_ptr<ConstantDataVector>(
369           new ConstantDataVector(cast<llvm::ConstantDataVector>(LLVMC), *this));
370       break;
371     case llvm::Value::FunctionVal:
372       It->second = std::unique_ptr<Function>(
373           new Function(cast<llvm::Function>(LLVMC), *this));
374       break;
375     case llvm::Value::GlobalIFuncVal:
376       It->second = std::unique_ptr<GlobalIFunc>(
377           new GlobalIFunc(cast<llvm::GlobalIFunc>(LLVMC), *this));
378       break;
379     case llvm::Value::GlobalVariableVal:
380       It->second = std::unique_ptr<GlobalVariable>(
381           new GlobalVariable(cast<llvm::GlobalVariable>(LLVMC), *this));
382       break;
383     case llvm::Value::GlobalAliasVal:
384       It->second = std::unique_ptr<GlobalAlias>(
385           new GlobalAlias(cast<llvm::GlobalAlias>(LLVMC), *this));
386       break;
387     case llvm::Value::NoCFIValueVal:
388       It->second = std::unique_ptr<NoCFIValue>(
389           new NoCFIValue(cast<llvm::NoCFIValue>(LLVMC), *this));
390       break;
391     case llvm::Value::ConstantPtrAuthVal:
392       It->second = std::unique_ptr<ConstantPtrAuth>(
393           new ConstantPtrAuth(cast<llvm::ConstantPtrAuth>(LLVMC), *this));
394       break;
395     case llvm::Value::ConstantExprVal:
396       It->second = std::unique_ptr<ConstantExpr>(
397           new ConstantExpr(cast<llvm::ConstantExpr>(LLVMC), *this));
398       break;
399     default:
400       It->second = std::unique_ptr<Constant>(new Constant(LLVMC, *this));
401       break;
402     }
403     auto *NewC = It->second.get();
404     for (llvm::Value *COp : LLVMC->operands())
405       getOrCreateValueInternal(COp, LLVMC);
406     return NewC;
407   }
408   // Argument
409   if (auto *LLVMArg = dyn_cast<llvm::Argument>(LLVMV)) {
410     It->second = std::unique_ptr<Argument>(new Argument(LLVMArg, *this));
411     return It->second.get();
412   }
413   // BasicBlock
414   if (auto *LLVMBB = dyn_cast<llvm::BasicBlock>(LLVMV)) {
415     assert(isa<llvm::BlockAddress>(U) &&
416            "This won't create a SBBB, don't call this function directly!");
417     if (auto *SBBB = getValue(LLVMBB))
418       return SBBB;
419     return nullptr;
420   }
421   // Metadata
422   if (auto *LLVMMD = dyn_cast<llvm::MetadataAsValue>(LLVMV)) {
423     It->second = std::unique_ptr<OpaqueValue>(new OpaqueValue(LLVMMD, *this));
424     return It->second.get();
425   }
426   // InlineAsm
427   if (auto *LLVMAsm = dyn_cast<llvm::InlineAsm>(LLVMV)) {
428     It->second = std::unique_ptr<OpaqueValue>(new OpaqueValue(LLVMAsm, *this));
429     return It->second.get();
430   }
431   llvm_unreachable("Unhandled LLVMV type!");
432 }
433 
getOrCreateArgument(llvm::Argument * LLVMArg)434 Argument *Context::getOrCreateArgument(llvm::Argument *LLVMArg) {
435   auto Pair = LLVMValueToValueMap.try_emplace(LLVMArg);
436   auto It = Pair.first;
437   if (Pair.second) {
438     It->second = std::unique_ptr<Argument>(new Argument(LLVMArg, *this));
439     return cast<Argument>(It->second.get());
440   }
441   return cast<Argument>(It->second.get());
442 }
443 
getOrCreateConstant(llvm::Constant * LLVMC)444 Constant *Context::getOrCreateConstant(llvm::Constant *LLVMC) {
445   return cast<Constant>(getOrCreateValueInternal(LLVMC, 0));
446 }
447 
createBasicBlock(llvm::BasicBlock * LLVMBB)448 BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) {
449   assert(getValue(LLVMBB) == nullptr && "Already exists!");
450   auto NewBBPtr = std::unique_ptr<BasicBlock>(new BasicBlock(LLVMBB, *this));
451   auto *BB = cast<BasicBlock>(registerValue(std::move(NewBBPtr)));
452   // Create SandboxIR for BB's body.
453   BB->buildBasicBlockFromLLVMIR(LLVMBB);
454   return BB;
455 }
456 
createVAArgInst(llvm::VAArgInst * SI)457 VAArgInst *Context::createVAArgInst(llvm::VAArgInst *SI) {
458   auto NewPtr = std::unique_ptr<VAArgInst>(new VAArgInst(SI, *this));
459   return cast<VAArgInst>(registerValue(std::move(NewPtr)));
460 }
461 
createFreezeInst(llvm::FreezeInst * SI)462 FreezeInst *Context::createFreezeInst(llvm::FreezeInst *SI) {
463   auto NewPtr = std::unique_ptr<FreezeInst>(new FreezeInst(SI, *this));
464   return cast<FreezeInst>(registerValue(std::move(NewPtr)));
465 }
466 
createFenceInst(llvm::FenceInst * SI)467 FenceInst *Context::createFenceInst(llvm::FenceInst *SI) {
468   auto NewPtr = std::unique_ptr<FenceInst>(new FenceInst(SI, *this));
469   return cast<FenceInst>(registerValue(std::move(NewPtr)));
470 }
471 
createSelectInst(llvm::SelectInst * SI)472 SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
473   auto NewPtr = std::unique_ptr<SelectInst>(new SelectInst(SI, *this));
474   return cast<SelectInst>(registerValue(std::move(NewPtr)));
475 }
476 
477 ExtractElementInst *
createExtractElementInst(llvm::ExtractElementInst * EEI)478 Context::createExtractElementInst(llvm::ExtractElementInst *EEI) {
479   auto NewPtr =
480       std::unique_ptr<ExtractElementInst>(new ExtractElementInst(EEI, *this));
481   return cast<ExtractElementInst>(registerValue(std::move(NewPtr)));
482 }
483 
484 InsertElementInst *
createInsertElementInst(llvm::InsertElementInst * IEI)485 Context::createInsertElementInst(llvm::InsertElementInst *IEI) {
486   auto NewPtr =
487       std::unique_ptr<InsertElementInst>(new InsertElementInst(IEI, *this));
488   return cast<InsertElementInst>(registerValue(std::move(NewPtr)));
489 }
490 
491 ShuffleVectorInst *
createShuffleVectorInst(llvm::ShuffleVectorInst * SVI)492 Context::createShuffleVectorInst(llvm::ShuffleVectorInst *SVI) {
493   auto NewPtr =
494       std::unique_ptr<ShuffleVectorInst>(new ShuffleVectorInst(SVI, *this));
495   return cast<ShuffleVectorInst>(registerValue(std::move(NewPtr)));
496 }
497 
createExtractValueInst(llvm::ExtractValueInst * EVI)498 ExtractValueInst *Context::createExtractValueInst(llvm::ExtractValueInst *EVI) {
499   auto NewPtr =
500       std::unique_ptr<ExtractValueInst>(new ExtractValueInst(EVI, *this));
501   return cast<ExtractValueInst>(registerValue(std::move(NewPtr)));
502 }
503 
createInsertValueInst(llvm::InsertValueInst * IVI)504 InsertValueInst *Context::createInsertValueInst(llvm::InsertValueInst *IVI) {
505   auto NewPtr =
506       std::unique_ptr<InsertValueInst>(new InsertValueInst(IVI, *this));
507   return cast<InsertValueInst>(registerValue(std::move(NewPtr)));
508 }
509 
createBranchInst(llvm::BranchInst * BI)510 BranchInst *Context::createBranchInst(llvm::BranchInst *BI) {
511   auto NewPtr = std::unique_ptr<BranchInst>(new BranchInst(BI, *this));
512   return cast<BranchInst>(registerValue(std::move(NewPtr)));
513 }
514 
createLoadInst(llvm::LoadInst * LI)515 LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
516   auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
517   return cast<LoadInst>(registerValue(std::move(NewPtr)));
518 }
519 
createStoreInst(llvm::StoreInst * SI)520 StoreInst *Context::createStoreInst(llvm::StoreInst *SI) {
521   auto NewPtr = std::unique_ptr<StoreInst>(new StoreInst(SI, *this));
522   return cast<StoreInst>(registerValue(std::move(NewPtr)));
523 }
524 
createReturnInst(llvm::ReturnInst * I)525 ReturnInst *Context::createReturnInst(llvm::ReturnInst *I) {
526   auto NewPtr = std::unique_ptr<ReturnInst>(new ReturnInst(I, *this));
527   return cast<ReturnInst>(registerValue(std::move(NewPtr)));
528 }
529 
createCallInst(llvm::CallInst * I)530 CallInst *Context::createCallInst(llvm::CallInst *I) {
531   auto NewPtr = std::unique_ptr<CallInst>(new CallInst(I, *this));
532   return cast<CallInst>(registerValue(std::move(NewPtr)));
533 }
534 
createInvokeInst(llvm::InvokeInst * I)535 InvokeInst *Context::createInvokeInst(llvm::InvokeInst *I) {
536   auto NewPtr = std::unique_ptr<InvokeInst>(new InvokeInst(I, *this));
537   return cast<InvokeInst>(registerValue(std::move(NewPtr)));
538 }
539 
createCallBrInst(llvm::CallBrInst * I)540 CallBrInst *Context::createCallBrInst(llvm::CallBrInst *I) {
541   auto NewPtr = std::unique_ptr<CallBrInst>(new CallBrInst(I, *this));
542   return cast<CallBrInst>(registerValue(std::move(NewPtr)));
543 }
544 
createUnreachableInst(llvm::UnreachableInst * UI)545 UnreachableInst *Context::createUnreachableInst(llvm::UnreachableInst *UI) {
546   auto NewPtr =
547       std::unique_ptr<UnreachableInst>(new UnreachableInst(UI, *this));
548   return cast<UnreachableInst>(registerValue(std::move(NewPtr)));
549 }
createLandingPadInst(llvm::LandingPadInst * I)550 LandingPadInst *Context::createLandingPadInst(llvm::LandingPadInst *I) {
551   auto NewPtr = std::unique_ptr<LandingPadInst>(new LandingPadInst(I, *this));
552   return cast<LandingPadInst>(registerValue(std::move(NewPtr)));
553 }
createCatchPadInst(llvm::CatchPadInst * I)554 CatchPadInst *Context::createCatchPadInst(llvm::CatchPadInst *I) {
555   auto NewPtr = std::unique_ptr<CatchPadInst>(new CatchPadInst(I, *this));
556   return cast<CatchPadInst>(registerValue(std::move(NewPtr)));
557 }
createCleanupPadInst(llvm::CleanupPadInst * I)558 CleanupPadInst *Context::createCleanupPadInst(llvm::CleanupPadInst *I) {
559   auto NewPtr = std::unique_ptr<CleanupPadInst>(new CleanupPadInst(I, *this));
560   return cast<CleanupPadInst>(registerValue(std::move(NewPtr)));
561 }
createCatchReturnInst(llvm::CatchReturnInst * I)562 CatchReturnInst *Context::createCatchReturnInst(llvm::CatchReturnInst *I) {
563   auto NewPtr = std::unique_ptr<CatchReturnInst>(new CatchReturnInst(I, *this));
564   return cast<CatchReturnInst>(registerValue(std::move(NewPtr)));
565 }
566 CleanupReturnInst *
createCleanupReturnInst(llvm::CleanupReturnInst * I)567 Context::createCleanupReturnInst(llvm::CleanupReturnInst *I) {
568   auto NewPtr =
569       std::unique_ptr<CleanupReturnInst>(new CleanupReturnInst(I, *this));
570   return cast<CleanupReturnInst>(registerValue(std::move(NewPtr)));
571 }
572 GetElementPtrInst *
createGetElementPtrInst(llvm::GetElementPtrInst * I)573 Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
574   auto NewPtr =
575       std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
576   return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
577 }
createCatchSwitchInst(llvm::CatchSwitchInst * I)578 CatchSwitchInst *Context::createCatchSwitchInst(llvm::CatchSwitchInst *I) {
579   auto NewPtr = std::unique_ptr<CatchSwitchInst>(new CatchSwitchInst(I, *this));
580   return cast<CatchSwitchInst>(registerValue(std::move(NewPtr)));
581 }
createResumeInst(llvm::ResumeInst * I)582 ResumeInst *Context::createResumeInst(llvm::ResumeInst *I) {
583   auto NewPtr = std::unique_ptr<ResumeInst>(new ResumeInst(I, *this));
584   return cast<ResumeInst>(registerValue(std::move(NewPtr)));
585 }
createSwitchInst(llvm::SwitchInst * I)586 SwitchInst *Context::createSwitchInst(llvm::SwitchInst *I) {
587   auto NewPtr = std::unique_ptr<SwitchInst>(new SwitchInst(I, *this));
588   return cast<SwitchInst>(registerValue(std::move(NewPtr)));
589 }
createUnaryOperator(llvm::UnaryOperator * I)590 UnaryOperator *Context::createUnaryOperator(llvm::UnaryOperator *I) {
591   auto NewPtr = std::unique_ptr<UnaryOperator>(new UnaryOperator(I, *this));
592   return cast<UnaryOperator>(registerValue(std::move(NewPtr)));
593 }
createBinaryOperator(llvm::BinaryOperator * I)594 BinaryOperator *Context::createBinaryOperator(llvm::BinaryOperator *I) {
595   auto NewPtr = std::unique_ptr<BinaryOperator>(new BinaryOperator(I, *this));
596   return cast<BinaryOperator>(registerValue(std::move(NewPtr)));
597 }
createAtomicRMWInst(llvm::AtomicRMWInst * I)598 AtomicRMWInst *Context::createAtomicRMWInst(llvm::AtomicRMWInst *I) {
599   auto NewPtr = std::unique_ptr<AtomicRMWInst>(new AtomicRMWInst(I, *this));
600   return cast<AtomicRMWInst>(registerValue(std::move(NewPtr)));
601 }
602 AtomicCmpXchgInst *
createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst * I)603 Context::createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I) {
604   auto NewPtr =
605       std::unique_ptr<AtomicCmpXchgInst>(new AtomicCmpXchgInst(I, *this));
606   return cast<AtomicCmpXchgInst>(registerValue(std::move(NewPtr)));
607 }
createAllocaInst(llvm::AllocaInst * I)608 AllocaInst *Context::createAllocaInst(llvm::AllocaInst *I) {
609   auto NewPtr = std::unique_ptr<AllocaInst>(new AllocaInst(I, *this));
610   return cast<AllocaInst>(registerValue(std::move(NewPtr)));
611 }
createCastInst(llvm::CastInst * I)612 CastInst *Context::createCastInst(llvm::CastInst *I) {
613   auto NewPtr = std::unique_ptr<CastInst>(new CastInst(I, *this));
614   return cast<CastInst>(registerValue(std::move(NewPtr)));
615 }
createPHINode(llvm::PHINode * I)616 PHINode *Context::createPHINode(llvm::PHINode *I) {
617   auto NewPtr = std::unique_ptr<PHINode>(new PHINode(I, *this));
618   return cast<PHINode>(registerValue(std::move(NewPtr)));
619 }
createICmpInst(llvm::ICmpInst * I)620 ICmpInst *Context::createICmpInst(llvm::ICmpInst *I) {
621   auto NewPtr = std::unique_ptr<ICmpInst>(new ICmpInst(I, *this));
622   return cast<ICmpInst>(registerValue(std::move(NewPtr)));
623 }
createFCmpInst(llvm::FCmpInst * I)624 FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
625   auto NewPtr = std::unique_ptr<FCmpInst>(new FCmpInst(I, *this));
626   return cast<FCmpInst>(registerValue(std::move(NewPtr)));
627 }
getValue(llvm::Value * V) const628 Value *Context::getValue(llvm::Value *V) const {
629   auto It = LLVMValueToValueMap.find(V);
630   if (It != LLVMValueToValueMap.end())
631     return It->second.get();
632   return nullptr;
633 }
634 
Context(LLVMContext & LLVMCtx)635 Context::Context(LLVMContext &LLVMCtx)
636     : LLVMCtx(LLVMCtx), IRTracker(*this),
637       LLVMIRBuilder(LLVMCtx, ConstantFolder()) {}
638 
~Context()639 Context::~Context() {}
640 
clear()641 void Context::clear() {
642   // TODO: Ideally we should clear only function-scope objects, and keep global
643   // objects, like Constants to avoid recreating them.
644   LLVMValueToValueMap.clear();
645 }
646 
getModule(llvm::Module * LLVMM) const647 Module *Context::getModule(llvm::Module *LLVMM) const {
648   auto It = LLVMModuleToModuleMap.find(LLVMM);
649   if (It != LLVMModuleToModuleMap.end())
650     return It->second.get();
651   return nullptr;
652 }
653 
getOrCreateModule(llvm::Module * LLVMM)654 Module *Context::getOrCreateModule(llvm::Module *LLVMM) {
655   auto Pair = LLVMModuleToModuleMap.try_emplace(LLVMM);
656   auto It = Pair.first;
657   if (!Pair.second)
658     return It->second.get();
659   It->second = std::unique_ptr<Module>(new Module(*LLVMM, *this));
660   return It->second.get();
661 }
662 
createFunction(llvm::Function * F)663 Function *Context::createFunction(llvm::Function *F) {
664   // Create the module if needed before we create the new sandboxir::Function.
665   // Note: this won't fully populate the module. The only globals that will be
666   // available will be the ones being used within the function.
667   getOrCreateModule(F->getParent());
668 
669   // There may be a function declaration already defined. Regardless destroy it.
670   if (Function *ExistingF = cast_or_null<Function>(getValue(F)))
671     detach(ExistingF);
672 
673   auto NewFPtr = std::unique_ptr<Function>(new Function(F, *this));
674   auto *SBF = cast<Function>(registerValue(std::move(NewFPtr)));
675   // Create arguments.
676   for (auto &Arg : F->args())
677     getOrCreateArgument(&Arg);
678   // Create BBs.
679   for (auto &BB : *F)
680     createBasicBlock(&BB);
681   return SBF;
682 }
683 
createModule(llvm::Module * LLVMM)684 Module *Context::createModule(llvm::Module *LLVMM) {
685   auto *M = getOrCreateModule(LLVMM);
686   // Create the functions.
687   for (auto &LLVMF : *LLVMM)
688     createFunction(&LLVMF);
689   // Create globals.
690   for (auto &Global : LLVMM->globals())
691     getOrCreateValue(&Global);
692   // Create aliases.
693   for (auto &Alias : LLVMM->aliases())
694     getOrCreateValue(&Alias);
695   // Create ifuncs.
696   for (auto &IFunc : LLVMM->ifuncs())
697     getOrCreateValue(&IFunc);
698 
699   return M;
700 }
701 
runEraseInstrCallbacks(Instruction * I)702 void Context::runEraseInstrCallbacks(Instruction *I) {
703   for (const auto &CBEntry : EraseInstrCallbacks)
704     CBEntry.second(I);
705 }
706 
runCreateInstrCallbacks(Instruction * I)707 void Context::runCreateInstrCallbacks(Instruction *I) {
708   for (auto &CBEntry : CreateInstrCallbacks)
709     CBEntry.second(I);
710 }
711 
runMoveInstrCallbacks(Instruction * I,const BBIterator & WhereIt)712 void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
713   for (auto &CBEntry : MoveInstrCallbacks)
714     CBEntry.second(I, WhereIt);
715 }
716 
runSetUseCallbacks(const Use & U,Value * NewSrc)717 void Context::runSetUseCallbacks(const Use &U, Value *NewSrc) {
718   for (auto &CBEntry : SetUseCallbacks)
719     CBEntry.second(U, NewSrc);
720 }
721 
722 // An arbitrary limit, to check for accidental misuse. We expect a small number
723 // of callbacks to be registered at a time, but we can increase this number if
724 // we discover we needed more.
725 [[maybe_unused]] static constexpr int MaxRegisteredCallbacks = 16;
726 
registerEraseInstrCallback(EraseInstrCallback CB)727 Context::CallbackID Context::registerEraseInstrCallback(EraseInstrCallback CB) {
728   assert(EraseInstrCallbacks.size() <= MaxRegisteredCallbacks &&
729          "EraseInstrCallbacks size limit exceeded");
730   CallbackID ID{NextCallbackID++};
731   EraseInstrCallbacks[ID] = CB;
732   return ID;
733 }
unregisterEraseInstrCallback(CallbackID ID)734 void Context::unregisterEraseInstrCallback(CallbackID ID) {
735   [[maybe_unused]] bool Erased = EraseInstrCallbacks.erase(ID);
736   assert(Erased &&
737          "Callback ID not found in EraseInstrCallbacks during deregistration");
738 }
739 
740 Context::CallbackID
registerCreateInstrCallback(CreateInstrCallback CB)741 Context::registerCreateInstrCallback(CreateInstrCallback CB) {
742   assert(CreateInstrCallbacks.size() <= MaxRegisteredCallbacks &&
743          "CreateInstrCallbacks size limit exceeded");
744   CallbackID ID{NextCallbackID++};
745   CreateInstrCallbacks[ID] = CB;
746   return ID;
747 }
unregisterCreateInstrCallback(CallbackID ID)748 void Context::unregisterCreateInstrCallback(CallbackID ID) {
749   [[maybe_unused]] bool Erased = CreateInstrCallbacks.erase(ID);
750   assert(Erased &&
751          "Callback ID not found in CreateInstrCallbacks during deregistration");
752 }
753 
registerMoveInstrCallback(MoveInstrCallback CB)754 Context::CallbackID Context::registerMoveInstrCallback(MoveInstrCallback CB) {
755   assert(MoveInstrCallbacks.size() <= MaxRegisteredCallbacks &&
756          "MoveInstrCallbacks size limit exceeded");
757   CallbackID ID{NextCallbackID++};
758   MoveInstrCallbacks[ID] = CB;
759   return ID;
760 }
unregisterMoveInstrCallback(CallbackID ID)761 void Context::unregisterMoveInstrCallback(CallbackID ID) {
762   [[maybe_unused]] bool Erased = MoveInstrCallbacks.erase(ID);
763   assert(Erased &&
764          "Callback ID not found in MoveInstrCallbacks during deregistration");
765 }
766 
registerSetUseCallback(SetUseCallback CB)767 Context::CallbackID Context::registerSetUseCallback(SetUseCallback CB) {
768   assert(SetUseCallbacks.size() <= MaxRegisteredCallbacks &&
769          "SetUseCallbacks size limit exceeded");
770   CallbackID ID{NextCallbackID++};
771   SetUseCallbacks[ID] = CB;
772   return ID;
773 }
unregisterSetUseCallback(CallbackID ID)774 void Context::unregisterSetUseCallback(CallbackID ID) {
775   [[maybe_unused]] bool Erased = SetUseCallbacks.erase(ID);
776   assert(Erased &&
777          "Callback ID not found in SetUseCallbacks during deregistration");
778 }
779 
780 } // namespace llvm::sandboxir
781