1 //===- Tracker.h ------------------------------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file is the component of SandboxIR that tracks all changes made to its 10 // state, such that we can revert the state when needed. 11 // 12 // Tracking changes 13 // ---------------- 14 // The user needs to call `Tracker::save()` to enable tracking changes 15 // made to SandboxIR. From that point on, any change made to SandboxIR, will 16 // automatically create a change tracking object and register it with the 17 // tracker. IR-change objects are subclasses of `IRChangeBase` and get 18 // registered with the `Tracker::track()` function. The change objects 19 // are saved in the order they are registered with the tracker and are stored in 20 // the `Tracker::Changes` vector. All of this is done transparently to 21 // the user. 22 // 23 // Reverting changes 24 // ----------------- 25 // Calling `Tracker::revert()` will restore the state saved when 26 // `Tracker::save()` was called. Internally this goes through the 27 // change objects in `Tracker::Changes` in reverse order, calling their 28 // `IRChangeBase::revert()` function one by one. 29 // 30 // Accepting changes 31 // ----------------- 32 // The user needs to either revert or accept changes before the tracker object 33 // is destroyed. This is enforced in the tracker's destructor. 34 // This is the job of `Tracker::accept()`. Internally this will go 35 // through the change objects in `Tracker::Changes` in order, calling 36 // `IRChangeBase::accept()`. 37 // 38 //===----------------------------------------------------------------------===// 39 40 #ifndef LLVM_SANDBOXIR_TRACKER_H 41 #define LLVM_SANDBOXIR_TRACKER_H 42 43 #include "llvm/ADT/PointerUnion.h" 44 #include "llvm/ADT/SmallVector.h" 45 #include "llvm/ADT/StableHashing.h" 46 #include "llvm/IR/IRBuilder.h" 47 #include "llvm/IR/Instruction.h" 48 #include "llvm/SandboxIR/Use.h" 49 #include "llvm/SandboxIR/Value.h" 50 #include "llvm/Support/Compiler.h" 51 #include "llvm/Support/Debug.h" 52 #include <memory> 53 54 namespace llvm::sandboxir { 55 56 class BasicBlock; 57 class CallBrInst; 58 class LoadInst; 59 class StoreInst; 60 class Instruction; 61 class Tracker; 62 class AllocaInst; 63 class CatchSwitchInst; 64 class SwitchInst; 65 class ConstantInt; 66 class ShuffleVectorInst; 67 class CmpInst; 68 class GlobalVariable; 69 70 #ifndef NDEBUG 71 72 /// A class that saves hashes and textual IR snapshots of functions in a 73 /// SandboxIR Context, and does hash comparison when `expectNoDiff` is called. 74 /// If hashes differ, it prints textual IR for both old and new versions to 75 /// aid debugging. 76 /// 77 /// This is used as an additional debug check when reverting changes to 78 /// SandboxIR, to verify the reverted state matches the initial state. 79 class IRSnapshotChecker { 80 Context &Ctx; 81 82 // A snapshot of textual IR for a function, with a hash for quick comparison. 83 struct FunctionSnapshot { 84 llvm::stable_hash Hash; 85 std::string TextualIR; 86 }; 87 88 // A snapshot for each llvm::Function found in every module in the SandboxIR 89 // Context. In practice there will always be one module, but sandbox IR 90 // save/restore ops work at the Context level, so we must take the full state 91 // into account. 92 using ContextSnapshot = DenseMap<const llvm::Function *, FunctionSnapshot>; 93 94 ContextSnapshot OrigContextSnapshot; 95 96 // Dumps to a string the textual IR for a single Function. 97 std::string dumpIR(const llvm::Function &F) const; 98 99 // Returns a snapshot of all the modules in the sandbox IR context. 100 ContextSnapshot takeSnapshot() const; 101 102 // Compares two snapshots and returns true if they differ. 103 bool diff(const ContextSnapshot &Orig, const ContextSnapshot &Curr) const; 104 105 public: IRSnapshotChecker(Context & Ctx)106 IRSnapshotChecker(Context &Ctx) : Ctx(Ctx) {} 107 108 /// Saves a snapshot of the current state. If there was any previous snapshot, 109 /// it will be replaced with the new one. 110 void save(); 111 112 /// Checks current state against saved state, crashes if different. 113 void expectNoDiff(); 114 }; 115 116 #endif // NDEBUG 117 118 /// The base class for IR Change classes. 119 class IRChangeBase { 120 protected: 121 friend class Tracker; // For Parent. 122 123 public: 124 /// This runs when changes get reverted. 125 virtual void revert(Tracker &Tracker) = 0; 126 /// This runs when changes get accepted. 127 virtual void accept() = 0; 128 virtual ~IRChangeBase() = default; 129 #ifndef NDEBUG 130 virtual void dump(raw_ostream &OS) const = 0; 131 LLVM_DUMP_METHOD virtual void dump() const = 0; 132 friend raw_ostream &operator<<(raw_ostream &OS, const IRChangeBase &C) { 133 C.dump(OS); 134 return OS; 135 } 136 #endif 137 }; 138 139 /// Tracks the change of the source Value of a sandboxir::Use. 140 class UseSet : public IRChangeBase { 141 Use U; 142 Value *OrigV = nullptr; 143 144 public: UseSet(const Use & U)145 UseSet(const Use &U) : U(U), OrigV(U.get()) {} revert(Tracker & Tracker)146 void revert(Tracker &Tracker) final { U.set(OrigV); } accept()147 void accept() final {} 148 #ifndef NDEBUG dump(raw_ostream & OS)149 void dump(raw_ostream &OS) const final { OS << "UseSet"; } 150 LLVM_DUMP_METHOD void dump() const final; 151 #endif 152 }; 153 154 class LLVM_ABI PHIRemoveIncoming : public IRChangeBase { 155 PHINode *PHI; 156 unsigned RemovedIdx; 157 Value *RemovedV; 158 BasicBlock *RemovedBB; 159 160 public: 161 PHIRemoveIncoming(PHINode *PHI, unsigned RemovedIdx); 162 void revert(Tracker &Tracker) final; accept()163 void accept() final {} 164 #ifndef NDEBUG dump(raw_ostream & OS)165 void dump(raw_ostream &OS) const final { OS << "PHISetIncoming"; } 166 LLVM_DUMP_METHOD void dump() const final; 167 #endif 168 }; 169 170 class LLVM_ABI PHIAddIncoming : public IRChangeBase { 171 PHINode *PHI; 172 unsigned Idx; 173 174 public: 175 PHIAddIncoming(PHINode *PHI); 176 void revert(Tracker &Tracker) final; accept()177 void accept() final {} 178 #ifndef NDEBUG dump(raw_ostream & OS)179 void dump(raw_ostream &OS) const final { OS << "PHISetIncoming"; } 180 LLVM_DUMP_METHOD void dump() const final; 181 #endif 182 }; 183 184 class LLVM_ABI CmpSwapOperands : public IRChangeBase { 185 CmpInst *Cmp; 186 187 public: 188 CmpSwapOperands(CmpInst *Cmp); 189 void revert(Tracker &Tracker) final; accept()190 void accept() final {} 191 #ifndef NDEBUG dump(raw_ostream & OS)192 void dump(raw_ostream &OS) const final { OS << "CmpSwapOperands"; } 193 LLVM_DUMP_METHOD void dump() const final; 194 #endif 195 }; 196 197 /// Tracks swapping a Use with another Use. 198 class UseSwap : public IRChangeBase { 199 Use ThisUse; 200 Use OtherUse; 201 202 public: UseSwap(const Use & ThisUse,const Use & OtherUse)203 UseSwap(const Use &ThisUse, const Use &OtherUse) 204 : ThisUse(ThisUse), OtherUse(OtherUse) { 205 assert(ThisUse.getUser() == OtherUse.getUser() && "Expected same user!"); 206 } revert(Tracker & Tracker)207 void revert(Tracker &Tracker) final { ThisUse.swap(OtherUse); } accept()208 void accept() final {} 209 #ifndef NDEBUG dump(raw_ostream & OS)210 void dump(raw_ostream &OS) const final { OS << "UseSwap"; } 211 LLVM_DUMP_METHOD void dump() const final; 212 #endif 213 }; 214 215 class LLVM_ABI EraseFromParent : public IRChangeBase { 216 /// Contains all the data we need to restore an "erased" (i.e., detached) 217 /// instruction: the instruction itself and its operands in order. 218 struct InstrAndOperands { 219 /// The operands that got dropped. 220 SmallVector<llvm::Value *> Operands; 221 /// The instruction that got "erased". 222 llvm::Instruction *LLVMI; 223 }; 224 /// The instruction data is in reverse program order, which helps create the 225 /// original program order during revert(). 226 SmallVector<InstrAndOperands> InstrData; 227 /// This is either the next Instruction in the stream, or the parent 228 /// BasicBlock if at the end of the BB. 229 PointerUnion<llvm::Instruction *, llvm::BasicBlock *> NextLLVMIOrBB; 230 /// We take ownership of the "erased" instruction. 231 std::unique_ptr<sandboxir::Value> ErasedIPtr; 232 233 public: 234 EraseFromParent(std::unique_ptr<sandboxir::Value> &&IPtr); 235 void revert(Tracker &Tracker) final; 236 void accept() final; 237 #ifndef NDEBUG dump(raw_ostream & OS)238 void dump(raw_ostream &OS) const final { OS << "EraseFromParent"; } 239 LLVM_DUMP_METHOD void dump() const final; 240 friend raw_ostream &operator<<(raw_ostream &OS, const EraseFromParent &C) { 241 C.dump(OS); 242 return OS; 243 } 244 #endif 245 }; 246 247 class LLVM_ABI RemoveFromParent : public IRChangeBase { 248 /// The instruction that is about to get removed. 249 Instruction *RemovedI = nullptr; 250 /// This is either the next instr, or the parent BB if at the end of the BB. 251 PointerUnion<Instruction *, BasicBlock *> NextInstrOrBB; 252 253 public: 254 RemoveFromParent(Instruction *RemovedI); 255 void revert(Tracker &Tracker) final; accept()256 void accept() final {}; getInstruction()257 Instruction *getInstruction() const { return RemovedI; } 258 #ifndef NDEBUG dump(raw_ostream & OS)259 void dump(raw_ostream &OS) const final { OS << "RemoveFromParent"; } 260 LLVM_DUMP_METHOD void dump() const final; 261 #endif // NDEBUG 262 }; 263 264 /// This class can be used for tracking most instruction setters. 265 /// The two template arguments are: 266 /// - GetterFn: The getter member function pointer (e.g., `&Foo::get`) 267 /// - SetterFn: The setter member function pointer (e.g., `&Foo::set`) 268 /// Upon construction, it saves a copy of the original value by calling the 269 /// getter function. Revert sets the value back to the one saved, using the 270 /// setter function provided. 271 /// 272 /// Example: 273 /// Tracker.track(std::make_unique< 274 /// GenericSetter<&FooInst::get, &FooInst::set>>(I, Tracker)); 275 /// 276 template <auto GetterFn, auto SetterFn> 277 class GenericSetter final : public IRChangeBase { 278 /// Traits for getting the class type from GetterFn type. 279 template <typename> struct GetClassTypeFromGetter; 280 template <typename RetT, typename ClassT> 281 struct GetClassTypeFromGetter<RetT (ClassT::*)() const> { 282 using ClassType = ClassT; 283 }; 284 using InstrT = typename GetClassTypeFromGetter<decltype(GetterFn)>::ClassType; 285 using SavedValT = std::invoke_result_t<decltype(GetterFn), InstrT>; 286 InstrT *I; 287 SavedValT OrigVal; 288 289 public: 290 GenericSetter(InstrT *I) : I(I), OrigVal((I->*GetterFn)()) {} 291 void revert(Tracker &Tracker) final { (I->*SetterFn)(OrigVal); } 292 void accept() final {} 293 #ifndef NDEBUG 294 void dump(raw_ostream &OS) const final { OS << "GenericSetter"; } 295 LLVM_DUMP_METHOD void dump() const final { 296 dump(dbgs()); 297 dbgs() << "\n"; 298 } 299 #endif 300 }; 301 302 /// Similar to GenericSetter but the setters/getters have an index as their 303 /// first argument. This is commont in cases like: getOperand(unsigned Idx) 304 template <auto GetterFn, auto SetterFn> 305 class GenericSetterWithIdx final : public IRChangeBase { 306 /// Helper for getting the class type from the getter 307 template <typename ClassT, typename RetT> 308 static ClassT getClassTypeFromGetter(RetT (ClassT::*Fn)(unsigned) const); 309 template <typename ClassT, typename RetT> 310 static ClassT getClassTypeFromGetter(RetT (ClassT::*Fn)(unsigned)); 311 312 using InstrT = decltype(getClassTypeFromGetter(GetterFn)); 313 using SavedValT = std::invoke_result_t<decltype(GetterFn), InstrT, unsigned>; 314 InstrT *I; 315 SavedValT OrigVal; 316 unsigned Idx; 317 318 public: 319 GenericSetterWithIdx(InstrT *I, unsigned Idx) 320 : I(I), OrigVal((I->*GetterFn)(Idx)), Idx(Idx) {} 321 void revert(Tracker &Tracker) final { (I->*SetterFn)(Idx, OrigVal); } 322 void accept() final {} 323 #ifndef NDEBUG 324 void dump(raw_ostream &OS) const final { OS << "GenericSetterWithIdx"; } 325 LLVM_DUMP_METHOD void dump() const final { 326 dump(dbgs()); 327 dbgs() << "\n"; 328 } 329 #endif 330 }; 331 332 class LLVM_ABI CatchSwitchAddHandler : public IRChangeBase { 333 CatchSwitchInst *CSI; 334 unsigned HandlerIdx; 335 336 public: 337 CatchSwitchAddHandler(CatchSwitchInst *CSI); 338 void revert(Tracker &Tracker) final; 339 void accept() final {} 340 #ifndef NDEBUG 341 void dump(raw_ostream &OS) const final { OS << "CatchSwitchAddHandler"; } 342 LLVM_DUMP_METHOD void dump() const final { 343 dump(dbgs()); 344 dbgs() << "\n"; 345 } 346 #endif // NDEBUG 347 }; 348 349 class LLVM_ABI SwitchAddCase : public IRChangeBase { 350 SwitchInst *Switch; 351 ConstantInt *Val; 352 353 public: 354 SwitchAddCase(SwitchInst *Switch, ConstantInt *Val) 355 : Switch(Switch), Val(Val) {} 356 void revert(Tracker &Tracker) final; 357 void accept() final {} 358 #ifndef NDEBUG 359 void dump(raw_ostream &OS) const final { OS << "SwitchAddCase"; } 360 LLVM_DUMP_METHOD void dump() const final; 361 #endif // NDEBUG 362 }; 363 364 class LLVM_ABI SwitchRemoveCase : public IRChangeBase { 365 SwitchInst *Switch; 366 struct Case { 367 ConstantInt *Val; 368 BasicBlock *Dest; 369 }; 370 SmallVector<Case> Cases; 371 372 public: 373 SwitchRemoveCase(SwitchInst *Switch); 374 375 void revert(Tracker &Tracker) final; 376 void accept() final {} 377 #ifndef NDEBUG 378 void dump(raw_ostream &OS) const final { OS << "SwitchRemoveCase"; } 379 LLVM_DUMP_METHOD void dump() const final; 380 #endif // NDEBUG 381 }; 382 383 class LLVM_ABI MoveInstr : public IRChangeBase { 384 /// The instruction that moved. 385 Instruction *MovedI; 386 /// This is either the next instruction in the block, or the parent BB if at 387 /// the end of the BB. 388 PointerUnion<Instruction *, BasicBlock *> NextInstrOrBB; 389 390 public: 391 MoveInstr(sandboxir::Instruction *I); 392 void revert(Tracker &Tracker) final; 393 void accept() final {} 394 #ifndef NDEBUG 395 void dump(raw_ostream &OS) const final { OS << "MoveInstr"; } 396 LLVM_DUMP_METHOD void dump() const final; 397 #endif // NDEBUG 398 }; 399 400 class LLVM_ABI InsertIntoBB final : public IRChangeBase { 401 Instruction *InsertedI = nullptr; 402 403 public: 404 InsertIntoBB(Instruction *InsertedI); 405 void revert(Tracker &Tracker) final; 406 void accept() final {} 407 #ifndef NDEBUG 408 void dump(raw_ostream &OS) const final { OS << "InsertIntoBB"; } 409 LLVM_DUMP_METHOD void dump() const final; 410 #endif // NDEBUG 411 }; 412 413 class LLVM_ABI CreateAndInsertInst final : public IRChangeBase { 414 Instruction *NewI = nullptr; 415 416 public: 417 CreateAndInsertInst(Instruction *NewI) : NewI(NewI) {} 418 void revert(Tracker &Tracker) final; 419 void accept() final {} 420 #ifndef NDEBUG 421 void dump(raw_ostream &OS) const final { OS << "CreateAndInsertInst"; } 422 LLVM_DUMP_METHOD void dump() const final; 423 #endif 424 }; 425 426 class LLVM_ABI ShuffleVectorSetMask final : public IRChangeBase { 427 ShuffleVectorInst *SVI; 428 SmallVector<int, 8> PrevMask; 429 430 public: 431 ShuffleVectorSetMask(ShuffleVectorInst *SVI); 432 void revert(Tracker &Tracker) final; 433 void accept() final {} 434 #ifndef NDEBUG 435 void dump(raw_ostream &OS) const final { OS << "ShuffleVectorSetMask"; } 436 LLVM_DUMP_METHOD void dump() const final; 437 #endif 438 }; 439 440 /// The tracker collects all the change objects and implements the main API for 441 /// saving / reverting / accepting. 442 class Tracker { 443 public: 444 enum class TrackerState { 445 Disabled, ///> Tracking is disabled 446 Record, ///> Tracking changes 447 Reverting, ///> Reverting changes 448 }; 449 450 private: 451 /// The list of changes that are being tracked. 452 SmallVector<std::unique_ptr<IRChangeBase>> Changes; 453 /// The current state of the tracker. 454 TrackerState State = TrackerState::Disabled; 455 Context &Ctx; 456 457 #ifndef NDEBUG 458 IRSnapshotChecker SnapshotChecker; 459 #endif 460 461 public: 462 #ifndef NDEBUG 463 /// Helps catch bugs where we are creating new change objects while in the 464 /// middle of creating other change objects. 465 bool InMiddleOfCreatingChange = false; 466 #endif // NDEBUG 467 468 explicit Tracker(Context &Ctx) 469 : Ctx(Ctx) 470 #ifndef NDEBUG 471 , 472 SnapshotChecker(Ctx) 473 #endif 474 { 475 } 476 477 LLVM_ABI ~Tracker(); 478 Context &getContext() const { return Ctx; } 479 /// \Returns true if there are no changes tracked. 480 bool empty() const { return Changes.empty(); } 481 /// Record \p Change and take ownership. This is the main function used to 482 /// track Sandbox IR changes. 483 void track(std::unique_ptr<IRChangeBase> &&Change) { 484 assert(State == TrackerState::Record && "The tracker should be tracking!"); 485 #ifndef NDEBUG 486 assert(!InMiddleOfCreatingChange && 487 "We are in the middle of creating another change!"); 488 if (isTracking()) 489 InMiddleOfCreatingChange = true; 490 #endif // NDEBUG 491 Changes.push_back(std::move(Change)); 492 493 #ifndef NDEBUG 494 InMiddleOfCreatingChange = false; 495 #endif 496 } 497 /// A convenience wrapper for `track()` that constructs and tracks the Change 498 /// object if tracking is enabled. \Returns true if tracking is enabled. 499 template <typename ChangeT, typename... ArgsT> 500 bool emplaceIfTracking(ArgsT... Args) { 501 if (!isTracking()) 502 return false; 503 track(std::make_unique<ChangeT>(Args...)); 504 return true; 505 } 506 /// \Returns true if the tracker is recording changes. 507 bool isTracking() const { return State == TrackerState::Record; } 508 /// \Returns the current state of the tracker. 509 TrackerState getState() const { return State; } 510 /// Turns on IR tracking. 511 LLVM_ABI void save(); 512 /// Stops tracking and accept changes. 513 LLVM_ABI void accept(); 514 /// Stops tracking and reverts to saved state. 515 LLVM_ABI void revert(); 516 517 #ifndef NDEBUG 518 void dump(raw_ostream &OS) const; 519 LLVM_DUMP_METHOD void dump() const; 520 friend raw_ostream &operator<<(raw_ostream &OS, const Tracker &Tracker) { 521 Tracker.dump(OS); 522 return OS; 523 } 524 #endif // NDEBUG 525 }; 526 527 } // namespace llvm::sandboxir 528 529 #endif // LLVM_SANDBOXIR_TRACKER_H 530