xref: /freebsd/contrib/llvm-project/llvm/include/llvm/SandboxIR/Tracker.h (revision 700637cbb5e582861067a11aaca4d053546871d2)
1 //===- Tracker.h ------------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file is the component of SandboxIR that tracks all changes made to its
10 // state, such that we can revert the state when needed.
11 //
12 // Tracking changes
13 // ----------------
14 // The user needs to call `Tracker::save()` to enable tracking changes
15 // made to SandboxIR. From that point on, any change made to SandboxIR, will
16 // automatically create a change tracking object and register it with the
17 // tracker. IR-change objects are subclasses of `IRChangeBase` and get
18 // registered with the `Tracker::track()` function. The change objects
19 // are saved in the order they are registered with the tracker and are stored in
20 // the `Tracker::Changes` vector. All of this is done transparently to
21 // the user.
22 //
23 // Reverting changes
24 // -----------------
25 // Calling `Tracker::revert()` will restore the state saved when
26 // `Tracker::save()` was called. Internally this goes through the
27 // change objects in `Tracker::Changes` in reverse order, calling their
28 // `IRChangeBase::revert()` function one by one.
29 //
30 // Accepting changes
31 // -----------------
32 // The user needs to either revert or accept changes before the tracker object
33 // is destroyed. This is enforced in the tracker's destructor.
34 // This is the job of `Tracker::accept()`. Internally this will go
35 // through the change objects in `Tracker::Changes` in order, calling
36 // `IRChangeBase::accept()`.
37 //
38 //===----------------------------------------------------------------------===//
39 
40 #ifndef LLVM_SANDBOXIR_TRACKER_H
41 #define LLVM_SANDBOXIR_TRACKER_H
42 
43 #include "llvm/ADT/PointerUnion.h"
44 #include "llvm/ADT/SmallVector.h"
45 #include "llvm/ADT/StableHashing.h"
46 #include "llvm/IR/IRBuilder.h"
47 #include "llvm/IR/Instruction.h"
48 #include "llvm/SandboxIR/Use.h"
49 #include "llvm/SandboxIR/Value.h"
50 #include "llvm/Support/Compiler.h"
51 #include "llvm/Support/Debug.h"
52 #include <memory>
53 
54 namespace llvm::sandboxir {
55 
56 class BasicBlock;
57 class CallBrInst;
58 class LoadInst;
59 class StoreInst;
60 class Instruction;
61 class Tracker;
62 class AllocaInst;
63 class CatchSwitchInst;
64 class SwitchInst;
65 class ConstantInt;
66 class ShuffleVectorInst;
67 class CmpInst;
68 class GlobalVariable;
69 
70 #ifndef NDEBUG
71 
72 /// A class that saves hashes and textual IR snapshots of functions in a
73 /// SandboxIR Context, and does hash comparison when `expectNoDiff` is called.
74 /// If hashes differ, it prints textual IR for both old and new versions to
75 /// aid debugging.
76 ///
77 /// This is used as an additional debug check when reverting changes to
78 /// SandboxIR, to verify the reverted state matches the initial state.
79 class IRSnapshotChecker {
80   Context &Ctx;
81 
82   // A snapshot of textual IR for a function, with a hash for quick comparison.
83   struct FunctionSnapshot {
84     llvm::stable_hash Hash;
85     std::string TextualIR;
86   };
87 
88   // A snapshot for each llvm::Function found in every module in the SandboxIR
89   // Context. In practice there will always be one module, but sandbox IR
90   // save/restore ops work at the Context level, so we must take the full state
91   // into account.
92   using ContextSnapshot = DenseMap<const llvm::Function *, FunctionSnapshot>;
93 
94   ContextSnapshot OrigContextSnapshot;
95 
96   // Dumps to a string the textual IR for a single Function.
97   std::string dumpIR(const llvm::Function &F) const;
98 
99   // Returns a snapshot of all the modules in the sandbox IR context.
100   ContextSnapshot takeSnapshot() const;
101 
102   // Compares two snapshots and returns true if they differ.
103   bool diff(const ContextSnapshot &Orig, const ContextSnapshot &Curr) const;
104 
105 public:
IRSnapshotChecker(Context & Ctx)106   IRSnapshotChecker(Context &Ctx) : Ctx(Ctx) {}
107 
108   /// Saves a snapshot of the current state. If there was any previous snapshot,
109   /// it will be replaced with the new one.
110   void save();
111 
112   /// Checks current state against saved state, crashes if different.
113   void expectNoDiff();
114 };
115 
116 #endif // NDEBUG
117 
118 /// The base class for IR Change classes.
119 class IRChangeBase {
120 protected:
121   friend class Tracker; // For Parent.
122 
123 public:
124   /// This runs when changes get reverted.
125   virtual void revert(Tracker &Tracker) = 0;
126   /// This runs when changes get accepted.
127   virtual void accept() = 0;
128   virtual ~IRChangeBase() = default;
129 #ifndef NDEBUG
130   virtual void dump(raw_ostream &OS) const = 0;
131   LLVM_DUMP_METHOD virtual void dump() const = 0;
132   friend raw_ostream &operator<<(raw_ostream &OS, const IRChangeBase &C) {
133     C.dump(OS);
134     return OS;
135   }
136 #endif
137 };
138 
139 /// Tracks the change of the source Value of a sandboxir::Use.
140 class UseSet : public IRChangeBase {
141   Use U;
142   Value *OrigV = nullptr;
143 
144 public:
UseSet(const Use & U)145   UseSet(const Use &U) : U(U), OrigV(U.get()) {}
revert(Tracker & Tracker)146   void revert(Tracker &Tracker) final { U.set(OrigV); }
accept()147   void accept() final {}
148 #ifndef NDEBUG
dump(raw_ostream & OS)149   void dump(raw_ostream &OS) const final { OS << "UseSet"; }
150   LLVM_DUMP_METHOD void dump() const final;
151 #endif
152 };
153 
154 class LLVM_ABI PHIRemoveIncoming : public IRChangeBase {
155   PHINode *PHI;
156   unsigned RemovedIdx;
157   Value *RemovedV;
158   BasicBlock *RemovedBB;
159 
160 public:
161   PHIRemoveIncoming(PHINode *PHI, unsigned RemovedIdx);
162   void revert(Tracker &Tracker) final;
accept()163   void accept() final {}
164 #ifndef NDEBUG
dump(raw_ostream & OS)165   void dump(raw_ostream &OS) const final { OS << "PHISetIncoming"; }
166   LLVM_DUMP_METHOD void dump() const final;
167 #endif
168 };
169 
170 class LLVM_ABI PHIAddIncoming : public IRChangeBase {
171   PHINode *PHI;
172   unsigned Idx;
173 
174 public:
175   PHIAddIncoming(PHINode *PHI);
176   void revert(Tracker &Tracker) final;
accept()177   void accept() final {}
178 #ifndef NDEBUG
dump(raw_ostream & OS)179   void dump(raw_ostream &OS) const final { OS << "PHISetIncoming"; }
180   LLVM_DUMP_METHOD void dump() const final;
181 #endif
182 };
183 
184 class LLVM_ABI CmpSwapOperands : public IRChangeBase {
185   CmpInst *Cmp;
186 
187 public:
188   CmpSwapOperands(CmpInst *Cmp);
189   void revert(Tracker &Tracker) final;
accept()190   void accept() final {}
191 #ifndef NDEBUG
dump(raw_ostream & OS)192   void dump(raw_ostream &OS) const final { OS << "CmpSwapOperands"; }
193   LLVM_DUMP_METHOD void dump() const final;
194 #endif
195 };
196 
197 /// Tracks swapping a Use with another Use.
198 class UseSwap : public IRChangeBase {
199   Use ThisUse;
200   Use OtherUse;
201 
202 public:
UseSwap(const Use & ThisUse,const Use & OtherUse)203   UseSwap(const Use &ThisUse, const Use &OtherUse)
204       : ThisUse(ThisUse), OtherUse(OtherUse) {
205     assert(ThisUse.getUser() == OtherUse.getUser() && "Expected same user!");
206   }
revert(Tracker & Tracker)207   void revert(Tracker &Tracker) final { ThisUse.swap(OtherUse); }
accept()208   void accept() final {}
209 #ifndef NDEBUG
dump(raw_ostream & OS)210   void dump(raw_ostream &OS) const final { OS << "UseSwap"; }
211   LLVM_DUMP_METHOD void dump() const final;
212 #endif
213 };
214 
215 class LLVM_ABI EraseFromParent : public IRChangeBase {
216   /// Contains all the data we need to restore an "erased" (i.e., detached)
217   /// instruction: the instruction itself and its operands in order.
218   struct InstrAndOperands {
219     /// The operands that got dropped.
220     SmallVector<llvm::Value *> Operands;
221     /// The instruction that got "erased".
222     llvm::Instruction *LLVMI;
223   };
224   /// The instruction data is in reverse program order, which helps create the
225   /// original program order during revert().
226   SmallVector<InstrAndOperands> InstrData;
227   /// This is either the next Instruction in the stream, or the parent
228   /// BasicBlock if at the end of the BB.
229   PointerUnion<llvm::Instruction *, llvm::BasicBlock *> NextLLVMIOrBB;
230   /// We take ownership of the "erased" instruction.
231   std::unique_ptr<sandboxir::Value> ErasedIPtr;
232 
233 public:
234   EraseFromParent(std::unique_ptr<sandboxir::Value> &&IPtr);
235   void revert(Tracker &Tracker) final;
236   void accept() final;
237 #ifndef NDEBUG
dump(raw_ostream & OS)238   void dump(raw_ostream &OS) const final { OS << "EraseFromParent"; }
239   LLVM_DUMP_METHOD void dump() const final;
240   friend raw_ostream &operator<<(raw_ostream &OS, const EraseFromParent &C) {
241     C.dump(OS);
242     return OS;
243   }
244 #endif
245 };
246 
247 class LLVM_ABI RemoveFromParent : public IRChangeBase {
248   /// The instruction that is about to get removed.
249   Instruction *RemovedI = nullptr;
250   /// This is either the next instr, or the parent BB if at the end of the BB.
251   PointerUnion<Instruction *, BasicBlock *> NextInstrOrBB;
252 
253 public:
254   RemoveFromParent(Instruction *RemovedI);
255   void revert(Tracker &Tracker) final;
accept()256   void accept() final {};
getInstruction()257   Instruction *getInstruction() const { return RemovedI; }
258 #ifndef NDEBUG
dump(raw_ostream & OS)259   void dump(raw_ostream &OS) const final { OS << "RemoveFromParent"; }
260   LLVM_DUMP_METHOD void dump() const final;
261 #endif // NDEBUG
262 };
263 
264 /// This class can be used for tracking most instruction setters.
265 /// The two template arguments are:
266 /// - GetterFn: The getter member function pointer (e.g., `&Foo::get`)
267 /// - SetterFn: The setter member function pointer (e.g., `&Foo::set`)
268 /// Upon construction, it saves a copy of the original value by calling the
269 /// getter function. Revert sets the value back to the one saved, using the
270 /// setter function provided.
271 ///
272 /// Example:
273 ///  Tracker.track(std::make_unique<
274 ///                GenericSetter<&FooInst::get, &FooInst::set>>(I, Tracker));
275 ///
276 template <auto GetterFn, auto SetterFn>
277 class GenericSetter final : public IRChangeBase {
278   /// Traits for getting the class type from GetterFn type.
279   template <typename> struct GetClassTypeFromGetter;
280   template <typename RetT, typename ClassT>
281   struct GetClassTypeFromGetter<RetT (ClassT::*)() const> {
282     using ClassType = ClassT;
283   };
284   using InstrT = typename GetClassTypeFromGetter<decltype(GetterFn)>::ClassType;
285   using SavedValT = std::invoke_result_t<decltype(GetterFn), InstrT>;
286   InstrT *I;
287   SavedValT OrigVal;
288 
289 public:
290   GenericSetter(InstrT *I) : I(I), OrigVal((I->*GetterFn)()) {}
291   void revert(Tracker &Tracker) final { (I->*SetterFn)(OrigVal); }
292   void accept() final {}
293 #ifndef NDEBUG
294   void dump(raw_ostream &OS) const final { OS << "GenericSetter"; }
295   LLVM_DUMP_METHOD void dump() const final {
296     dump(dbgs());
297     dbgs() << "\n";
298   }
299 #endif
300 };
301 
302 /// Similar to GenericSetter but the setters/getters have an index as their
303 /// first argument. This is commont in cases like: getOperand(unsigned Idx)
304 template <auto GetterFn, auto SetterFn>
305 class GenericSetterWithIdx final : public IRChangeBase {
306   /// Helper for getting the class type from the getter
307   template <typename ClassT, typename RetT>
308   static ClassT getClassTypeFromGetter(RetT (ClassT::*Fn)(unsigned) const);
309   template <typename ClassT, typename RetT>
310   static ClassT getClassTypeFromGetter(RetT (ClassT::*Fn)(unsigned));
311 
312   using InstrT = decltype(getClassTypeFromGetter(GetterFn));
313   using SavedValT = std::invoke_result_t<decltype(GetterFn), InstrT, unsigned>;
314   InstrT *I;
315   SavedValT OrigVal;
316   unsigned Idx;
317 
318 public:
319   GenericSetterWithIdx(InstrT *I, unsigned Idx)
320       : I(I), OrigVal((I->*GetterFn)(Idx)), Idx(Idx) {}
321   void revert(Tracker &Tracker) final { (I->*SetterFn)(Idx, OrigVal); }
322   void accept() final {}
323 #ifndef NDEBUG
324   void dump(raw_ostream &OS) const final { OS << "GenericSetterWithIdx"; }
325   LLVM_DUMP_METHOD void dump() const final {
326     dump(dbgs());
327     dbgs() << "\n";
328   }
329 #endif
330 };
331 
332 class LLVM_ABI CatchSwitchAddHandler : public IRChangeBase {
333   CatchSwitchInst *CSI;
334   unsigned HandlerIdx;
335 
336 public:
337   CatchSwitchAddHandler(CatchSwitchInst *CSI);
338   void revert(Tracker &Tracker) final;
339   void accept() final {}
340 #ifndef NDEBUG
341   void dump(raw_ostream &OS) const final { OS << "CatchSwitchAddHandler"; }
342   LLVM_DUMP_METHOD void dump() const final {
343     dump(dbgs());
344     dbgs() << "\n";
345   }
346 #endif // NDEBUG
347 };
348 
349 class LLVM_ABI SwitchAddCase : public IRChangeBase {
350   SwitchInst *Switch;
351   ConstantInt *Val;
352 
353 public:
354   SwitchAddCase(SwitchInst *Switch, ConstantInt *Val)
355       : Switch(Switch), Val(Val) {}
356   void revert(Tracker &Tracker) final;
357   void accept() final {}
358 #ifndef NDEBUG
359   void dump(raw_ostream &OS) const final { OS << "SwitchAddCase"; }
360   LLVM_DUMP_METHOD void dump() const final;
361 #endif // NDEBUG
362 };
363 
364 class LLVM_ABI SwitchRemoveCase : public IRChangeBase {
365   SwitchInst *Switch;
366   struct Case {
367     ConstantInt *Val;
368     BasicBlock *Dest;
369   };
370   SmallVector<Case> Cases;
371 
372 public:
373   SwitchRemoveCase(SwitchInst *Switch);
374 
375   void revert(Tracker &Tracker) final;
376   void accept() final {}
377 #ifndef NDEBUG
378   void dump(raw_ostream &OS) const final { OS << "SwitchRemoveCase"; }
379   LLVM_DUMP_METHOD void dump() const final;
380 #endif // NDEBUG
381 };
382 
383 class LLVM_ABI MoveInstr : public IRChangeBase {
384   /// The instruction that moved.
385   Instruction *MovedI;
386   /// This is either the next instruction in the block, or the parent BB if at
387   /// the end of the BB.
388   PointerUnion<Instruction *, BasicBlock *> NextInstrOrBB;
389 
390 public:
391   MoveInstr(sandboxir::Instruction *I);
392   void revert(Tracker &Tracker) final;
393   void accept() final {}
394 #ifndef NDEBUG
395   void dump(raw_ostream &OS) const final { OS << "MoveInstr"; }
396   LLVM_DUMP_METHOD void dump() const final;
397 #endif // NDEBUG
398 };
399 
400 class LLVM_ABI InsertIntoBB final : public IRChangeBase {
401   Instruction *InsertedI = nullptr;
402 
403 public:
404   InsertIntoBB(Instruction *InsertedI);
405   void revert(Tracker &Tracker) final;
406   void accept() final {}
407 #ifndef NDEBUG
408   void dump(raw_ostream &OS) const final { OS << "InsertIntoBB"; }
409   LLVM_DUMP_METHOD void dump() const final;
410 #endif // NDEBUG
411 };
412 
413 class LLVM_ABI CreateAndInsertInst final : public IRChangeBase {
414   Instruction *NewI = nullptr;
415 
416 public:
417   CreateAndInsertInst(Instruction *NewI) : NewI(NewI) {}
418   void revert(Tracker &Tracker) final;
419   void accept() final {}
420 #ifndef NDEBUG
421   void dump(raw_ostream &OS) const final { OS << "CreateAndInsertInst"; }
422   LLVM_DUMP_METHOD void dump() const final;
423 #endif
424 };
425 
426 class LLVM_ABI ShuffleVectorSetMask final : public IRChangeBase {
427   ShuffleVectorInst *SVI;
428   SmallVector<int, 8> PrevMask;
429 
430 public:
431   ShuffleVectorSetMask(ShuffleVectorInst *SVI);
432   void revert(Tracker &Tracker) final;
433   void accept() final {}
434 #ifndef NDEBUG
435   void dump(raw_ostream &OS) const final { OS << "ShuffleVectorSetMask"; }
436   LLVM_DUMP_METHOD void dump() const final;
437 #endif
438 };
439 
440 /// The tracker collects all the change objects and implements the main API for
441 /// saving / reverting / accepting.
442 class Tracker {
443 public:
444   enum class TrackerState {
445     Disabled,  ///> Tracking is disabled
446     Record,    ///> Tracking changes
447     Reverting, ///> Reverting changes
448   };
449 
450 private:
451   /// The list of changes that are being tracked.
452   SmallVector<std::unique_ptr<IRChangeBase>> Changes;
453   /// The current state of the tracker.
454   TrackerState State = TrackerState::Disabled;
455   Context &Ctx;
456 
457 #ifndef NDEBUG
458   IRSnapshotChecker SnapshotChecker;
459 #endif
460 
461 public:
462 #ifndef NDEBUG
463   /// Helps catch bugs where we are creating new change objects while in the
464   /// middle of creating other change objects.
465   bool InMiddleOfCreatingChange = false;
466 #endif // NDEBUG
467 
468   explicit Tracker(Context &Ctx)
469       : Ctx(Ctx)
470 #ifndef NDEBUG
471         ,
472         SnapshotChecker(Ctx)
473 #endif
474   {
475   }
476 
477   LLVM_ABI ~Tracker();
478   Context &getContext() const { return Ctx; }
479   /// \Returns true if there are no changes tracked.
480   bool empty() const { return Changes.empty(); }
481   /// Record \p Change and take ownership. This is the main function used to
482   /// track Sandbox IR changes.
483   void track(std::unique_ptr<IRChangeBase> &&Change) {
484     assert(State == TrackerState::Record && "The tracker should be tracking!");
485 #ifndef NDEBUG
486     assert(!InMiddleOfCreatingChange &&
487            "We are in the middle of creating another change!");
488     if (isTracking())
489       InMiddleOfCreatingChange = true;
490 #endif // NDEBUG
491     Changes.push_back(std::move(Change));
492 
493 #ifndef NDEBUG
494     InMiddleOfCreatingChange = false;
495 #endif
496   }
497   /// A convenience wrapper for `track()` that constructs and tracks the Change
498   /// object if tracking is enabled. \Returns true if tracking is enabled.
499   template <typename ChangeT, typename... ArgsT>
500   bool emplaceIfTracking(ArgsT... Args) {
501     if (!isTracking())
502       return false;
503     track(std::make_unique<ChangeT>(Args...));
504     return true;
505   }
506   /// \Returns true if the tracker is recording changes.
507   bool isTracking() const { return State == TrackerState::Record; }
508   /// \Returns the current state of the tracker.
509   TrackerState getState() const { return State; }
510   /// Turns on IR tracking.
511   LLVM_ABI void save();
512   /// Stops tracking and accept changes.
513   LLVM_ABI void accept();
514   /// Stops tracking and reverts to saved state.
515   LLVM_ABI void revert();
516 
517 #ifndef NDEBUG
518   void dump(raw_ostream &OS) const;
519   LLVM_DUMP_METHOD void dump() const;
520   friend raw_ostream &operator<<(raw_ostream &OS, const Tracker &Tracker) {
521     Tracker.dump(OS);
522     return OS;
523   }
524 #endif // NDEBUG
525 };
526 
527 } // namespace llvm::sandboxir
528 
529 #endif // LLVM_SANDBOXIR_TRACKER_H
530