xref: /freebsd/contrib/llvm-project/llvm/lib/SandboxIR/Tracker.cpp (revision 2c2ec6bbc9cc7762a250ffe903bda6c2e44d25ff)
1 //===- Tracker.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/SandboxIR/Tracker.h"
10 #include "llvm/ADT/STLExtras.h"
11 #include "llvm/IR/BasicBlock.h"
12 #include "llvm/IR/Instruction.h"
13 #include "llvm/IR/StructuralHash.h"
14 #include "llvm/SandboxIR/Instruction.h"
15 
16 using namespace llvm::sandboxir;
17 
18 #ifndef NDEBUG
19 
20 std::string IRSnapshotChecker::dumpIR(const llvm::Function &F) const {
21   std::string Result;
22   raw_string_ostream SS(Result);
23   F.print(SS, /*AssemblyAnnotationWriter=*/nullptr);
24   return Result;
25 }
26 
27 IRSnapshotChecker::ContextSnapshot IRSnapshotChecker::takeSnapshot() const {
28   ContextSnapshot Result;
29   for (const auto &Entry : Ctx.LLVMModuleToModuleMap)
30     for (const auto &F : *Entry.first) {
31       FunctionSnapshot Snapshot;
32       Snapshot.Hash = StructuralHash(F, /*DetailedHash=*/true);
33       Snapshot.TextualIR = dumpIR(F);
34       Result[&F] = Snapshot;
35     }
36   return Result;
37 }
38 
39 bool IRSnapshotChecker::diff(const ContextSnapshot &Orig,
40                              const ContextSnapshot &Curr) const {
41   bool DifferenceFound = false;
42   for (const auto &[F, OrigFS] : Orig) {
43     auto CurrFSIt = Curr.find(F);
44     if (CurrFSIt == Curr.end()) {
45       DifferenceFound = true;
46       dbgs() << "Function " << F->getName() << " not found in current IR.\n";
47       dbgs() << OrigFS.TextualIR << "\n";
48       continue;
49     }
50     const FunctionSnapshot &CurrFS = CurrFSIt->second;
51     if (OrigFS.Hash != CurrFS.Hash) {
52       DifferenceFound = true;
53       dbgs() << "Found IR difference in Function " << F->getName() << "\n";
54       dbgs() << "Original:\n" << OrigFS.TextualIR << "\n";
55       dbgs() << "Current:\n" << CurrFS.TextualIR << "\n";
56     }
57   }
58   // Check that Curr doesn't contain any new functions.
59   for (const auto &[F, CurrFS] : Curr) {
60     if (!Orig.contains(F)) {
61       DifferenceFound = true;
62       dbgs() << "Function " << F->getName()
63              << " found in current IR but not in original snapshot.\n";
64       dbgs() << CurrFS.TextualIR << "\n";
65     }
66   }
67   return DifferenceFound;
68 }
69 
70 void IRSnapshotChecker::save() { OrigContextSnapshot = takeSnapshot(); }
71 
72 void IRSnapshotChecker::expectNoDiff() {
73   ContextSnapshot CurrContextSnapshot = takeSnapshot();
74   if (diff(OrigContextSnapshot, CurrContextSnapshot)) {
75     llvm_unreachable(
76         "Original and current IR differ! Probably a checkpointing bug.");
77   }
78 }
79 
80 void UseSet::dump() const {
81   dump(dbgs());
82   dbgs() << "\n";
83 }
84 
85 void UseSwap::dump() const {
86   dump(dbgs());
87   dbgs() << "\n";
88 }
89 #endif // NDEBUG
90 
91 PHIRemoveIncoming::PHIRemoveIncoming(PHINode *PHI, unsigned RemovedIdx)
92     : PHI(PHI), RemovedIdx(RemovedIdx) {
93   RemovedV = PHI->getIncomingValue(RemovedIdx);
94   RemovedBB = PHI->getIncomingBlock(RemovedIdx);
95 }
96 
97 void PHIRemoveIncoming::revert(Tracker &Tracker) {
98   // Special case: if the PHI is now empty, as we don't need to care about the
99   // order of the incoming values.
100   unsigned NumIncoming = PHI->getNumIncomingValues();
101   if (NumIncoming == 0) {
102     PHI->addIncoming(RemovedV, RemovedBB);
103     return;
104   }
105   // Shift all incoming values by one starting from the end until `Idx`.
106   // Start by adding a copy of the last incoming values.
107   unsigned LastIdx = NumIncoming - 1;
108   PHI->addIncoming(PHI->getIncomingValue(LastIdx),
109                    PHI->getIncomingBlock(LastIdx));
110   for (unsigned Idx = LastIdx; Idx > RemovedIdx; --Idx) {
111     auto *PrevV = PHI->getIncomingValue(Idx - 1);
112     auto *PrevBB = PHI->getIncomingBlock(Idx - 1);
113     PHI->setIncomingValue(Idx, PrevV);
114     PHI->setIncomingBlock(Idx, PrevBB);
115   }
116   PHI->setIncomingValue(RemovedIdx, RemovedV);
117   PHI->setIncomingBlock(RemovedIdx, RemovedBB);
118 }
119 
120 #ifndef NDEBUG
121 void PHIRemoveIncoming::dump() const {
122   dump(dbgs());
123   dbgs() << "\n";
124 }
125 #endif // NDEBUG
126 
127 PHIAddIncoming::PHIAddIncoming(PHINode *PHI)
128     : PHI(PHI), Idx(PHI->getNumIncomingValues()) {}
129 
130 void PHIAddIncoming::revert(Tracker &Tracker) { PHI->removeIncomingValue(Idx); }
131 
132 #ifndef NDEBUG
133 void PHIAddIncoming::dump() const {
134   dump(dbgs());
135   dbgs() << "\n";
136 }
137 #endif // NDEBUG
138 
139 Tracker::~Tracker() {
140   assert(Changes.empty() && "You must accept or revert changes!");
141 }
142 
143 EraseFromParent::EraseFromParent(std::unique_ptr<sandboxir::Value> &&ErasedIPtr)
144     : ErasedIPtr(std::move(ErasedIPtr)) {
145   auto *I = cast<Instruction>(this->ErasedIPtr.get());
146   auto LLVMInstrs = I->getLLVMInstrs();
147   // Iterate in reverse program order.
148   for (auto *LLVMI : reverse(LLVMInstrs)) {
149     SmallVector<llvm::Value *> Operands;
150     Operands.reserve(LLVMI->getNumOperands());
151     for (auto [OpNum, Use] : enumerate(LLVMI->operands()))
152       Operands.push_back(Use.get());
153     InstrData.push_back({Operands, LLVMI});
154   }
155   assert(is_sorted(InstrData,
156                    [](const auto &D0, const auto &D1) {
157                      return D0.LLVMI->comesBefore(D1.LLVMI);
158                    }) &&
159          "Expected reverse program order!");
160   auto *BotLLVMI = cast<llvm::Instruction>(I->Val);
161   if (BotLLVMI->getNextNode() != nullptr)
162     NextLLVMIOrBB = BotLLVMI->getNextNode();
163   else
164     NextLLVMIOrBB = BotLLVMI->getParent();
165 }
166 
167 void EraseFromParent::accept() {
168   for (const auto &IData : InstrData)
169     IData.LLVMI->deleteValue();
170 }
171 
172 void EraseFromParent::revert(Tracker &Tracker) {
173   // Place the bottom-most instruction first.
174   auto [Operands, BotLLVMI] = InstrData[0];
175   if (auto *NextLLVMI = dyn_cast<llvm::Instruction *>(NextLLVMIOrBB)) {
176     BotLLVMI->insertBefore(NextLLVMI->getIterator());
177   } else {
178     auto *LLVMBB = cast<llvm::BasicBlock *>(NextLLVMIOrBB);
179     BotLLVMI->insertInto(LLVMBB, LLVMBB->end());
180   }
181   for (auto [OpNum, Op] : enumerate(Operands))
182     BotLLVMI->setOperand(OpNum, Op);
183 
184   // Go over the rest of the instructions and stack them on top.
185   for (auto [Operands, LLVMI] : drop_begin(InstrData)) {
186     LLVMI->insertBefore(BotLLVMI->getIterator());
187     for (auto [OpNum, Op] : enumerate(Operands))
188       LLVMI->setOperand(OpNum, Op);
189     BotLLVMI = LLVMI;
190   }
191   Tracker.getContext().registerValue(std::move(ErasedIPtr));
192 }
193 
194 #ifndef NDEBUG
195 void EraseFromParent::dump() const {
196   dump(dbgs());
197   dbgs() << "\n";
198 }
199 #endif // NDEBUG
200 
201 RemoveFromParent::RemoveFromParent(Instruction *RemovedI) : RemovedI(RemovedI) {
202   if (auto *NextI = RemovedI->getNextNode())
203     NextInstrOrBB = NextI;
204   else
205     NextInstrOrBB = RemovedI->getParent();
206 }
207 
208 void RemoveFromParent::revert(Tracker &Tracker) {
209   if (auto *NextI = dyn_cast<Instruction *>(NextInstrOrBB)) {
210     RemovedI->insertBefore(NextI);
211   } else {
212     auto *BB = cast<BasicBlock *>(NextInstrOrBB);
213     RemovedI->insertInto(BB, BB->end());
214   }
215 }
216 
217 #ifndef NDEBUG
218 void RemoveFromParent::dump() const {
219   dump(dbgs());
220   dbgs() << "\n";
221 }
222 #endif
223 
224 CatchSwitchAddHandler::CatchSwitchAddHandler(CatchSwitchInst *CSI)
225     : CSI(CSI), HandlerIdx(CSI->getNumHandlers()) {}
226 
227 void CatchSwitchAddHandler::revert(Tracker &Tracker) {
228   // TODO: This should ideally use sandboxir::CatchSwitchInst::removeHandler()
229   // once it gets implemented.
230   auto *LLVMCSI = cast<llvm::CatchSwitchInst>(CSI->Val);
231   LLVMCSI->removeHandler(LLVMCSI->handler_begin() + HandlerIdx);
232 }
233 
234 SwitchRemoveCase::SwitchRemoveCase(SwitchInst *Switch) : Switch(Switch) {
235   for (const auto &C : Switch->cases())
236     Cases.push_back({C.getCaseValue(), C.getCaseSuccessor()});
237 }
238 
239 void SwitchRemoveCase::revert(Tracker &Tracker) {
240   // SwitchInst::removeCase doesn't provide any guarantees about the order of
241   // cases after removal. In order to preserve the original ordering, we save
242   // all of them and, when reverting, clear them all then insert them in the
243   // desired order. This still relies on the fact that `addCase` will insert
244   // them at the end, but it is documented to invalidate `case_end()` so it's
245   // probably okay.
246   unsigned NumCases = Switch->getNumCases();
247   for (unsigned I = 0; I < NumCases; ++I)
248     Switch->removeCase(Switch->case_begin());
249   for (auto &Case : Cases)
250     Switch->addCase(Case.Val, Case.Dest);
251 }
252 
253 #ifndef NDEBUG
254 void SwitchRemoveCase::dump() const {
255   dump(dbgs());
256   dbgs() << "\n";
257 }
258 #endif // NDEBUG
259 
260 void SwitchAddCase::revert(Tracker &Tracker) {
261   auto It = Switch->findCaseValue(Val);
262   Switch->removeCase(It);
263 }
264 
265 #ifndef NDEBUG
266 void SwitchAddCase::dump() const {
267   dump(dbgs());
268   dbgs() << "\n";
269 }
270 #endif // NDEBUG
271 
272 MoveInstr::MoveInstr(Instruction *MovedI) : MovedI(MovedI) {
273   if (auto *NextI = MovedI->getNextNode())
274     NextInstrOrBB = NextI;
275   else
276     NextInstrOrBB = MovedI->getParent();
277 }
278 
279 void MoveInstr::revert(Tracker &Tracker) {
280   if (auto *NextI = dyn_cast<Instruction *>(NextInstrOrBB)) {
281     MovedI->moveBefore(NextI);
282   } else {
283     auto *BB = cast<BasicBlock *>(NextInstrOrBB);
284     MovedI->moveBefore(*BB, BB->end());
285   }
286 }
287 
288 #ifndef NDEBUG
289 void MoveInstr::dump() const {
290   dump(dbgs());
291   dbgs() << "\n";
292 }
293 #endif
294 
295 void InsertIntoBB::revert(Tracker &Tracker) { InsertedI->removeFromParent(); }
296 
297 InsertIntoBB::InsertIntoBB(Instruction *InsertedI) : InsertedI(InsertedI) {}
298 
299 #ifndef NDEBUG
300 void InsertIntoBB::dump() const {
301   dump(dbgs());
302   dbgs() << "\n";
303 }
304 #endif
305 
306 void CreateAndInsertInst::revert(Tracker &Tracker) { NewI->eraseFromParent(); }
307 
308 #ifndef NDEBUG
309 void CreateAndInsertInst::dump() const {
310   dump(dbgs());
311   dbgs() << "\n";
312 }
313 #endif
314 
315 ShuffleVectorSetMask::ShuffleVectorSetMask(ShuffleVectorInst *SVI)
316     : SVI(SVI), PrevMask(SVI->getShuffleMask()) {}
317 
318 void ShuffleVectorSetMask::revert(Tracker &Tracker) {
319   SVI->setShuffleMask(PrevMask);
320 }
321 
322 #ifndef NDEBUG
323 void ShuffleVectorSetMask::dump() const {
324   dump(dbgs());
325   dbgs() << "\n";
326 }
327 #endif
328 
329 CmpSwapOperands::CmpSwapOperands(CmpInst *Cmp) : Cmp(Cmp) {}
330 
331 void CmpSwapOperands::revert(Tracker &Tracker) { Cmp->swapOperands(); }
332 #ifndef NDEBUG
333 void CmpSwapOperands::dump() const {
334   dump(dbgs());
335   dbgs() << "\n";
336 }
337 #endif
338 
339 void Tracker::save() {
340   State = TrackerState::Record;
341 #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS)
342   SnapshotChecker.save();
343 #endif
344 }
345 
346 void Tracker::revert() {
347   assert(State == TrackerState::Record && "Forgot to save()!");
348   State = TrackerState::Reverting;
349   for (auto &Change : reverse(Changes))
350     Change->revert(*this);
351   Changes.clear();
352 #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS)
353   SnapshotChecker.expectNoDiff();
354 #endif
355   State = TrackerState::Disabled;
356 }
357 
358 void Tracker::accept() {
359   assert(State == TrackerState::Record && "Forgot to save()!");
360   State = TrackerState::Disabled;
361   for (auto &Change : Changes)
362     Change->accept();
363   Changes.clear();
364 }
365 
366 #ifndef NDEBUG
367 void Tracker::dump(raw_ostream &OS) const {
368   for (auto [Idx, ChangePtr] : enumerate(Changes)) {
369     OS << Idx << ". ";
370     ChangePtr->dump(OS);
371     OS << "\n";
372   }
373 }
374 void Tracker::dump() const {
375   dump(dbgs());
376   dbgs() << "\n";
377 }
378 #endif // NDEBUG
379