1 //===- Tracker.cpp --------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/SandboxIR/Tracker.h" 10 #include "llvm/ADT/STLExtras.h" 11 #include "llvm/IR/BasicBlock.h" 12 #include "llvm/IR/Instruction.h" 13 #include "llvm/IR/StructuralHash.h" 14 #include "llvm/SandboxIR/Instruction.h" 15 16 using namespace llvm::sandboxir; 17 18 #ifndef NDEBUG 19 20 std::string IRSnapshotChecker::dumpIR(const llvm::Function &F) const { 21 std::string Result; 22 raw_string_ostream SS(Result); 23 F.print(SS, /*AssemblyAnnotationWriter=*/nullptr); 24 return Result; 25 } 26 27 IRSnapshotChecker::ContextSnapshot IRSnapshotChecker::takeSnapshot() const { 28 ContextSnapshot Result; 29 for (const auto &Entry : Ctx.LLVMModuleToModuleMap) 30 for (const auto &F : *Entry.first) { 31 FunctionSnapshot Snapshot; 32 Snapshot.Hash = StructuralHash(F, /*DetailedHash=*/true); 33 Snapshot.TextualIR = dumpIR(F); 34 Result[&F] = Snapshot; 35 } 36 return Result; 37 } 38 39 bool IRSnapshotChecker::diff(const ContextSnapshot &Orig, 40 const ContextSnapshot &Curr) const { 41 bool DifferenceFound = false; 42 for (const auto &[F, OrigFS] : Orig) { 43 auto CurrFSIt = Curr.find(F); 44 if (CurrFSIt == Curr.end()) { 45 DifferenceFound = true; 46 dbgs() << "Function " << F->getName() << " not found in current IR.\n"; 47 dbgs() << OrigFS.TextualIR << "\n"; 48 continue; 49 } 50 const FunctionSnapshot &CurrFS = CurrFSIt->second; 51 if (OrigFS.Hash != CurrFS.Hash) { 52 DifferenceFound = true; 53 dbgs() << "Found IR difference in Function " << F->getName() << "\n"; 54 dbgs() << "Original:\n" << OrigFS.TextualIR << "\n"; 55 dbgs() << "Current:\n" << CurrFS.TextualIR << "\n"; 56 } 57 } 58 // Check that Curr doesn't contain any new functions. 59 for (const auto &[F, CurrFS] : Curr) { 60 if (!Orig.contains(F)) { 61 DifferenceFound = true; 62 dbgs() << "Function " << F->getName() 63 << " found in current IR but not in original snapshot.\n"; 64 dbgs() << CurrFS.TextualIR << "\n"; 65 } 66 } 67 return DifferenceFound; 68 } 69 70 void IRSnapshotChecker::save() { OrigContextSnapshot = takeSnapshot(); } 71 72 void IRSnapshotChecker::expectNoDiff() { 73 ContextSnapshot CurrContextSnapshot = takeSnapshot(); 74 if (diff(OrigContextSnapshot, CurrContextSnapshot)) { 75 llvm_unreachable( 76 "Original and current IR differ! Probably a checkpointing bug."); 77 } 78 } 79 80 void UseSet::dump() const { 81 dump(dbgs()); 82 dbgs() << "\n"; 83 } 84 85 void UseSwap::dump() const { 86 dump(dbgs()); 87 dbgs() << "\n"; 88 } 89 #endif // NDEBUG 90 91 PHIRemoveIncoming::PHIRemoveIncoming(PHINode *PHI, unsigned RemovedIdx) 92 : PHI(PHI), RemovedIdx(RemovedIdx) { 93 RemovedV = PHI->getIncomingValue(RemovedIdx); 94 RemovedBB = PHI->getIncomingBlock(RemovedIdx); 95 } 96 97 void PHIRemoveIncoming::revert(Tracker &Tracker) { 98 // Special case: if the PHI is now empty, as we don't need to care about the 99 // order of the incoming values. 100 unsigned NumIncoming = PHI->getNumIncomingValues(); 101 if (NumIncoming == 0) { 102 PHI->addIncoming(RemovedV, RemovedBB); 103 return; 104 } 105 // Shift all incoming values by one starting from the end until `Idx`. 106 // Start by adding a copy of the last incoming values. 107 unsigned LastIdx = NumIncoming - 1; 108 PHI->addIncoming(PHI->getIncomingValue(LastIdx), 109 PHI->getIncomingBlock(LastIdx)); 110 for (unsigned Idx = LastIdx; Idx > RemovedIdx; --Idx) { 111 auto *PrevV = PHI->getIncomingValue(Idx - 1); 112 auto *PrevBB = PHI->getIncomingBlock(Idx - 1); 113 PHI->setIncomingValue(Idx, PrevV); 114 PHI->setIncomingBlock(Idx, PrevBB); 115 } 116 PHI->setIncomingValue(RemovedIdx, RemovedV); 117 PHI->setIncomingBlock(RemovedIdx, RemovedBB); 118 } 119 120 #ifndef NDEBUG 121 void PHIRemoveIncoming::dump() const { 122 dump(dbgs()); 123 dbgs() << "\n"; 124 } 125 #endif // NDEBUG 126 127 PHIAddIncoming::PHIAddIncoming(PHINode *PHI) 128 : PHI(PHI), Idx(PHI->getNumIncomingValues()) {} 129 130 void PHIAddIncoming::revert(Tracker &Tracker) { PHI->removeIncomingValue(Idx); } 131 132 #ifndef NDEBUG 133 void PHIAddIncoming::dump() const { 134 dump(dbgs()); 135 dbgs() << "\n"; 136 } 137 #endif // NDEBUG 138 139 Tracker::~Tracker() { 140 assert(Changes.empty() && "You must accept or revert changes!"); 141 } 142 143 EraseFromParent::EraseFromParent(std::unique_ptr<sandboxir::Value> &&ErasedIPtr) 144 : ErasedIPtr(std::move(ErasedIPtr)) { 145 auto *I = cast<Instruction>(this->ErasedIPtr.get()); 146 auto LLVMInstrs = I->getLLVMInstrs(); 147 // Iterate in reverse program order. 148 for (auto *LLVMI : reverse(LLVMInstrs)) { 149 SmallVector<llvm::Value *> Operands; 150 Operands.reserve(LLVMI->getNumOperands()); 151 for (auto [OpNum, Use] : enumerate(LLVMI->operands())) 152 Operands.push_back(Use.get()); 153 InstrData.push_back({Operands, LLVMI}); 154 } 155 assert(is_sorted(InstrData, 156 [](const auto &D0, const auto &D1) { 157 return D0.LLVMI->comesBefore(D1.LLVMI); 158 }) && 159 "Expected reverse program order!"); 160 auto *BotLLVMI = cast<llvm::Instruction>(I->Val); 161 if (BotLLVMI->getNextNode() != nullptr) 162 NextLLVMIOrBB = BotLLVMI->getNextNode(); 163 else 164 NextLLVMIOrBB = BotLLVMI->getParent(); 165 } 166 167 void EraseFromParent::accept() { 168 for (const auto &IData : InstrData) 169 IData.LLVMI->deleteValue(); 170 } 171 172 void EraseFromParent::revert(Tracker &Tracker) { 173 // Place the bottom-most instruction first. 174 auto [Operands, BotLLVMI] = InstrData[0]; 175 if (auto *NextLLVMI = dyn_cast<llvm::Instruction *>(NextLLVMIOrBB)) { 176 BotLLVMI->insertBefore(NextLLVMI->getIterator()); 177 } else { 178 auto *LLVMBB = cast<llvm::BasicBlock *>(NextLLVMIOrBB); 179 BotLLVMI->insertInto(LLVMBB, LLVMBB->end()); 180 } 181 for (auto [OpNum, Op] : enumerate(Operands)) 182 BotLLVMI->setOperand(OpNum, Op); 183 184 // Go over the rest of the instructions and stack them on top. 185 for (auto [Operands, LLVMI] : drop_begin(InstrData)) { 186 LLVMI->insertBefore(BotLLVMI->getIterator()); 187 for (auto [OpNum, Op] : enumerate(Operands)) 188 LLVMI->setOperand(OpNum, Op); 189 BotLLVMI = LLVMI; 190 } 191 Tracker.getContext().registerValue(std::move(ErasedIPtr)); 192 } 193 194 #ifndef NDEBUG 195 void EraseFromParent::dump() const { 196 dump(dbgs()); 197 dbgs() << "\n"; 198 } 199 #endif // NDEBUG 200 201 RemoveFromParent::RemoveFromParent(Instruction *RemovedI) : RemovedI(RemovedI) { 202 if (auto *NextI = RemovedI->getNextNode()) 203 NextInstrOrBB = NextI; 204 else 205 NextInstrOrBB = RemovedI->getParent(); 206 } 207 208 void RemoveFromParent::revert(Tracker &Tracker) { 209 if (auto *NextI = dyn_cast<Instruction *>(NextInstrOrBB)) { 210 RemovedI->insertBefore(NextI); 211 } else { 212 auto *BB = cast<BasicBlock *>(NextInstrOrBB); 213 RemovedI->insertInto(BB, BB->end()); 214 } 215 } 216 217 #ifndef NDEBUG 218 void RemoveFromParent::dump() const { 219 dump(dbgs()); 220 dbgs() << "\n"; 221 } 222 #endif 223 224 CatchSwitchAddHandler::CatchSwitchAddHandler(CatchSwitchInst *CSI) 225 : CSI(CSI), HandlerIdx(CSI->getNumHandlers()) {} 226 227 void CatchSwitchAddHandler::revert(Tracker &Tracker) { 228 // TODO: This should ideally use sandboxir::CatchSwitchInst::removeHandler() 229 // once it gets implemented. 230 auto *LLVMCSI = cast<llvm::CatchSwitchInst>(CSI->Val); 231 LLVMCSI->removeHandler(LLVMCSI->handler_begin() + HandlerIdx); 232 } 233 234 SwitchRemoveCase::SwitchRemoveCase(SwitchInst *Switch) : Switch(Switch) { 235 for (const auto &C : Switch->cases()) 236 Cases.push_back({C.getCaseValue(), C.getCaseSuccessor()}); 237 } 238 239 void SwitchRemoveCase::revert(Tracker &Tracker) { 240 // SwitchInst::removeCase doesn't provide any guarantees about the order of 241 // cases after removal. In order to preserve the original ordering, we save 242 // all of them and, when reverting, clear them all then insert them in the 243 // desired order. This still relies on the fact that `addCase` will insert 244 // them at the end, but it is documented to invalidate `case_end()` so it's 245 // probably okay. 246 unsigned NumCases = Switch->getNumCases(); 247 for (unsigned I = 0; I < NumCases; ++I) 248 Switch->removeCase(Switch->case_begin()); 249 for (auto &Case : Cases) 250 Switch->addCase(Case.Val, Case.Dest); 251 } 252 253 #ifndef NDEBUG 254 void SwitchRemoveCase::dump() const { 255 dump(dbgs()); 256 dbgs() << "\n"; 257 } 258 #endif // NDEBUG 259 260 void SwitchAddCase::revert(Tracker &Tracker) { 261 auto It = Switch->findCaseValue(Val); 262 Switch->removeCase(It); 263 } 264 265 #ifndef NDEBUG 266 void SwitchAddCase::dump() const { 267 dump(dbgs()); 268 dbgs() << "\n"; 269 } 270 #endif // NDEBUG 271 272 MoveInstr::MoveInstr(Instruction *MovedI) : MovedI(MovedI) { 273 if (auto *NextI = MovedI->getNextNode()) 274 NextInstrOrBB = NextI; 275 else 276 NextInstrOrBB = MovedI->getParent(); 277 } 278 279 void MoveInstr::revert(Tracker &Tracker) { 280 if (auto *NextI = dyn_cast<Instruction *>(NextInstrOrBB)) { 281 MovedI->moveBefore(NextI); 282 } else { 283 auto *BB = cast<BasicBlock *>(NextInstrOrBB); 284 MovedI->moveBefore(*BB, BB->end()); 285 } 286 } 287 288 #ifndef NDEBUG 289 void MoveInstr::dump() const { 290 dump(dbgs()); 291 dbgs() << "\n"; 292 } 293 #endif 294 295 void InsertIntoBB::revert(Tracker &Tracker) { InsertedI->removeFromParent(); } 296 297 InsertIntoBB::InsertIntoBB(Instruction *InsertedI) : InsertedI(InsertedI) {} 298 299 #ifndef NDEBUG 300 void InsertIntoBB::dump() const { 301 dump(dbgs()); 302 dbgs() << "\n"; 303 } 304 #endif 305 306 void CreateAndInsertInst::revert(Tracker &Tracker) { NewI->eraseFromParent(); } 307 308 #ifndef NDEBUG 309 void CreateAndInsertInst::dump() const { 310 dump(dbgs()); 311 dbgs() << "\n"; 312 } 313 #endif 314 315 ShuffleVectorSetMask::ShuffleVectorSetMask(ShuffleVectorInst *SVI) 316 : SVI(SVI), PrevMask(SVI->getShuffleMask()) {} 317 318 void ShuffleVectorSetMask::revert(Tracker &Tracker) { 319 SVI->setShuffleMask(PrevMask); 320 } 321 322 #ifndef NDEBUG 323 void ShuffleVectorSetMask::dump() const { 324 dump(dbgs()); 325 dbgs() << "\n"; 326 } 327 #endif 328 329 CmpSwapOperands::CmpSwapOperands(CmpInst *Cmp) : Cmp(Cmp) {} 330 331 void CmpSwapOperands::revert(Tracker &Tracker) { Cmp->swapOperands(); } 332 #ifndef NDEBUG 333 void CmpSwapOperands::dump() const { 334 dump(dbgs()); 335 dbgs() << "\n"; 336 } 337 #endif 338 339 void Tracker::save() { 340 State = TrackerState::Record; 341 #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS) 342 SnapshotChecker.save(); 343 #endif 344 } 345 346 void Tracker::revert() { 347 assert(State == TrackerState::Record && "Forgot to save()!"); 348 State = TrackerState::Reverting; 349 for (auto &Change : reverse(Changes)) 350 Change->revert(*this); 351 Changes.clear(); 352 #if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS) 353 SnapshotChecker.expectNoDiff(); 354 #endif 355 State = TrackerState::Disabled; 356 } 357 358 void Tracker::accept() { 359 assert(State == TrackerState::Record && "Forgot to save()!"); 360 State = TrackerState::Disabled; 361 for (auto &Change : Changes) 362 Change->accept(); 363 Changes.clear(); 364 } 365 366 #ifndef NDEBUG 367 void Tracker::dump(raw_ostream &OS) const { 368 for (auto [Idx, ChangePtr] : enumerate(Changes)) { 369 OS << Idx << ". "; 370 ChangePtr->dump(OS); 371 OS << "\n"; 372 } 373 } 374 void Tracker::dump() const { 375 dump(dbgs()); 376 dbgs() << "\n"; 377 } 378 #endif // NDEBUG 379