1 //===- Context.cpp - The Context class of Sandbox IR ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "llvm/SandboxIR/Context.h"
10 #include "llvm/IR/InlineAsm.h"
11 #include "llvm/SandboxIR/Function.h"
12 #include "llvm/SandboxIR/Instruction.h"
13 #include "llvm/SandboxIR/Module.h"
14
15 namespace llvm::sandboxir {
16
detachLLVMValue(llvm::Value * V)17 std::unique_ptr<Value> Context::detachLLVMValue(llvm::Value *V) {
18 std::unique_ptr<Value> Erased;
19 auto It = LLVMValueToValueMap.find(V);
20 if (It != LLVMValueToValueMap.end()) {
21 auto *Val = It->second.release();
22 Erased = std::unique_ptr<Value>(Val);
23 LLVMValueToValueMap.erase(It);
24 }
25 return Erased;
26 }
27
detach(Value * V)28 std::unique_ptr<Value> Context::detach(Value *V) {
29 assert(V->getSubclassID() != Value::ClassID::Constant &&
30 "Can't detach a constant!");
31 assert(V->getSubclassID() != Value::ClassID::User && "Can't detach a user!");
32 return detachLLVMValue(V->Val);
33 }
34
registerValue(std::unique_ptr<Value> && VPtr)35 Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
36 assert(VPtr->getSubclassID() != Value::ClassID::User &&
37 "Can't register a user!");
38
39 Value *V = VPtr.get();
40 [[maybe_unused]] auto Pair =
41 LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
42 assert(Pair.second && "Already exists!");
43
44 // Track creation of instructions.
45 // Please note that we don't allow the creation of detached instructions,
46 // meaning that the instructions need to be inserted into a block upon
47 // creation. This is why the tracker class combines creation and insertion.
48 if (auto *I = dyn_cast<Instruction>(V)) {
49 getTracker().emplaceIfTracking<CreateAndInsertInst>(I);
50 runCreateInstrCallbacks(I);
51 }
52
53 return V;
54 }
55
getOrCreateValueInternal(llvm::Value * LLVMV,llvm::User * U)56 Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
57 auto Pair = LLVMValueToValueMap.try_emplace(LLVMV);
58 auto It = Pair.first;
59 if (!Pair.second)
60 return It->second.get();
61
62 // Instruction
63 if (auto *LLVMI = dyn_cast<llvm::Instruction>(LLVMV)) {
64 switch (LLVMI->getOpcode()) {
65 case llvm::Instruction::VAArg: {
66 auto *LLVMVAArg = cast<llvm::VAArgInst>(LLVMV);
67 It->second = std::unique_ptr<VAArgInst>(new VAArgInst(LLVMVAArg, *this));
68 return It->second.get();
69 }
70 case llvm::Instruction::Freeze: {
71 auto *LLVMFreeze = cast<llvm::FreezeInst>(LLVMV);
72 It->second =
73 std::unique_ptr<FreezeInst>(new FreezeInst(LLVMFreeze, *this));
74 return It->second.get();
75 }
76 case llvm::Instruction::Fence: {
77 auto *LLVMFence = cast<llvm::FenceInst>(LLVMV);
78 It->second = std::unique_ptr<FenceInst>(new FenceInst(LLVMFence, *this));
79 return It->second.get();
80 }
81 case llvm::Instruction::Select: {
82 auto *LLVMSel = cast<llvm::SelectInst>(LLVMV);
83 It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
84 return It->second.get();
85 }
86 case llvm::Instruction::ExtractElement: {
87 auto *LLVMIns = cast<llvm::ExtractElementInst>(LLVMV);
88 It->second = std::unique_ptr<ExtractElementInst>(
89 new ExtractElementInst(LLVMIns, *this));
90 return It->second.get();
91 }
92 case llvm::Instruction::InsertElement: {
93 auto *LLVMIns = cast<llvm::InsertElementInst>(LLVMV);
94 It->second = std::unique_ptr<InsertElementInst>(
95 new InsertElementInst(LLVMIns, *this));
96 return It->second.get();
97 }
98 case llvm::Instruction::ShuffleVector: {
99 auto *LLVMIns = cast<llvm::ShuffleVectorInst>(LLVMV);
100 It->second = std::unique_ptr<ShuffleVectorInst>(
101 new ShuffleVectorInst(LLVMIns, *this));
102 return It->second.get();
103 }
104 case llvm::Instruction::ExtractValue: {
105 auto *LLVMIns = cast<llvm::ExtractValueInst>(LLVMV);
106 It->second = std::unique_ptr<ExtractValueInst>(
107 new ExtractValueInst(LLVMIns, *this));
108 return It->second.get();
109 }
110 case llvm::Instruction::InsertValue: {
111 auto *LLVMIns = cast<llvm::InsertValueInst>(LLVMV);
112 It->second =
113 std::unique_ptr<InsertValueInst>(new InsertValueInst(LLVMIns, *this));
114 return It->second.get();
115 }
116 case llvm::Instruction::Br: {
117 auto *LLVMBr = cast<llvm::BranchInst>(LLVMV);
118 It->second = std::unique_ptr<BranchInst>(new BranchInst(LLVMBr, *this));
119 return It->second.get();
120 }
121 case llvm::Instruction::Load: {
122 auto *LLVMLd = cast<llvm::LoadInst>(LLVMV);
123 It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
124 return It->second.get();
125 }
126 case llvm::Instruction::Store: {
127 auto *LLVMSt = cast<llvm::StoreInst>(LLVMV);
128 It->second = std::unique_ptr<StoreInst>(new StoreInst(LLVMSt, *this));
129 return It->second.get();
130 }
131 case llvm::Instruction::Ret: {
132 auto *LLVMRet = cast<llvm::ReturnInst>(LLVMV);
133 It->second = std::unique_ptr<ReturnInst>(new ReturnInst(LLVMRet, *this));
134 return It->second.get();
135 }
136 case llvm::Instruction::Call: {
137 auto *LLVMCall = cast<llvm::CallInst>(LLVMV);
138 It->second = std::unique_ptr<CallInst>(new CallInst(LLVMCall, *this));
139 return It->second.get();
140 }
141 case llvm::Instruction::Invoke: {
142 auto *LLVMInvoke = cast<llvm::InvokeInst>(LLVMV);
143 It->second =
144 std::unique_ptr<InvokeInst>(new InvokeInst(LLVMInvoke, *this));
145 return It->second.get();
146 }
147 case llvm::Instruction::CallBr: {
148 auto *LLVMCallBr = cast<llvm::CallBrInst>(LLVMV);
149 It->second =
150 std::unique_ptr<CallBrInst>(new CallBrInst(LLVMCallBr, *this));
151 return It->second.get();
152 }
153 case llvm::Instruction::LandingPad: {
154 auto *LLVMLPad = cast<llvm::LandingPadInst>(LLVMV);
155 It->second =
156 std::unique_ptr<LandingPadInst>(new LandingPadInst(LLVMLPad, *this));
157 return It->second.get();
158 }
159 case llvm::Instruction::CatchPad: {
160 auto *LLVMCPI = cast<llvm::CatchPadInst>(LLVMV);
161 It->second =
162 std::unique_ptr<CatchPadInst>(new CatchPadInst(LLVMCPI, *this));
163 return It->second.get();
164 }
165 case llvm::Instruction::CleanupPad: {
166 auto *LLVMCPI = cast<llvm::CleanupPadInst>(LLVMV);
167 It->second =
168 std::unique_ptr<CleanupPadInst>(new CleanupPadInst(LLVMCPI, *this));
169 return It->second.get();
170 }
171 case llvm::Instruction::CatchRet: {
172 auto *LLVMCRI = cast<llvm::CatchReturnInst>(LLVMV);
173 It->second =
174 std::unique_ptr<CatchReturnInst>(new CatchReturnInst(LLVMCRI, *this));
175 return It->second.get();
176 }
177 case llvm::Instruction::CleanupRet: {
178 auto *LLVMCRI = cast<llvm::CleanupReturnInst>(LLVMV);
179 It->second = std::unique_ptr<CleanupReturnInst>(
180 new CleanupReturnInst(LLVMCRI, *this));
181 return It->second.get();
182 }
183 case llvm::Instruction::GetElementPtr: {
184 auto *LLVMGEP = cast<llvm::GetElementPtrInst>(LLVMV);
185 It->second = std::unique_ptr<GetElementPtrInst>(
186 new GetElementPtrInst(LLVMGEP, *this));
187 return It->second.get();
188 }
189 case llvm::Instruction::CatchSwitch: {
190 auto *LLVMCatchSwitchInst = cast<llvm::CatchSwitchInst>(LLVMV);
191 It->second = std::unique_ptr<CatchSwitchInst>(
192 new CatchSwitchInst(LLVMCatchSwitchInst, *this));
193 return It->second.get();
194 }
195 case llvm::Instruction::Resume: {
196 auto *LLVMResumeInst = cast<llvm::ResumeInst>(LLVMV);
197 It->second =
198 std::unique_ptr<ResumeInst>(new ResumeInst(LLVMResumeInst, *this));
199 return It->second.get();
200 }
201 case llvm::Instruction::Switch: {
202 auto *LLVMSwitchInst = cast<llvm::SwitchInst>(LLVMV);
203 It->second =
204 std::unique_ptr<SwitchInst>(new SwitchInst(LLVMSwitchInst, *this));
205 return It->second.get();
206 }
207 case llvm::Instruction::FNeg: {
208 auto *LLVMUnaryOperator = cast<llvm::UnaryOperator>(LLVMV);
209 It->second = std::unique_ptr<UnaryOperator>(
210 new UnaryOperator(LLVMUnaryOperator, *this));
211 return It->second.get();
212 }
213 case llvm::Instruction::Add:
214 case llvm::Instruction::FAdd:
215 case llvm::Instruction::Sub:
216 case llvm::Instruction::FSub:
217 case llvm::Instruction::Mul:
218 case llvm::Instruction::FMul:
219 case llvm::Instruction::UDiv:
220 case llvm::Instruction::SDiv:
221 case llvm::Instruction::FDiv:
222 case llvm::Instruction::URem:
223 case llvm::Instruction::SRem:
224 case llvm::Instruction::FRem:
225 case llvm::Instruction::Shl:
226 case llvm::Instruction::LShr:
227 case llvm::Instruction::AShr:
228 case llvm::Instruction::And:
229 case llvm::Instruction::Or:
230 case llvm::Instruction::Xor: {
231 auto *LLVMBinaryOperator = cast<llvm::BinaryOperator>(LLVMV);
232 It->second = std::unique_ptr<BinaryOperator>(
233 new BinaryOperator(LLVMBinaryOperator, *this));
234 return It->second.get();
235 }
236 case llvm::Instruction::AtomicRMW: {
237 auto *LLVMAtomicRMW = cast<llvm::AtomicRMWInst>(LLVMV);
238 It->second = std::unique_ptr<AtomicRMWInst>(
239 new AtomicRMWInst(LLVMAtomicRMW, *this));
240 return It->second.get();
241 }
242 case llvm::Instruction::AtomicCmpXchg: {
243 auto *LLVMAtomicCmpXchg = cast<llvm::AtomicCmpXchgInst>(LLVMV);
244 It->second = std::unique_ptr<AtomicCmpXchgInst>(
245 new AtomicCmpXchgInst(LLVMAtomicCmpXchg, *this));
246 return It->second.get();
247 }
248 case llvm::Instruction::Alloca: {
249 auto *LLVMAlloca = cast<llvm::AllocaInst>(LLVMV);
250 It->second =
251 std::unique_ptr<AllocaInst>(new AllocaInst(LLVMAlloca, *this));
252 return It->second.get();
253 }
254 case llvm::Instruction::ZExt:
255 case llvm::Instruction::SExt:
256 case llvm::Instruction::FPToUI:
257 case llvm::Instruction::FPToSI:
258 case llvm::Instruction::FPExt:
259 case llvm::Instruction::PtrToInt:
260 case llvm::Instruction::IntToPtr:
261 case llvm::Instruction::SIToFP:
262 case llvm::Instruction::UIToFP:
263 case llvm::Instruction::Trunc:
264 case llvm::Instruction::FPTrunc:
265 case llvm::Instruction::BitCast:
266 case llvm::Instruction::AddrSpaceCast: {
267 auto *LLVMCast = cast<llvm::CastInst>(LLVMV);
268 It->second = std::unique_ptr<CastInst>(new CastInst(LLVMCast, *this));
269 return It->second.get();
270 }
271 case llvm::Instruction::PHI: {
272 auto *LLVMPhi = cast<llvm::PHINode>(LLVMV);
273 It->second = std::unique_ptr<PHINode>(new PHINode(LLVMPhi, *this));
274 return It->second.get();
275 }
276 case llvm::Instruction::ICmp: {
277 auto *LLVMICmp = cast<llvm::ICmpInst>(LLVMV);
278 It->second = std::unique_ptr<ICmpInst>(new ICmpInst(LLVMICmp, *this));
279 return It->second.get();
280 }
281 case llvm::Instruction::FCmp: {
282 auto *LLVMFCmp = cast<llvm::FCmpInst>(LLVMV);
283 It->second = std::unique_ptr<FCmpInst>(new FCmpInst(LLVMFCmp, *this));
284 return It->second.get();
285 }
286 case llvm::Instruction::Unreachable: {
287 auto *LLVMUnreachable = cast<llvm::UnreachableInst>(LLVMV);
288 It->second = std::unique_ptr<UnreachableInst>(
289 new UnreachableInst(LLVMUnreachable, *this));
290 return It->second.get();
291 }
292 default:
293 break;
294 }
295 It->second = std::unique_ptr<OpaqueInst>(
296 new OpaqueInst(cast<llvm::Instruction>(LLVMV), *this));
297 return It->second.get();
298 }
299 // Constant
300 if (auto *LLVMC = dyn_cast<llvm::Constant>(LLVMV)) {
301 switch (LLVMC->getValueID()) {
302 case llvm::Value::ConstantIntVal:
303 It->second = std::unique_ptr<ConstantInt>(
304 new ConstantInt(cast<llvm::ConstantInt>(LLVMC), *this));
305 return It->second.get();
306 case llvm::Value::ConstantFPVal:
307 It->second = std::unique_ptr<ConstantFP>(
308 new ConstantFP(cast<llvm::ConstantFP>(LLVMC), *this));
309 return It->second.get();
310 case llvm::Value::BlockAddressVal:
311 It->second = std::unique_ptr<BlockAddress>(
312 new BlockAddress(cast<llvm::BlockAddress>(LLVMC), *this));
313 return It->second.get();
314 case llvm::Value::ConstantTokenNoneVal:
315 It->second = std::unique_ptr<ConstantTokenNone>(
316 new ConstantTokenNone(cast<llvm::ConstantTokenNone>(LLVMC), *this));
317 return It->second.get();
318 case llvm::Value::ConstantAggregateZeroVal: {
319 auto *CAZ = cast<llvm::ConstantAggregateZero>(LLVMC);
320 It->second = std::unique_ptr<ConstantAggregateZero>(
321 new ConstantAggregateZero(CAZ, *this));
322 auto *Ret = It->second.get();
323 // Must create sandboxir for elements.
324 auto EC = CAZ->getElementCount();
325 if (EC.isFixed()) {
326 for (auto ElmIdx : seq<unsigned>(0, EC.getFixedValue()))
327 getOrCreateValueInternal(CAZ->getElementValue(ElmIdx), CAZ);
328 }
329 return Ret;
330 }
331 case llvm::Value::ConstantPointerNullVal:
332 It->second = std::unique_ptr<ConstantPointerNull>(new ConstantPointerNull(
333 cast<llvm::ConstantPointerNull>(LLVMC), *this));
334 return It->second.get();
335 case llvm::Value::PoisonValueVal:
336 It->second = std::unique_ptr<PoisonValue>(
337 new PoisonValue(cast<llvm::PoisonValue>(LLVMC), *this));
338 return It->second.get();
339 case llvm::Value::UndefValueVal:
340 It->second = std::unique_ptr<UndefValue>(
341 new UndefValue(cast<llvm::UndefValue>(LLVMC), *this));
342 return It->second.get();
343 case llvm::Value::DSOLocalEquivalentVal: {
344 auto *DSOLE = cast<llvm::DSOLocalEquivalent>(LLVMC);
345 It->second = std::unique_ptr<DSOLocalEquivalent>(
346 new DSOLocalEquivalent(DSOLE, *this));
347 auto *Ret = It->second.get();
348 getOrCreateValueInternal(DSOLE->getGlobalValue(), DSOLE);
349 return Ret;
350 }
351 case llvm::Value::ConstantArrayVal:
352 It->second = std::unique_ptr<ConstantArray>(
353 new ConstantArray(cast<llvm::ConstantArray>(LLVMC), *this));
354 break;
355 case llvm::Value::ConstantStructVal:
356 It->second = std::unique_ptr<ConstantStruct>(
357 new ConstantStruct(cast<llvm::ConstantStruct>(LLVMC), *this));
358 break;
359 case llvm::Value::ConstantVectorVal:
360 It->second = std::unique_ptr<ConstantVector>(
361 new ConstantVector(cast<llvm::ConstantVector>(LLVMC), *this));
362 break;
363 case llvm::Value::ConstantDataArrayVal:
364 It->second = std::unique_ptr<ConstantDataArray>(
365 new ConstantDataArray(cast<llvm::ConstantDataArray>(LLVMC), *this));
366 break;
367 case llvm::Value::ConstantDataVectorVal:
368 It->second = std::unique_ptr<ConstantDataVector>(
369 new ConstantDataVector(cast<llvm::ConstantDataVector>(LLVMC), *this));
370 break;
371 case llvm::Value::FunctionVal:
372 It->second = std::unique_ptr<Function>(
373 new Function(cast<llvm::Function>(LLVMC), *this));
374 break;
375 case llvm::Value::GlobalIFuncVal:
376 It->second = std::unique_ptr<GlobalIFunc>(
377 new GlobalIFunc(cast<llvm::GlobalIFunc>(LLVMC), *this));
378 break;
379 case llvm::Value::GlobalVariableVal:
380 It->second = std::unique_ptr<GlobalVariable>(
381 new GlobalVariable(cast<llvm::GlobalVariable>(LLVMC), *this));
382 break;
383 case llvm::Value::GlobalAliasVal:
384 It->second = std::unique_ptr<GlobalAlias>(
385 new GlobalAlias(cast<llvm::GlobalAlias>(LLVMC), *this));
386 break;
387 case llvm::Value::NoCFIValueVal:
388 It->second = std::unique_ptr<NoCFIValue>(
389 new NoCFIValue(cast<llvm::NoCFIValue>(LLVMC), *this));
390 break;
391 case llvm::Value::ConstantPtrAuthVal:
392 It->second = std::unique_ptr<ConstantPtrAuth>(
393 new ConstantPtrAuth(cast<llvm::ConstantPtrAuth>(LLVMC), *this));
394 break;
395 case llvm::Value::ConstantExprVal:
396 It->second = std::unique_ptr<ConstantExpr>(
397 new ConstantExpr(cast<llvm::ConstantExpr>(LLVMC), *this));
398 break;
399 default:
400 It->second = std::unique_ptr<Constant>(new Constant(LLVMC, *this));
401 break;
402 }
403 auto *NewC = It->second.get();
404 for (llvm::Value *COp : LLVMC->operands())
405 getOrCreateValueInternal(COp, LLVMC);
406 return NewC;
407 }
408 // Argument
409 if (auto *LLVMArg = dyn_cast<llvm::Argument>(LLVMV)) {
410 It->second = std::unique_ptr<Argument>(new Argument(LLVMArg, *this));
411 return It->second.get();
412 }
413 // BasicBlock
414 if (auto *LLVMBB = dyn_cast<llvm::BasicBlock>(LLVMV)) {
415 assert(isa<llvm::BlockAddress>(U) &&
416 "This won't create a SBBB, don't call this function directly!");
417 if (auto *SBBB = getValue(LLVMBB))
418 return SBBB;
419 return nullptr;
420 }
421 // Metadata
422 if (auto *LLVMMD = dyn_cast<llvm::MetadataAsValue>(LLVMV)) {
423 It->second = std::unique_ptr<OpaqueValue>(new OpaqueValue(LLVMMD, *this));
424 return It->second.get();
425 }
426 // InlineAsm
427 if (auto *LLVMAsm = dyn_cast<llvm::InlineAsm>(LLVMV)) {
428 It->second = std::unique_ptr<OpaqueValue>(new OpaqueValue(LLVMAsm, *this));
429 return It->second.get();
430 }
431 llvm_unreachable("Unhandled LLVMV type!");
432 }
433
getOrCreateArgument(llvm::Argument * LLVMArg)434 Argument *Context::getOrCreateArgument(llvm::Argument *LLVMArg) {
435 auto Pair = LLVMValueToValueMap.try_emplace(LLVMArg);
436 auto It = Pair.first;
437 if (Pair.second) {
438 It->second = std::unique_ptr<Argument>(new Argument(LLVMArg, *this));
439 return cast<Argument>(It->second.get());
440 }
441 return cast<Argument>(It->second.get());
442 }
443
getOrCreateConstant(llvm::Constant * LLVMC)444 Constant *Context::getOrCreateConstant(llvm::Constant *LLVMC) {
445 return cast<Constant>(getOrCreateValueInternal(LLVMC, 0));
446 }
447
createBasicBlock(llvm::BasicBlock * LLVMBB)448 BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) {
449 assert(getValue(LLVMBB) == nullptr && "Already exists!");
450 auto NewBBPtr = std::unique_ptr<BasicBlock>(new BasicBlock(LLVMBB, *this));
451 auto *BB = cast<BasicBlock>(registerValue(std::move(NewBBPtr)));
452 // Create SandboxIR for BB's body.
453 BB->buildBasicBlockFromLLVMIR(LLVMBB);
454 return BB;
455 }
456
createVAArgInst(llvm::VAArgInst * SI)457 VAArgInst *Context::createVAArgInst(llvm::VAArgInst *SI) {
458 auto NewPtr = std::unique_ptr<VAArgInst>(new VAArgInst(SI, *this));
459 return cast<VAArgInst>(registerValue(std::move(NewPtr)));
460 }
461
createFreezeInst(llvm::FreezeInst * SI)462 FreezeInst *Context::createFreezeInst(llvm::FreezeInst *SI) {
463 auto NewPtr = std::unique_ptr<FreezeInst>(new FreezeInst(SI, *this));
464 return cast<FreezeInst>(registerValue(std::move(NewPtr)));
465 }
466
createFenceInst(llvm::FenceInst * SI)467 FenceInst *Context::createFenceInst(llvm::FenceInst *SI) {
468 auto NewPtr = std::unique_ptr<FenceInst>(new FenceInst(SI, *this));
469 return cast<FenceInst>(registerValue(std::move(NewPtr)));
470 }
471
createSelectInst(llvm::SelectInst * SI)472 SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
473 auto NewPtr = std::unique_ptr<SelectInst>(new SelectInst(SI, *this));
474 return cast<SelectInst>(registerValue(std::move(NewPtr)));
475 }
476
477 ExtractElementInst *
createExtractElementInst(llvm::ExtractElementInst * EEI)478 Context::createExtractElementInst(llvm::ExtractElementInst *EEI) {
479 auto NewPtr =
480 std::unique_ptr<ExtractElementInst>(new ExtractElementInst(EEI, *this));
481 return cast<ExtractElementInst>(registerValue(std::move(NewPtr)));
482 }
483
484 InsertElementInst *
createInsertElementInst(llvm::InsertElementInst * IEI)485 Context::createInsertElementInst(llvm::InsertElementInst *IEI) {
486 auto NewPtr =
487 std::unique_ptr<InsertElementInst>(new InsertElementInst(IEI, *this));
488 return cast<InsertElementInst>(registerValue(std::move(NewPtr)));
489 }
490
491 ShuffleVectorInst *
createShuffleVectorInst(llvm::ShuffleVectorInst * SVI)492 Context::createShuffleVectorInst(llvm::ShuffleVectorInst *SVI) {
493 auto NewPtr =
494 std::unique_ptr<ShuffleVectorInst>(new ShuffleVectorInst(SVI, *this));
495 return cast<ShuffleVectorInst>(registerValue(std::move(NewPtr)));
496 }
497
createExtractValueInst(llvm::ExtractValueInst * EVI)498 ExtractValueInst *Context::createExtractValueInst(llvm::ExtractValueInst *EVI) {
499 auto NewPtr =
500 std::unique_ptr<ExtractValueInst>(new ExtractValueInst(EVI, *this));
501 return cast<ExtractValueInst>(registerValue(std::move(NewPtr)));
502 }
503
createInsertValueInst(llvm::InsertValueInst * IVI)504 InsertValueInst *Context::createInsertValueInst(llvm::InsertValueInst *IVI) {
505 auto NewPtr =
506 std::unique_ptr<InsertValueInst>(new InsertValueInst(IVI, *this));
507 return cast<InsertValueInst>(registerValue(std::move(NewPtr)));
508 }
509
createBranchInst(llvm::BranchInst * BI)510 BranchInst *Context::createBranchInst(llvm::BranchInst *BI) {
511 auto NewPtr = std::unique_ptr<BranchInst>(new BranchInst(BI, *this));
512 return cast<BranchInst>(registerValue(std::move(NewPtr)));
513 }
514
createLoadInst(llvm::LoadInst * LI)515 LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
516 auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
517 return cast<LoadInst>(registerValue(std::move(NewPtr)));
518 }
519
createStoreInst(llvm::StoreInst * SI)520 StoreInst *Context::createStoreInst(llvm::StoreInst *SI) {
521 auto NewPtr = std::unique_ptr<StoreInst>(new StoreInst(SI, *this));
522 return cast<StoreInst>(registerValue(std::move(NewPtr)));
523 }
524
createReturnInst(llvm::ReturnInst * I)525 ReturnInst *Context::createReturnInst(llvm::ReturnInst *I) {
526 auto NewPtr = std::unique_ptr<ReturnInst>(new ReturnInst(I, *this));
527 return cast<ReturnInst>(registerValue(std::move(NewPtr)));
528 }
529
createCallInst(llvm::CallInst * I)530 CallInst *Context::createCallInst(llvm::CallInst *I) {
531 auto NewPtr = std::unique_ptr<CallInst>(new CallInst(I, *this));
532 return cast<CallInst>(registerValue(std::move(NewPtr)));
533 }
534
createInvokeInst(llvm::InvokeInst * I)535 InvokeInst *Context::createInvokeInst(llvm::InvokeInst *I) {
536 auto NewPtr = std::unique_ptr<InvokeInst>(new InvokeInst(I, *this));
537 return cast<InvokeInst>(registerValue(std::move(NewPtr)));
538 }
539
createCallBrInst(llvm::CallBrInst * I)540 CallBrInst *Context::createCallBrInst(llvm::CallBrInst *I) {
541 auto NewPtr = std::unique_ptr<CallBrInst>(new CallBrInst(I, *this));
542 return cast<CallBrInst>(registerValue(std::move(NewPtr)));
543 }
544
createUnreachableInst(llvm::UnreachableInst * UI)545 UnreachableInst *Context::createUnreachableInst(llvm::UnreachableInst *UI) {
546 auto NewPtr =
547 std::unique_ptr<UnreachableInst>(new UnreachableInst(UI, *this));
548 return cast<UnreachableInst>(registerValue(std::move(NewPtr)));
549 }
createLandingPadInst(llvm::LandingPadInst * I)550 LandingPadInst *Context::createLandingPadInst(llvm::LandingPadInst *I) {
551 auto NewPtr = std::unique_ptr<LandingPadInst>(new LandingPadInst(I, *this));
552 return cast<LandingPadInst>(registerValue(std::move(NewPtr)));
553 }
createCatchPadInst(llvm::CatchPadInst * I)554 CatchPadInst *Context::createCatchPadInst(llvm::CatchPadInst *I) {
555 auto NewPtr = std::unique_ptr<CatchPadInst>(new CatchPadInst(I, *this));
556 return cast<CatchPadInst>(registerValue(std::move(NewPtr)));
557 }
createCleanupPadInst(llvm::CleanupPadInst * I)558 CleanupPadInst *Context::createCleanupPadInst(llvm::CleanupPadInst *I) {
559 auto NewPtr = std::unique_ptr<CleanupPadInst>(new CleanupPadInst(I, *this));
560 return cast<CleanupPadInst>(registerValue(std::move(NewPtr)));
561 }
createCatchReturnInst(llvm::CatchReturnInst * I)562 CatchReturnInst *Context::createCatchReturnInst(llvm::CatchReturnInst *I) {
563 auto NewPtr = std::unique_ptr<CatchReturnInst>(new CatchReturnInst(I, *this));
564 return cast<CatchReturnInst>(registerValue(std::move(NewPtr)));
565 }
566 CleanupReturnInst *
createCleanupReturnInst(llvm::CleanupReturnInst * I)567 Context::createCleanupReturnInst(llvm::CleanupReturnInst *I) {
568 auto NewPtr =
569 std::unique_ptr<CleanupReturnInst>(new CleanupReturnInst(I, *this));
570 return cast<CleanupReturnInst>(registerValue(std::move(NewPtr)));
571 }
572 GetElementPtrInst *
createGetElementPtrInst(llvm::GetElementPtrInst * I)573 Context::createGetElementPtrInst(llvm::GetElementPtrInst *I) {
574 auto NewPtr =
575 std::unique_ptr<GetElementPtrInst>(new GetElementPtrInst(I, *this));
576 return cast<GetElementPtrInst>(registerValue(std::move(NewPtr)));
577 }
createCatchSwitchInst(llvm::CatchSwitchInst * I)578 CatchSwitchInst *Context::createCatchSwitchInst(llvm::CatchSwitchInst *I) {
579 auto NewPtr = std::unique_ptr<CatchSwitchInst>(new CatchSwitchInst(I, *this));
580 return cast<CatchSwitchInst>(registerValue(std::move(NewPtr)));
581 }
createResumeInst(llvm::ResumeInst * I)582 ResumeInst *Context::createResumeInst(llvm::ResumeInst *I) {
583 auto NewPtr = std::unique_ptr<ResumeInst>(new ResumeInst(I, *this));
584 return cast<ResumeInst>(registerValue(std::move(NewPtr)));
585 }
createSwitchInst(llvm::SwitchInst * I)586 SwitchInst *Context::createSwitchInst(llvm::SwitchInst *I) {
587 auto NewPtr = std::unique_ptr<SwitchInst>(new SwitchInst(I, *this));
588 return cast<SwitchInst>(registerValue(std::move(NewPtr)));
589 }
createUnaryOperator(llvm::UnaryOperator * I)590 UnaryOperator *Context::createUnaryOperator(llvm::UnaryOperator *I) {
591 auto NewPtr = std::unique_ptr<UnaryOperator>(new UnaryOperator(I, *this));
592 return cast<UnaryOperator>(registerValue(std::move(NewPtr)));
593 }
createBinaryOperator(llvm::BinaryOperator * I)594 BinaryOperator *Context::createBinaryOperator(llvm::BinaryOperator *I) {
595 auto NewPtr = std::unique_ptr<BinaryOperator>(new BinaryOperator(I, *this));
596 return cast<BinaryOperator>(registerValue(std::move(NewPtr)));
597 }
createAtomicRMWInst(llvm::AtomicRMWInst * I)598 AtomicRMWInst *Context::createAtomicRMWInst(llvm::AtomicRMWInst *I) {
599 auto NewPtr = std::unique_ptr<AtomicRMWInst>(new AtomicRMWInst(I, *this));
600 return cast<AtomicRMWInst>(registerValue(std::move(NewPtr)));
601 }
602 AtomicCmpXchgInst *
createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst * I)603 Context::createAtomicCmpXchgInst(llvm::AtomicCmpXchgInst *I) {
604 auto NewPtr =
605 std::unique_ptr<AtomicCmpXchgInst>(new AtomicCmpXchgInst(I, *this));
606 return cast<AtomicCmpXchgInst>(registerValue(std::move(NewPtr)));
607 }
createAllocaInst(llvm::AllocaInst * I)608 AllocaInst *Context::createAllocaInst(llvm::AllocaInst *I) {
609 auto NewPtr = std::unique_ptr<AllocaInst>(new AllocaInst(I, *this));
610 return cast<AllocaInst>(registerValue(std::move(NewPtr)));
611 }
createCastInst(llvm::CastInst * I)612 CastInst *Context::createCastInst(llvm::CastInst *I) {
613 auto NewPtr = std::unique_ptr<CastInst>(new CastInst(I, *this));
614 return cast<CastInst>(registerValue(std::move(NewPtr)));
615 }
createPHINode(llvm::PHINode * I)616 PHINode *Context::createPHINode(llvm::PHINode *I) {
617 auto NewPtr = std::unique_ptr<PHINode>(new PHINode(I, *this));
618 return cast<PHINode>(registerValue(std::move(NewPtr)));
619 }
createICmpInst(llvm::ICmpInst * I)620 ICmpInst *Context::createICmpInst(llvm::ICmpInst *I) {
621 auto NewPtr = std::unique_ptr<ICmpInst>(new ICmpInst(I, *this));
622 return cast<ICmpInst>(registerValue(std::move(NewPtr)));
623 }
createFCmpInst(llvm::FCmpInst * I)624 FCmpInst *Context::createFCmpInst(llvm::FCmpInst *I) {
625 auto NewPtr = std::unique_ptr<FCmpInst>(new FCmpInst(I, *this));
626 return cast<FCmpInst>(registerValue(std::move(NewPtr)));
627 }
getValue(llvm::Value * V) const628 Value *Context::getValue(llvm::Value *V) const {
629 auto It = LLVMValueToValueMap.find(V);
630 if (It != LLVMValueToValueMap.end())
631 return It->second.get();
632 return nullptr;
633 }
634
Context(LLVMContext & LLVMCtx)635 Context::Context(LLVMContext &LLVMCtx)
636 : LLVMCtx(LLVMCtx), IRTracker(*this),
637 LLVMIRBuilder(LLVMCtx, ConstantFolder()) {}
638
~Context()639 Context::~Context() {}
640
clear()641 void Context::clear() {
642 // TODO: Ideally we should clear only function-scope objects, and keep global
643 // objects, like Constants to avoid recreating them.
644 LLVMValueToValueMap.clear();
645 }
646
getModule(llvm::Module * LLVMM) const647 Module *Context::getModule(llvm::Module *LLVMM) const {
648 auto It = LLVMModuleToModuleMap.find(LLVMM);
649 if (It != LLVMModuleToModuleMap.end())
650 return It->second.get();
651 return nullptr;
652 }
653
getOrCreateModule(llvm::Module * LLVMM)654 Module *Context::getOrCreateModule(llvm::Module *LLVMM) {
655 auto Pair = LLVMModuleToModuleMap.try_emplace(LLVMM);
656 auto It = Pair.first;
657 if (!Pair.second)
658 return It->second.get();
659 It->second = std::unique_ptr<Module>(new Module(*LLVMM, *this));
660 return It->second.get();
661 }
662
createFunction(llvm::Function * F)663 Function *Context::createFunction(llvm::Function *F) {
664 // Create the module if needed before we create the new sandboxir::Function.
665 // Note: this won't fully populate the module. The only globals that will be
666 // available will be the ones being used within the function.
667 getOrCreateModule(F->getParent());
668
669 // There may be a function declaration already defined. Regardless destroy it.
670 if (Function *ExistingF = cast_or_null<Function>(getValue(F)))
671 detach(ExistingF);
672
673 auto NewFPtr = std::unique_ptr<Function>(new Function(F, *this));
674 auto *SBF = cast<Function>(registerValue(std::move(NewFPtr)));
675 // Create arguments.
676 for (auto &Arg : F->args())
677 getOrCreateArgument(&Arg);
678 // Create BBs.
679 for (auto &BB : *F)
680 createBasicBlock(&BB);
681 return SBF;
682 }
683
createModule(llvm::Module * LLVMM)684 Module *Context::createModule(llvm::Module *LLVMM) {
685 auto *M = getOrCreateModule(LLVMM);
686 // Create the functions.
687 for (auto &LLVMF : *LLVMM)
688 createFunction(&LLVMF);
689 // Create globals.
690 for (auto &Global : LLVMM->globals())
691 getOrCreateValue(&Global);
692 // Create aliases.
693 for (auto &Alias : LLVMM->aliases())
694 getOrCreateValue(&Alias);
695 // Create ifuncs.
696 for (auto &IFunc : LLVMM->ifuncs())
697 getOrCreateValue(&IFunc);
698
699 return M;
700 }
701
runEraseInstrCallbacks(Instruction * I)702 void Context::runEraseInstrCallbacks(Instruction *I) {
703 for (const auto &CBEntry : EraseInstrCallbacks)
704 CBEntry.second(I);
705 }
706
runCreateInstrCallbacks(Instruction * I)707 void Context::runCreateInstrCallbacks(Instruction *I) {
708 for (auto &CBEntry : CreateInstrCallbacks)
709 CBEntry.second(I);
710 }
711
runMoveInstrCallbacks(Instruction * I,const BBIterator & WhereIt)712 void Context::runMoveInstrCallbacks(Instruction *I, const BBIterator &WhereIt) {
713 for (auto &CBEntry : MoveInstrCallbacks)
714 CBEntry.second(I, WhereIt);
715 }
716
runSetUseCallbacks(const Use & U,Value * NewSrc)717 void Context::runSetUseCallbacks(const Use &U, Value *NewSrc) {
718 for (auto &CBEntry : SetUseCallbacks)
719 CBEntry.second(U, NewSrc);
720 }
721
722 // An arbitrary limit, to check for accidental misuse. We expect a small number
723 // of callbacks to be registered at a time, but we can increase this number if
724 // we discover we needed more.
725 [[maybe_unused]] static constexpr int MaxRegisteredCallbacks = 16;
726
registerEraseInstrCallback(EraseInstrCallback CB)727 Context::CallbackID Context::registerEraseInstrCallback(EraseInstrCallback CB) {
728 assert(EraseInstrCallbacks.size() <= MaxRegisteredCallbacks &&
729 "EraseInstrCallbacks size limit exceeded");
730 CallbackID ID{NextCallbackID++};
731 EraseInstrCallbacks[ID] = CB;
732 return ID;
733 }
unregisterEraseInstrCallback(CallbackID ID)734 void Context::unregisterEraseInstrCallback(CallbackID ID) {
735 [[maybe_unused]] bool Erased = EraseInstrCallbacks.erase(ID);
736 assert(Erased &&
737 "Callback ID not found in EraseInstrCallbacks during deregistration");
738 }
739
740 Context::CallbackID
registerCreateInstrCallback(CreateInstrCallback CB)741 Context::registerCreateInstrCallback(CreateInstrCallback CB) {
742 assert(CreateInstrCallbacks.size() <= MaxRegisteredCallbacks &&
743 "CreateInstrCallbacks size limit exceeded");
744 CallbackID ID{NextCallbackID++};
745 CreateInstrCallbacks[ID] = CB;
746 return ID;
747 }
unregisterCreateInstrCallback(CallbackID ID)748 void Context::unregisterCreateInstrCallback(CallbackID ID) {
749 [[maybe_unused]] bool Erased = CreateInstrCallbacks.erase(ID);
750 assert(Erased &&
751 "Callback ID not found in CreateInstrCallbacks during deregistration");
752 }
753
registerMoveInstrCallback(MoveInstrCallback CB)754 Context::CallbackID Context::registerMoveInstrCallback(MoveInstrCallback CB) {
755 assert(MoveInstrCallbacks.size() <= MaxRegisteredCallbacks &&
756 "MoveInstrCallbacks size limit exceeded");
757 CallbackID ID{NextCallbackID++};
758 MoveInstrCallbacks[ID] = CB;
759 return ID;
760 }
unregisterMoveInstrCallback(CallbackID ID)761 void Context::unregisterMoveInstrCallback(CallbackID ID) {
762 [[maybe_unused]] bool Erased = MoveInstrCallbacks.erase(ID);
763 assert(Erased &&
764 "Callback ID not found in MoveInstrCallbacks during deregistration");
765 }
766
registerSetUseCallback(SetUseCallback CB)767 Context::CallbackID Context::registerSetUseCallback(SetUseCallback CB) {
768 assert(SetUseCallbacks.size() <= MaxRegisteredCallbacks &&
769 "SetUseCallbacks size limit exceeded");
770 CallbackID ID{NextCallbackID++};
771 SetUseCallbacks[ID] = CB;
772 return ID;
773 }
unregisterSetUseCallback(CallbackID ID)774 void Context::unregisterSetUseCallback(CallbackID ID) {
775 [[maybe_unused]] bool Erased = SetUseCallbacks.erase(ID);
776 assert(Erased &&
777 "Callback ID not found in SetUseCallbacks during deregistration");
778 }
779
780 } // namespace llvm::sandboxir
781