xref: /freebsd/contrib/llvm-project/llvm/lib/SandboxIR/SandboxIR.cpp (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- SandboxIR.cpp - A transactional overlay IR on top of LLVM IR -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/SandboxIR/SandboxIR.h"
10 #include "llvm/ADT/SmallPtrSet.h"
11 #include "llvm/IR/Constants.h"
12 #include "llvm/Support/Debug.h"
13 #include <sstream>
14 
15 using namespace llvm::sandboxir;
16 
get() const17 Value *Use::get() const { return Ctx->getValue(LLVMUse->get()); }
18 
set(Value * V)19 void Use::set(Value *V) { LLVMUse->set(V->Val); }
20 
getOperandNo() const21 unsigned Use::getOperandNo() const { return Usr->getUseOperandNo(*this); }
22 
23 #ifndef NDEBUG
dump(raw_ostream & OS) const24 void Use::dump(raw_ostream &OS) const {
25   Value *Def = nullptr;
26   if (LLVMUse == nullptr)
27     OS << "<null> LLVM Use! ";
28   else
29     Def = Ctx->getValue(LLVMUse->get());
30   OS << "Def:  ";
31   if (Def == nullptr)
32     OS << "NULL";
33   else
34     OS << *Def;
35   OS << "\n";
36 
37   OS << "User: ";
38   if (Usr == nullptr)
39     OS << "NULL";
40   else
41     OS << *Usr;
42   OS << "\n";
43 
44   OS << "OperandNo: ";
45   if (Usr == nullptr)
46     OS << "N/A";
47   else
48     OS << getOperandNo();
49   OS << "\n";
50 }
51 
dump() const52 void Use::dump() const { dump(dbgs()); }
53 #endif // NDEBUG
54 
operator *() const55 Use OperandUseIterator::operator*() const { return Use; }
56 
operator ++()57 OperandUseIterator &OperandUseIterator::operator++() {
58   assert(Use.LLVMUse != nullptr && "Already at end!");
59   User *User = Use.getUser();
60   Use = User->getOperandUseInternal(Use.getOperandNo() + 1, /*Verify=*/false);
61   return *this;
62 }
63 
operator ++()64 UserUseIterator &UserUseIterator::operator++() {
65   // Get the corresponding llvm::Use, get the next in the list, and update the
66   // sandboxir::Use.
67   llvm::Use *&LLVMUse = Use.LLVMUse;
68   assert(LLVMUse != nullptr && "Already at end!");
69   LLVMUse = LLVMUse->getNext();
70   if (LLVMUse == nullptr) {
71     Use.Usr = nullptr;
72     return *this;
73   }
74   auto *Ctx = Use.Ctx;
75   auto *LLVMUser = LLVMUse->getUser();
76   Use.Usr = cast_or_null<sandboxir::User>(Ctx->getValue(LLVMUser));
77   return *this;
78 }
79 
Value(ClassID SubclassID,llvm::Value * Val,Context & Ctx)80 Value::Value(ClassID SubclassID, llvm::Value *Val, Context &Ctx)
81     : SubclassID(SubclassID), Val(Val), Ctx(Ctx) {
82 #ifndef NDEBUG
83   UID = Ctx.getNumValues();
84 #endif
85 }
86 
use_begin()87 Value::use_iterator Value::use_begin() {
88   llvm::Use *LLVMUse = nullptr;
89   if (Val->use_begin() != Val->use_end())
90     LLVMUse = &*Val->use_begin();
91   User *User = LLVMUse != nullptr ? cast_or_null<sandboxir::User>(Ctx.getValue(
92                                         Val->use_begin()->getUser()))
93                                   : nullptr;
94   return use_iterator(Use(LLVMUse, User, Ctx));
95 }
96 
user_begin()97 Value::user_iterator Value::user_begin() {
98   auto UseBegin = Val->use_begin();
99   auto UseEnd = Val->use_end();
100   bool AtEnd = UseBegin == UseEnd;
101   llvm::Use *LLVMUse = AtEnd ? nullptr : &*UseBegin;
102   User *User =
103       AtEnd ? nullptr
104             : cast_or_null<sandboxir::User>(Ctx.getValue(&*LLVMUse->getUser()));
105   return user_iterator(Use(LLVMUse, User, Ctx), UseToUser());
106 }
107 
getNumUses() const108 unsigned Value::getNumUses() const { return range_size(Val->users()); }
109 
replaceUsesWithIf(Value * OtherV,llvm::function_ref<bool (const Use &)> ShouldReplace)110 void Value::replaceUsesWithIf(
111     Value *OtherV, llvm::function_ref<bool(const Use &)> ShouldReplace) {
112   assert(getType() == OtherV->getType() && "Can't replace with different type");
113   llvm::Value *OtherVal = OtherV->Val;
114   // We are delegating RUWIf to LLVM IR's RUWIf.
115   Val->replaceUsesWithIf(
116       OtherVal, [&ShouldReplace, this](llvm::Use &LLVMUse) -> bool {
117         User *DstU = cast_or_null<User>(Ctx.getValue(LLVMUse.getUser()));
118         if (DstU == nullptr)
119           return false;
120         Use UseToReplace(&LLVMUse, DstU, Ctx);
121         if (!ShouldReplace(UseToReplace))
122           return false;
123         auto &Tracker = Ctx.getTracker();
124         if (Tracker.isTracking())
125           Tracker.track(std::make_unique<UseSet>(UseToReplace, Tracker));
126         return true;
127       });
128 }
129 
replaceAllUsesWith(Value * Other)130 void Value::replaceAllUsesWith(Value *Other) {
131   assert(getType() == Other->getType() &&
132          "Replacing with Value of different type!");
133   auto &Tracker = Ctx.getTracker();
134   if (Tracker.isTracking()) {
135     for (auto Use : uses())
136       Tracker.track(std::make_unique<UseSet>(Use, Tracker));
137   }
138   // We are delegating RAUW to LLVM IR's RAUW.
139   Val->replaceAllUsesWith(Other->Val);
140 }
141 
142 #ifndef NDEBUG
getUid() const143 std::string Value::getUid() const {
144   std::stringstream SS;
145   SS << "SB" << UID << ".";
146   return SS.str();
147 }
148 
dumpCommonHeader(raw_ostream & OS) const149 void Value::dumpCommonHeader(raw_ostream &OS) const {
150   OS << getUid() << " " << getSubclassIDStr(SubclassID) << " ";
151 }
152 
dumpCommonFooter(raw_ostream & OS) const153 void Value::dumpCommonFooter(raw_ostream &OS) const {
154   OS.indent(2) << "Val: ";
155   if (Val)
156     OS << *Val;
157   else
158     OS << "NULL";
159   OS << "\n";
160 }
161 
dumpCommonPrefix(raw_ostream & OS) const162 void Value::dumpCommonPrefix(raw_ostream &OS) const {
163   if (Val)
164     OS << *Val;
165   else
166     OS << "NULL ";
167 }
168 
dumpCommonSuffix(raw_ostream & OS) const169 void Value::dumpCommonSuffix(raw_ostream &OS) const {
170   OS << " ; " << getUid() << " (" << getSubclassIDStr(SubclassID) << ")";
171 }
172 
printAsOperandCommon(raw_ostream & OS) const173 void Value::printAsOperandCommon(raw_ostream &OS) const {
174   if (Val)
175     Val->printAsOperand(OS);
176   else
177     OS << "NULL ";
178 }
179 
printAsOperand(raw_ostream & OS) const180 void Argument::printAsOperand(raw_ostream &OS) const {
181   printAsOperandCommon(OS);
182 }
dump(raw_ostream & OS) const183 void Argument::dump(raw_ostream &OS) const {
184   dumpCommonPrefix(OS);
185   dumpCommonSuffix(OS);
186 }
dump() const187 void Argument::dump() const {
188   dump(dbgs());
189   dbgs() << "\n";
190 }
191 #endif // NDEBUG
192 
getOperandUseDefault(unsigned OpIdx,bool Verify) const193 Use User::getOperandUseDefault(unsigned OpIdx, bool Verify) const {
194   assert((!Verify || OpIdx < getNumOperands()) && "Out of bounds!");
195   assert(isa<llvm::User>(Val) && "Non-users have no operands!");
196   llvm::Use *LLVMUse;
197   if (OpIdx != getNumOperands())
198     LLVMUse = &cast<llvm::User>(Val)->getOperandUse(OpIdx);
199   else
200     LLVMUse = cast<llvm::User>(Val)->op_end();
201   return Use(LLVMUse, const_cast<User *>(this), Ctx);
202 }
203 
204 #ifndef NDEBUG
verifyUserOfLLVMUse(const llvm::Use & Use) const205 void User::verifyUserOfLLVMUse(const llvm::Use &Use) const {
206   assert(Ctx.getValue(Use.getUser()) == this &&
207          "Use not found in this SBUser's operands!");
208 }
209 #endif
210 
classof(const Value * From)211 bool User::classof(const Value *From) {
212   switch (From->getSubclassID()) {
213 #define DEF_VALUE(ID, CLASS)
214 #define DEF_USER(ID, CLASS)                                                    \
215   case ClassID::ID:                                                            \
216     return true;
217 #define DEF_INSTR(ID, OPC, CLASS)                                              \
218   case ClassID::ID:                                                            \
219     return true;
220 #include "llvm/SandboxIR/SandboxIRValues.def"
221   default:
222     return false;
223   }
224 }
225 
setOperand(unsigned OperandIdx,Value * Operand)226 void User::setOperand(unsigned OperandIdx, Value *Operand) {
227   assert(isa<llvm::User>(Val) && "No operands!");
228   auto &Tracker = Ctx.getTracker();
229   if (Tracker.isTracking())
230     Tracker.track(std::make_unique<UseSet>(getOperandUse(OperandIdx), Tracker));
231   // We are delegating to llvm::User::setOperand().
232   cast<llvm::User>(Val)->setOperand(OperandIdx, Operand->Val);
233 }
234 
replaceUsesOfWith(Value * FromV,Value * ToV)235 bool User::replaceUsesOfWith(Value *FromV, Value *ToV) {
236   auto &Tracker = Ctx.getTracker();
237   if (Tracker.isTracking()) {
238     for (auto OpIdx : seq<unsigned>(0, getNumOperands())) {
239       auto Use = getOperandUse(OpIdx);
240       if (Use.get() == FromV)
241         Tracker.track(std::make_unique<UseSet>(Use, Tracker));
242     }
243   }
244   // We are delegating RUOW to LLVM IR's RUOW.
245   return cast<llvm::User>(Val)->replaceUsesOfWith(FromV->Val, ToV->Val);
246 }
247 
248 #ifndef NDEBUG
dumpCommonHeader(raw_ostream & OS) const249 void User::dumpCommonHeader(raw_ostream &OS) const {
250   Value::dumpCommonHeader(OS);
251   // TODO: This is incomplete
252 }
253 #endif // NDEBUG
254 
operator ++()255 BBIterator &BBIterator::operator++() {
256   auto ItE = BB->end();
257   assert(It != ItE && "Already at end!");
258   ++It;
259   if (It == ItE)
260     return *this;
261   Instruction &NextI = *cast<sandboxir::Instruction>(Ctx->getValue(&*It));
262   unsigned Num = NextI.getNumOfIRInstrs();
263   assert(Num > 0 && "Bad getNumOfIRInstrs()");
264   It = std::next(It, Num - 1);
265   return *this;
266 }
267 
operator --()268 BBIterator &BBIterator::operator--() {
269   assert(It != BB->begin() && "Already at begin!");
270   if (It == BB->end()) {
271     --It;
272     return *this;
273   }
274   Instruction &CurrI = **this;
275   unsigned Num = CurrI.getNumOfIRInstrs();
276   assert(Num > 0 && "Bad getNumOfIRInstrs()");
277   assert(std::prev(It, Num - 1) != BB->begin() && "Already at begin!");
278   It = std::prev(It, Num);
279   return *this;
280 }
281 
getOpcodeName(Opcode Opc)282 const char *Instruction::getOpcodeName(Opcode Opc) {
283   switch (Opc) {
284 #define DEF_VALUE(ID, CLASS)
285 #define DEF_USER(ID, CLASS)
286 #define OP(OPC)                                                                \
287   case Opcode::OPC:                                                            \
288     return #OPC;
289 #define DEF_INSTR(ID, OPC, CLASS) OPC
290 #include "llvm/SandboxIR/SandboxIRValues.def"
291   }
292   llvm_unreachable("Unknown Opcode");
293 }
294 
getTopmostLLVMInstruction() const295 llvm::Instruction *Instruction::getTopmostLLVMInstruction() const {
296   Instruction *Prev = getPrevNode();
297   if (Prev == nullptr) {
298     // If at top of the BB, return the first BB instruction.
299     return &*cast<llvm::BasicBlock>(getParent()->Val)->begin();
300   }
301   // Else get the Previous sandbox IR instruction's bottom IR instruction and
302   // return its successor.
303   llvm::Instruction *PrevBotI = cast<llvm::Instruction>(Prev->Val);
304   return PrevBotI->getNextNode();
305 }
306 
getIterator() const307 BBIterator Instruction::getIterator() const {
308   auto *I = cast<llvm::Instruction>(Val);
309   return BasicBlock::iterator(I->getParent(), I->getIterator(), &Ctx);
310 }
311 
getNextNode() const312 Instruction *Instruction::getNextNode() const {
313   assert(getParent() != nullptr && "Detached!");
314   assert(getIterator() != getParent()->end() && "Already at end!");
315   // `Val` is the bottom-most LLVM IR instruction. Get the next in the chain,
316   // and get the corresponding sandboxir Instruction that maps to it. This works
317   // even for SandboxIR Instructions that map to more than one LLVM Instruction.
318   auto *LLVMI = cast<llvm::Instruction>(Val);
319   assert(LLVMI->getParent() != nullptr && "LLVM IR instr is detached!");
320   auto *NextLLVMI = LLVMI->getNextNode();
321   auto *NextI = cast_or_null<Instruction>(Ctx.getValue(NextLLVMI));
322   if (NextI == nullptr)
323     return nullptr;
324   return NextI;
325 }
326 
getPrevNode() const327 Instruction *Instruction::getPrevNode() const {
328   assert(getParent() != nullptr && "Detached!");
329   auto It = getIterator();
330   if (It != getParent()->begin())
331     return std::prev(getIterator()).get();
332   return nullptr;
333 }
334 
removeFromParent()335 void Instruction::removeFromParent() {
336   auto &Tracker = Ctx.getTracker();
337   if (Tracker.isTracking())
338     Tracker.track(std::make_unique<RemoveFromParent>(this, Tracker));
339 
340   // Detach all the LLVM IR instructions from their parent BB.
341   for (llvm::Instruction *I : getLLVMInstrs())
342     I->removeFromParent();
343 }
344 
eraseFromParent()345 void Instruction::eraseFromParent() {
346   assert(users().empty() && "Still connected to users, can't erase!");
347   std::unique_ptr<Value> Detached = Ctx.detach(this);
348   auto LLVMInstrs = getLLVMInstrs();
349 
350   auto &Tracker = Ctx.getTracker();
351   if (Tracker.isTracking()) {
352     Tracker.track(
353         std::make_unique<EraseFromParent>(std::move(Detached), Tracker));
354     // We don't actually delete the IR instruction, because then it would be
355     // impossible to bring it back from the dead at the same memory location.
356     // Instead we remove it from its BB and track its current location.
357     for (llvm::Instruction *I : LLVMInstrs)
358       I->removeFromParent();
359     // TODO: Multi-instructions need special treatment because some of the
360     // references are internal to the instruction.
361     for (llvm::Instruction *I : LLVMInstrs)
362       I->dropAllReferences();
363   } else {
364     // Erase in reverse to avoid erasing nstructions with attached uses.
365     for (llvm::Instruction *I : reverse(LLVMInstrs))
366       I->eraseFromParent();
367   }
368 }
369 
moveBefore(BasicBlock & BB,const BBIterator & WhereIt)370 void Instruction::moveBefore(BasicBlock &BB, const BBIterator &WhereIt) {
371   if (std::next(getIterator()) == WhereIt)
372     // Destination is same as origin, nothing to do.
373     return;
374 
375   auto &Tracker = Ctx.getTracker();
376   if (Tracker.isTracking())
377     Tracker.track(std::make_unique<MoveInstr>(this, Tracker));
378 
379   auto *LLVMBB = cast<llvm::BasicBlock>(BB.Val);
380   llvm::BasicBlock::iterator It;
381   if (WhereIt == BB.end()) {
382     It = LLVMBB->end();
383   } else {
384     Instruction *WhereI = &*WhereIt;
385     It = WhereI->getTopmostLLVMInstruction()->getIterator();
386   }
387   // TODO: Move this to the verifier of sandboxir::Instruction.
388   assert(is_sorted(getLLVMInstrs(),
389                    [](auto *I1, auto *I2) { return I1->comesBefore(I2); }) &&
390          "Expected program order!");
391   // Do the actual move in LLVM IR.
392   for (auto *I : getLLVMInstrs())
393     I->moveBefore(*LLVMBB, It);
394 }
395 
insertBefore(Instruction * BeforeI)396 void Instruction::insertBefore(Instruction *BeforeI) {
397   llvm::Instruction *BeforeTopI = BeforeI->getTopmostLLVMInstruction();
398   // TODO: Move this to the verifier of sandboxir::Instruction.
399   assert(is_sorted(getLLVMInstrs(),
400                    [](auto *I1, auto *I2) { return I1->comesBefore(I2); }) &&
401          "Expected program order!");
402   // Insert the LLVM IR Instructions in program order.
403   for (llvm::Instruction *I : getLLVMInstrs())
404     I->insertBefore(BeforeTopI);
405 }
406 
insertAfter(Instruction * AfterI)407 void Instruction::insertAfter(Instruction *AfterI) {
408   insertInto(AfterI->getParent(), std::next(AfterI->getIterator()));
409 }
410 
insertInto(BasicBlock * BB,const BBIterator & WhereIt)411 void Instruction::insertInto(BasicBlock *BB, const BBIterator &WhereIt) {
412   llvm::BasicBlock *LLVMBB = cast<llvm::BasicBlock>(BB->Val);
413   llvm::Instruction *LLVMBeforeI;
414   llvm::BasicBlock::iterator LLVMBeforeIt;
415   if (WhereIt != BB->end()) {
416     Instruction *BeforeI = &*WhereIt;
417     LLVMBeforeI = BeforeI->getTopmostLLVMInstruction();
418     LLVMBeforeIt = LLVMBeforeI->getIterator();
419   } else {
420     LLVMBeforeI = nullptr;
421     LLVMBeforeIt = LLVMBB->end();
422   }
423   // Insert the LLVM IR Instructions in program order.
424   for (llvm::Instruction *I : getLLVMInstrs())
425     I->insertInto(LLVMBB, LLVMBeforeIt);
426 }
427 
getParent() const428 BasicBlock *Instruction::getParent() const {
429   // Get the LLVM IR Instruction that this maps to, get its parent, and get the
430   // corresponding sandboxir::BasicBlock by looking it up in sandboxir::Context.
431   auto *BB = cast<llvm::Instruction>(Val)->getParent();
432   if (BB == nullptr)
433     return nullptr;
434   return cast<BasicBlock>(Ctx.getValue(BB));
435 }
436 
classof(const sandboxir::Value * From)437 bool Instruction::classof(const sandboxir::Value *From) {
438   switch (From->getSubclassID()) {
439 #define DEF_INSTR(ID, OPC, CLASS)                                              \
440   case ClassID::ID:                                                            \
441     return true;
442 #include "llvm/SandboxIR/SandboxIRValues.def"
443   default:
444     return false;
445   }
446 }
447 
448 #ifndef NDEBUG
dump(raw_ostream & OS) const449 void Instruction::dump(raw_ostream &OS) const {
450   OS << "Unimplemented! Please override dump().";
451 }
dump() const452 void Instruction::dump() const {
453   dump(dbgs());
454   dbgs() << "\n";
455 }
456 #endif // NDEBUG
457 
createCommon(Value * Cond,Value * True,Value * False,const Twine & Name,IRBuilder<> & Builder,Context & Ctx)458 Value *SelectInst::createCommon(Value *Cond, Value *True, Value *False,
459                                 const Twine &Name, IRBuilder<> &Builder,
460                                 Context &Ctx) {
461   llvm::Value *NewV =
462       Builder.CreateSelect(Cond->Val, True->Val, False->Val, Name);
463   if (auto *NewSI = dyn_cast<llvm::SelectInst>(NewV))
464     return Ctx.createSelectInst(NewSI);
465   assert(isa<llvm::Constant>(NewV) && "Expected constant");
466   return Ctx.getOrCreateConstant(cast<llvm::Constant>(NewV));
467 }
468 
create(Value * Cond,Value * True,Value * False,Instruction * InsertBefore,Context & Ctx,const Twine & Name)469 Value *SelectInst::create(Value *Cond, Value *True, Value *False,
470                           Instruction *InsertBefore, Context &Ctx,
471                           const Twine &Name) {
472   llvm::Instruction *BeforeIR = InsertBefore->getTopmostLLVMInstruction();
473   auto &Builder = Ctx.getLLVMIRBuilder();
474   Builder.SetInsertPoint(BeforeIR);
475   return createCommon(Cond, True, False, Name, Builder, Ctx);
476 }
477 
create(Value * Cond,Value * True,Value * False,BasicBlock * InsertAtEnd,Context & Ctx,const Twine & Name)478 Value *SelectInst::create(Value *Cond, Value *True, Value *False,
479                           BasicBlock *InsertAtEnd, Context &Ctx,
480                           const Twine &Name) {
481   auto *IRInsertAtEnd = cast<llvm::BasicBlock>(InsertAtEnd->Val);
482   auto &Builder = Ctx.getLLVMIRBuilder();
483   Builder.SetInsertPoint(IRInsertAtEnd);
484   return createCommon(Cond, True, False, Name, Builder, Ctx);
485 }
486 
classof(const Value * From)487 bool SelectInst::classof(const Value *From) {
488   return From->getSubclassID() == ClassID::Select;
489 }
490 
491 #ifndef NDEBUG
dump(raw_ostream & OS) const492 void SelectInst::dump(raw_ostream &OS) const {
493   dumpCommonPrefix(OS);
494   dumpCommonSuffix(OS);
495 }
496 
dump() const497 void SelectInst::dump() const {
498   dump(dbgs());
499   dbgs() << "\n";
500 }
501 #endif // NDEBUG
502 
create(Type * Ty,Value * Ptr,MaybeAlign Align,Instruction * InsertBefore,Context & Ctx,const Twine & Name)503 LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
504                            Instruction *InsertBefore, Context &Ctx,
505                            const Twine &Name) {
506   llvm::Instruction *BeforeIR = InsertBefore->getTopmostLLVMInstruction();
507   auto &Builder = Ctx.getLLVMIRBuilder();
508   Builder.SetInsertPoint(BeforeIR);
509   auto *NewLI = Builder.CreateAlignedLoad(Ty, Ptr->Val, Align,
510                                           /*isVolatile=*/false, Name);
511   auto *NewSBI = Ctx.createLoadInst(NewLI);
512   return NewSBI;
513 }
514 
create(Type * Ty,Value * Ptr,MaybeAlign Align,BasicBlock * InsertAtEnd,Context & Ctx,const Twine & Name)515 LoadInst *LoadInst::create(Type *Ty, Value *Ptr, MaybeAlign Align,
516                            BasicBlock *InsertAtEnd, Context &Ctx,
517                            const Twine &Name) {
518   auto &Builder = Ctx.getLLVMIRBuilder();
519   Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
520   auto *NewLI = Builder.CreateAlignedLoad(Ty, Ptr->Val, Align,
521                                           /*isVolatile=*/false, Name);
522   auto *NewSBI = Ctx.createLoadInst(NewLI);
523   return NewSBI;
524 }
525 
classof(const Value * From)526 bool LoadInst::classof(const Value *From) {
527   return From->getSubclassID() == ClassID::Load;
528 }
529 
getPointerOperand() const530 Value *LoadInst::getPointerOperand() const {
531   return Ctx.getValue(cast<llvm::LoadInst>(Val)->getPointerOperand());
532 }
533 
534 #ifndef NDEBUG
dump(raw_ostream & OS) const535 void LoadInst::dump(raw_ostream &OS) const {
536   dumpCommonPrefix(OS);
537   dumpCommonSuffix(OS);
538 }
539 
dump() const540 void LoadInst::dump() const {
541   dump(dbgs());
542   dbgs() << "\n";
543 }
544 #endif // NDEBUG
create(Value * V,Value * Ptr,MaybeAlign Align,Instruction * InsertBefore,Context & Ctx)545 StoreInst *StoreInst::create(Value *V, Value *Ptr, MaybeAlign Align,
546                              Instruction *InsertBefore, Context &Ctx) {
547   llvm::Instruction *BeforeIR = InsertBefore->getTopmostLLVMInstruction();
548   auto &Builder = Ctx.getLLVMIRBuilder();
549   Builder.SetInsertPoint(BeforeIR);
550   auto *NewSI =
551       Builder.CreateAlignedStore(V->Val, Ptr->Val, Align, /*isVolatile=*/false);
552   auto *NewSBI = Ctx.createStoreInst(NewSI);
553   return NewSBI;
554 }
create(Value * V,Value * Ptr,MaybeAlign Align,BasicBlock * InsertAtEnd,Context & Ctx)555 StoreInst *StoreInst::create(Value *V, Value *Ptr, MaybeAlign Align,
556                              BasicBlock *InsertAtEnd, Context &Ctx) {
557   auto *InsertAtEndIR = cast<llvm::BasicBlock>(InsertAtEnd->Val);
558   auto &Builder = Ctx.getLLVMIRBuilder();
559   Builder.SetInsertPoint(InsertAtEndIR);
560   auto *NewSI =
561       Builder.CreateAlignedStore(V->Val, Ptr->Val, Align, /*isVolatile=*/false);
562   auto *NewSBI = Ctx.createStoreInst(NewSI);
563   return NewSBI;
564 }
565 
classof(const Value * From)566 bool StoreInst::classof(const Value *From) {
567   return From->getSubclassID() == ClassID::Store;
568 }
569 
getValueOperand() const570 Value *StoreInst::getValueOperand() const {
571   return Ctx.getValue(cast<llvm::StoreInst>(Val)->getValueOperand());
572 }
573 
getPointerOperand() const574 Value *StoreInst::getPointerOperand() const {
575   return Ctx.getValue(cast<llvm::StoreInst>(Val)->getPointerOperand());
576 }
577 
578 #ifndef NDEBUG
dump(raw_ostream & OS) const579 void StoreInst::dump(raw_ostream &OS) const {
580   dumpCommonPrefix(OS);
581   dumpCommonSuffix(OS);
582 }
583 
dump() const584 void StoreInst::dump() const {
585   dump(dbgs());
586   dbgs() << "\n";
587 }
588 #endif // NDEBUG
589 
createCommon(Value * RetVal,IRBuilder<> & Builder,Context & Ctx)590 ReturnInst *ReturnInst::createCommon(Value *RetVal, IRBuilder<> &Builder,
591                                      Context &Ctx) {
592   llvm::ReturnInst *NewRI;
593   if (RetVal != nullptr)
594     NewRI = Builder.CreateRet(RetVal->Val);
595   else
596     NewRI = Builder.CreateRetVoid();
597   return Ctx.createReturnInst(NewRI);
598 }
599 
create(Value * RetVal,Instruction * InsertBefore,Context & Ctx)600 ReturnInst *ReturnInst::create(Value *RetVal, Instruction *InsertBefore,
601                                Context &Ctx) {
602   llvm::Instruction *BeforeIR = InsertBefore->getTopmostLLVMInstruction();
603   auto &Builder = Ctx.getLLVMIRBuilder();
604   Builder.SetInsertPoint(BeforeIR);
605   return createCommon(RetVal, Builder, Ctx);
606 }
607 
create(Value * RetVal,BasicBlock * InsertAtEnd,Context & Ctx)608 ReturnInst *ReturnInst::create(Value *RetVal, BasicBlock *InsertAtEnd,
609                                Context &Ctx) {
610   auto &Builder = Ctx.getLLVMIRBuilder();
611   Builder.SetInsertPoint(cast<llvm::BasicBlock>(InsertAtEnd->Val));
612   return createCommon(RetVal, Builder, Ctx);
613 }
614 
getReturnValue() const615 Value *ReturnInst::getReturnValue() const {
616   auto *LLVMRetVal = cast<llvm::ReturnInst>(Val)->getReturnValue();
617   return LLVMRetVal != nullptr ? Ctx.getValue(LLVMRetVal) : nullptr;
618 }
619 
620 #ifndef NDEBUG
dump(raw_ostream & OS) const621 void ReturnInst::dump(raw_ostream &OS) const {
622   dumpCommonPrefix(OS);
623   dumpCommonSuffix(OS);
624 }
625 
dump() const626 void ReturnInst::dump() const {
627   dump(dbgs());
628   dbgs() << "\n";
629 }
630 
dump(raw_ostream & OS) const631 void OpaqueInst::dump(raw_ostream &OS) const {
632   dumpCommonPrefix(OS);
633   dumpCommonSuffix(OS);
634 }
635 
dump() const636 void OpaqueInst::dump() const {
637   dump(dbgs());
638   dbgs() << "\n";
639 }
640 #endif // NDEBUG
641 
createInt(Type * Ty,uint64_t V,Context & Ctx,bool IsSigned)642 Constant *Constant::createInt(Type *Ty, uint64_t V, Context &Ctx,
643                               bool IsSigned) {
644   llvm::Constant *LLVMC = llvm::ConstantInt::get(Ty, V, IsSigned);
645   return Ctx.getOrCreateConstant(LLVMC);
646 }
647 
648 #ifndef NDEBUG
dump(raw_ostream & OS) const649 void Constant::dump(raw_ostream &OS) const {
650   dumpCommonPrefix(OS);
651   dumpCommonSuffix(OS);
652 }
653 
dump() const654 void Constant::dump() const {
655   dump(dbgs());
656   dbgs() << "\n";
657 }
658 
dumpNameAndArgs(raw_ostream & OS) const659 void Function::dumpNameAndArgs(raw_ostream &OS) const {
660   auto *F = cast<llvm::Function>(Val);
661   OS << *F->getReturnType() << " @" << F->getName() << "(";
662   interleave(
663       F->args(),
664       [this, &OS](const llvm::Argument &LLVMArg) {
665         auto *SBArg = cast_or_null<Argument>(Ctx.getValue(&LLVMArg));
666         if (SBArg == nullptr)
667           OS << "NULL";
668         else
669           SBArg->printAsOperand(OS);
670       },
671       [&] { OS << ", "; });
672   OS << ")";
673 }
dump(raw_ostream & OS) const674 void Function::dump(raw_ostream &OS) const {
675   dumpNameAndArgs(OS);
676   OS << " {\n";
677   auto *LLVMF = cast<llvm::Function>(Val);
678   interleave(
679       *LLVMF,
680       [this, &OS](const llvm::BasicBlock &LLVMBB) {
681         auto *BB = cast_or_null<BasicBlock>(Ctx.getValue(&LLVMBB));
682         if (BB == nullptr)
683           OS << "NULL";
684         else
685           OS << *BB;
686       },
687       [&OS] { OS << "\n"; });
688   OS << "}\n";
689 }
dump() const690 void Function::dump() const {
691   dump(dbgs());
692   dbgs() << "\n";
693 }
694 #endif // NDEBUG
695 
696 BasicBlock::iterator::pointer
getInstr(llvm::BasicBlock::iterator It) const697 BasicBlock::iterator::getInstr(llvm::BasicBlock::iterator It) const {
698   return cast_or_null<Instruction>(Ctx->getValue(&*It));
699 }
700 
detachLLVMValue(llvm::Value * V)701 std::unique_ptr<Value> Context::detachLLVMValue(llvm::Value *V) {
702   std::unique_ptr<Value> Erased;
703   auto It = LLVMValueToValueMap.find(V);
704   if (It != LLVMValueToValueMap.end()) {
705     auto *Val = It->second.release();
706     Erased = std::unique_ptr<Value>(Val);
707     LLVMValueToValueMap.erase(It);
708   }
709   return Erased;
710 }
711 
detach(Value * V)712 std::unique_ptr<Value> Context::detach(Value *V) {
713   assert(V->getSubclassID() != Value::ClassID::Constant &&
714          "Can't detach a constant!");
715   assert(V->getSubclassID() != Value::ClassID::User && "Can't detach a user!");
716   return detachLLVMValue(V->Val);
717 }
718 
registerValue(std::unique_ptr<Value> && VPtr)719 Value *Context::registerValue(std::unique_ptr<Value> &&VPtr) {
720   assert(VPtr->getSubclassID() != Value::ClassID::User &&
721          "Can't register a user!");
722   Value *V = VPtr.get();
723   [[maybe_unused]] auto Pair =
724       LLVMValueToValueMap.insert({VPtr->Val, std::move(VPtr)});
725   assert(Pair.second && "Already exists!");
726   return V;
727 }
728 
getOrCreateValueInternal(llvm::Value * LLVMV,llvm::User * U)729 Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
730   auto Pair = LLVMValueToValueMap.insert({LLVMV, nullptr});
731   auto It = Pair.first;
732   if (!Pair.second)
733     return It->second.get();
734 
735   if (auto *C = dyn_cast<llvm::Constant>(LLVMV)) {
736     It->second = std::unique_ptr<Constant>(new Constant(C, *this));
737     auto *NewC = It->second.get();
738     for (llvm::Value *COp : C->operands())
739       getOrCreateValueInternal(COp, C);
740     return NewC;
741   }
742   if (auto *Arg = dyn_cast<llvm::Argument>(LLVMV)) {
743     It->second = std::unique_ptr<Argument>(new Argument(Arg, *this));
744     return It->second.get();
745   }
746   if (auto *BB = dyn_cast<llvm::BasicBlock>(LLVMV)) {
747     assert(isa<BlockAddress>(U) &&
748            "This won't create a SBBB, don't call this function directly!");
749     if (auto *SBBB = getValue(BB))
750       return SBBB;
751     return nullptr;
752   }
753   assert(isa<llvm::Instruction>(LLVMV) && "Expected Instruction");
754 
755   switch (cast<llvm::Instruction>(LLVMV)->getOpcode()) {
756   case llvm::Instruction::Select: {
757     auto *LLVMSel = cast<llvm::SelectInst>(LLVMV);
758     It->second = std::unique_ptr<SelectInst>(new SelectInst(LLVMSel, *this));
759     return It->second.get();
760   }
761   case llvm::Instruction::Load: {
762     auto *LLVMLd = cast<llvm::LoadInst>(LLVMV);
763     It->second = std::unique_ptr<LoadInst>(new LoadInst(LLVMLd, *this));
764     return It->second.get();
765   }
766   case llvm::Instruction::Store: {
767     auto *LLVMSt = cast<llvm::StoreInst>(LLVMV);
768     It->second = std::unique_ptr<StoreInst>(new StoreInst(LLVMSt, *this));
769     return It->second.get();
770   }
771   case llvm::Instruction::Ret: {
772     auto *LLVMRet = cast<llvm::ReturnInst>(LLVMV);
773     It->second = std::unique_ptr<ReturnInst>(new ReturnInst(LLVMRet, *this));
774     return It->second.get();
775   }
776   default:
777     break;
778   }
779 
780   It->second = std::unique_ptr<OpaqueInst>(
781       new OpaqueInst(cast<llvm::Instruction>(LLVMV), *this));
782   return It->second.get();
783 }
784 
createBasicBlock(llvm::BasicBlock * LLVMBB)785 BasicBlock *Context::createBasicBlock(llvm::BasicBlock *LLVMBB) {
786   assert(getValue(LLVMBB) == nullptr && "Already exists!");
787   auto NewBBPtr = std::unique_ptr<BasicBlock>(new BasicBlock(LLVMBB, *this));
788   auto *BB = cast<BasicBlock>(registerValue(std::move(NewBBPtr)));
789   // Create SandboxIR for BB's body.
790   BB->buildBasicBlockFromLLVMIR(LLVMBB);
791   return BB;
792 }
793 
createSelectInst(llvm::SelectInst * SI)794 SelectInst *Context::createSelectInst(llvm::SelectInst *SI) {
795   auto NewPtr = std::unique_ptr<SelectInst>(new SelectInst(SI, *this));
796   return cast<SelectInst>(registerValue(std::move(NewPtr)));
797 }
798 
createLoadInst(llvm::LoadInst * LI)799 LoadInst *Context::createLoadInst(llvm::LoadInst *LI) {
800   auto NewPtr = std::unique_ptr<LoadInst>(new LoadInst(LI, *this));
801   return cast<LoadInst>(registerValue(std::move(NewPtr)));
802 }
803 
createStoreInst(llvm::StoreInst * SI)804 StoreInst *Context::createStoreInst(llvm::StoreInst *SI) {
805   auto NewPtr = std::unique_ptr<StoreInst>(new StoreInst(SI, *this));
806   return cast<StoreInst>(registerValue(std::move(NewPtr)));
807 }
808 
createReturnInst(llvm::ReturnInst * I)809 ReturnInst *Context::createReturnInst(llvm::ReturnInst *I) {
810   auto NewPtr = std::unique_ptr<ReturnInst>(new ReturnInst(I, *this));
811   return cast<ReturnInst>(registerValue(std::move(NewPtr)));
812 }
813 
getValue(llvm::Value * V) const814 Value *Context::getValue(llvm::Value *V) const {
815   auto It = LLVMValueToValueMap.find(V);
816   if (It != LLVMValueToValueMap.end())
817     return It->second.get();
818   return nullptr;
819 }
820 
createFunction(llvm::Function * F)821 Function *Context::createFunction(llvm::Function *F) {
822   assert(getValue(F) == nullptr && "Already exists!");
823   auto NewFPtr = std::unique_ptr<Function>(new Function(F, *this));
824   // Create arguments.
825   for (auto &Arg : F->args())
826     getOrCreateArgument(&Arg);
827   // Create BBs.
828   for (auto &BB : *F)
829     createBasicBlock(&BB);
830   auto *SBF = cast<Function>(registerValue(std::move(NewFPtr)));
831   return SBF;
832 }
833 
getParent() const834 Function *BasicBlock::getParent() const {
835   auto *BB = cast<llvm::BasicBlock>(Val);
836   auto *F = BB->getParent();
837   if (F == nullptr)
838     // Detached
839     return nullptr;
840   return cast_or_null<Function>(Ctx.getValue(F));
841 }
842 
buildBasicBlockFromLLVMIR(llvm::BasicBlock * LLVMBB)843 void BasicBlock::buildBasicBlockFromLLVMIR(llvm::BasicBlock *LLVMBB) {
844   for (llvm::Instruction &IRef : reverse(*LLVMBB)) {
845     llvm::Instruction *I = &IRef;
846     Ctx.getOrCreateValue(I);
847     for (auto [OpIdx, Op] : enumerate(I->operands())) {
848       // Skip instruction's label operands
849       if (isa<llvm::BasicBlock>(Op))
850         continue;
851       // Skip metadata
852       if (isa<llvm::MetadataAsValue>(Op))
853         continue;
854       // Skip asm
855       if (isa<llvm::InlineAsm>(Op))
856         continue;
857       Ctx.getOrCreateValue(Op);
858     }
859   }
860 #if !defined(NDEBUG) && defined(SBVEC_EXPENSIVE_CHECKS)
861   verify();
862 #endif
863 }
864 
begin() const865 BasicBlock::iterator BasicBlock::begin() const {
866   llvm::BasicBlock *BB = cast<llvm::BasicBlock>(Val);
867   llvm::BasicBlock::iterator It = BB->begin();
868   if (!BB->empty()) {
869     auto *V = Ctx.getValue(&*BB->begin());
870     assert(V != nullptr && "No SandboxIR for BB->begin()!");
871     auto *I = cast<Instruction>(V);
872     unsigned Num = I->getNumOfIRInstrs();
873     assert(Num >= 1u && "Bad getNumOfIRInstrs()");
874     It = std::next(It, Num - 1);
875   }
876   return iterator(BB, It, &Ctx);
877 }
878 
getTerminator() const879 Instruction *BasicBlock::getTerminator() const {
880   auto *TerminatorV =
881       Ctx.getValue(cast<llvm::BasicBlock>(Val)->getTerminator());
882   return cast_or_null<Instruction>(TerminatorV);
883 }
884 
front() const885 Instruction &BasicBlock::front() const {
886   auto *BB = cast<llvm::BasicBlock>(Val);
887   assert(!BB->empty() && "Empty block!");
888   auto *SBI = cast<Instruction>(getContext().getValue(&*BB->begin()));
889   assert(SBI != nullptr && "Expected Instr!");
890   return *SBI;
891 }
892 
back() const893 Instruction &BasicBlock::back() const {
894   auto *BB = cast<llvm::BasicBlock>(Val);
895   assert(!BB->empty() && "Empty block!");
896   auto *SBI = cast<Instruction>(getContext().getValue(&*BB->rbegin()));
897   assert(SBI != nullptr && "Expected Instr!");
898   return *SBI;
899 }
900 
901 #ifndef NDEBUG
dump(raw_ostream & OS) const902 void BasicBlock::dump(raw_ostream &OS) const {
903   llvm::BasicBlock *BB = cast<llvm::BasicBlock>(Val);
904   const auto &Name = BB->getName();
905   OS << Name;
906   if (!Name.empty())
907     OS << ":\n";
908   // If there are Instructions in the BB that are not mapped to SandboxIR, then
909   // use a crash-proof dump.
910   if (any_of(*BB, [this](llvm::Instruction &I) {
911         return Ctx.getValue(&I) == nullptr;
912       })) {
913     OS << "<Crash-proof mode!>\n";
914     DenseSet<Instruction *> Visited;
915     for (llvm::Instruction &IRef : *BB) {
916       Value *SBV = Ctx.getValue(&IRef);
917       if (SBV == nullptr)
918         OS << IRef << " *** No SandboxIR ***\n";
919       else {
920         auto *SBI = dyn_cast<Instruction>(SBV);
921         if (SBI == nullptr) {
922           OS << IRef << " *** Not a SBInstruction!!! ***\n";
923         } else {
924           if (Visited.insert(SBI).second)
925             OS << *SBI << "\n";
926         }
927       }
928     }
929   } else {
930     for (auto &SBI : *this) {
931       SBI.dump(OS);
932       OS << "\n";
933     }
934   }
935 }
dump() const936 void BasicBlock::dump() const {
937   dump(dbgs());
938   dbgs() << "\n";
939 }
940 #endif // NDEBUG
941