1 //==-- X86LoadValueInjectionLoadHardening.cpp - LVI load hardening for x86 --=// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 /// 9 /// Description: This pass finds Load Value Injection (LVI) gadgets consisting 10 /// of a load from memory (i.e., SOURCE), and any operation that may transmit 11 /// the value loaded from memory over a covert channel, or use the value loaded 12 /// from memory to determine a branch/call target (i.e., SINK). After finding 13 /// all such gadgets in a given function, the pass minimally inserts LFENCE 14 /// instructions in such a manner that the following property is satisfied: for 15 /// all SOURCE+SINK pairs, all paths in the CFG from SOURCE to SINK contain at 16 /// least one LFENCE instruction. The algorithm that implements this minimal 17 /// insertion is influenced by an academic paper that minimally inserts memory 18 /// fences for high-performance concurrent programs: 19 /// http://www.cs.ucr.edu/~lesani/companion/oopsla15/OOPSLA15.pdf 20 /// The algorithm implemented in this pass is as follows: 21 /// 1. Build a condensed CFG (i.e., a GadgetGraph) consisting only of the 22 /// following components: 23 /// - SOURCE instructions (also includes function arguments) 24 /// - SINK instructions 25 /// - Basic block entry points 26 /// - Basic block terminators 27 /// - LFENCE instructions 28 /// 2. Analyze the GadgetGraph to determine which SOURCE+SINK pairs (i.e., 29 /// gadgets) are already mitigated by existing LFENCEs. If all gadgets have been 30 /// mitigated, go to step 6. 31 /// 3. Use a heuristic or plugin to approximate minimal LFENCE insertion. 32 /// 4. Insert one LFENCE along each CFG edge that was cut in step 3. 33 /// 5. Go to step 2. 34 /// 6. If any LFENCEs were inserted, return `true` from runOnMachineFunction() 35 /// to tell LLVM that the function was modified. 36 /// 37 //===----------------------------------------------------------------------===// 38 39 #include "ImmutableGraph.h" 40 #include "X86.h" 41 #include "X86Subtarget.h" 42 #include "X86TargetMachine.h" 43 #include "llvm/ADT/DenseMap.h" 44 #include "llvm/ADT/STLExtras.h" 45 #include "llvm/ADT/SmallSet.h" 46 #include "llvm/ADT/Statistic.h" 47 #include "llvm/ADT/StringRef.h" 48 #include "llvm/CodeGen/MachineBasicBlock.h" 49 #include "llvm/CodeGen/MachineDominanceFrontier.h" 50 #include "llvm/CodeGen/MachineDominators.h" 51 #include "llvm/CodeGen/MachineFunction.h" 52 #include "llvm/CodeGen/MachineFunctionPass.h" 53 #include "llvm/CodeGen/MachineInstr.h" 54 #include "llvm/CodeGen/MachineInstrBuilder.h" 55 #include "llvm/CodeGen/MachineLoopInfo.h" 56 #include "llvm/CodeGen/MachineRegisterInfo.h" 57 #include "llvm/CodeGen/RDFGraph.h" 58 #include "llvm/CodeGen/RDFLiveness.h" 59 #include "llvm/InitializePasses.h" 60 #include "llvm/Support/CommandLine.h" 61 #include "llvm/Support/DOTGraphTraits.h" 62 #include "llvm/Support/Debug.h" 63 #include "llvm/Support/DynamicLibrary.h" 64 #include "llvm/Support/GraphWriter.h" 65 #include "llvm/Support/raw_ostream.h" 66 67 using namespace llvm; 68 69 #define PASS_KEY "x86-lvi-load" 70 #define DEBUG_TYPE PASS_KEY 71 72 STATISTIC(NumFences, "Number of LFENCEs inserted for LVI mitigation"); 73 STATISTIC(NumFunctionsConsidered, "Number of functions analyzed"); 74 STATISTIC(NumFunctionsMitigated, "Number of functions for which mitigations " 75 "were deployed"); 76 STATISTIC(NumGadgets, "Number of LVI gadgets detected during analysis"); 77 78 static cl::opt<std::string> OptimizePluginPath( 79 PASS_KEY "-opt-plugin", 80 cl::desc("Specify a plugin to optimize LFENCE insertion"), cl::Hidden); 81 82 static cl::opt<bool> NoConditionalBranches( 83 PASS_KEY "-no-cbranch", 84 cl::desc("Don't treat conditional branches as disclosure gadgets. This " 85 "may improve performance, at the cost of security."), 86 cl::init(false), cl::Hidden); 87 88 static cl::opt<bool> EmitDot( 89 PASS_KEY "-dot", 90 cl::desc( 91 "For each function, emit a dot graph depicting potential LVI gadgets"), 92 cl::init(false), cl::Hidden); 93 94 static cl::opt<bool> EmitDotOnly( 95 PASS_KEY "-dot-only", 96 cl::desc("For each function, emit a dot graph depicting potential LVI " 97 "gadgets, and do not insert any fences"), 98 cl::init(false), cl::Hidden); 99 100 static cl::opt<bool> EmitDotVerify( 101 PASS_KEY "-dot-verify", 102 cl::desc("For each function, emit a dot graph to stdout depicting " 103 "potential LVI gadgets, used for testing purposes only"), 104 cl::init(false), cl::Hidden); 105 106 static llvm::sys::DynamicLibrary OptimizeDL; 107 typedef int (*OptimizeCutT)(unsigned int *Nodes, unsigned int NodesSize, 108 unsigned int *Edges, int *EdgeValues, 109 int *CutEdges /* out */, unsigned int EdgesSize); 110 static OptimizeCutT OptimizeCut = nullptr; 111 112 namespace { 113 114 struct MachineGadgetGraph : ImmutableGraph<MachineInstr *, int> { 115 static constexpr int GadgetEdgeSentinel = -1; 116 static constexpr MachineInstr *const ArgNodeSentinel = nullptr; 117 118 using GraphT = ImmutableGraph<MachineInstr *, int>; 119 using Node = typename GraphT::Node; 120 using Edge = typename GraphT::Edge; 121 using size_type = typename GraphT::size_type; 122 MachineGadgetGraph(std::unique_ptr<Node[]> Nodes, 123 std::unique_ptr<Edge[]> Edges, size_type NodesSize, 124 size_type EdgesSize, int NumFences = 0, int NumGadgets = 0) 125 : GraphT(std::move(Nodes), std::move(Edges), NodesSize, EdgesSize), 126 NumFences(NumFences), NumGadgets(NumGadgets) {} 127 static inline bool isCFGEdge(const Edge &E) { 128 return E.getValue() != GadgetEdgeSentinel; 129 } 130 static inline bool isGadgetEdge(const Edge &E) { 131 return E.getValue() == GadgetEdgeSentinel; 132 } 133 int NumFences; 134 int NumGadgets; 135 }; 136 137 class X86LoadValueInjectionLoadHardeningPass : public MachineFunctionPass { 138 public: 139 X86LoadValueInjectionLoadHardeningPass() : MachineFunctionPass(ID) {} 140 141 StringRef getPassName() const override { 142 return "X86 Load Value Injection (LVI) Load Hardening"; 143 } 144 void getAnalysisUsage(AnalysisUsage &AU) const override; 145 bool runOnMachineFunction(MachineFunction &MF) override; 146 147 static char ID; 148 149 private: 150 using GraphBuilder = ImmutableGraphBuilder<MachineGadgetGraph>; 151 using Edge = MachineGadgetGraph::Edge; 152 using Node = MachineGadgetGraph::Node; 153 using EdgeSet = MachineGadgetGraph::EdgeSet; 154 using NodeSet = MachineGadgetGraph::NodeSet; 155 156 const X86Subtarget *STI = nullptr; 157 const TargetInstrInfo *TII = nullptr; 158 const TargetRegisterInfo *TRI = nullptr; 159 160 std::unique_ptr<MachineGadgetGraph> 161 getGadgetGraph(MachineFunction &MF, const MachineLoopInfo &MLI, 162 const MachineDominatorTree &MDT, 163 const MachineDominanceFrontier &MDF) const; 164 int hardenLoadsWithPlugin(MachineFunction &MF, 165 std::unique_ptr<MachineGadgetGraph> Graph) const; 166 int hardenLoadsWithHeuristic(MachineFunction &MF, 167 std::unique_ptr<MachineGadgetGraph> Graph) const; 168 int elimMitigatedEdgesAndNodes(MachineGadgetGraph &G, 169 EdgeSet &ElimEdges /* in, out */, 170 NodeSet &ElimNodes /* in, out */) const; 171 std::unique_ptr<MachineGadgetGraph> 172 trimMitigatedEdges(std::unique_ptr<MachineGadgetGraph> Graph) const; 173 int insertFences(MachineFunction &MF, MachineGadgetGraph &G, 174 EdgeSet &CutEdges /* in, out */) const; 175 bool instrUsesRegToAccessMemory(const MachineInstr &I, unsigned Reg) const; 176 bool instrUsesRegToBranch(const MachineInstr &I, unsigned Reg) const; 177 inline bool isFence(const MachineInstr *MI) const { 178 return MI && (MI->getOpcode() == X86::LFENCE || 179 (STI->useLVIControlFlowIntegrity() && MI->isCall())); 180 } 181 }; 182 183 } // end anonymous namespace 184 185 namespace llvm { 186 187 template <> 188 struct GraphTraits<MachineGadgetGraph *> 189 : GraphTraits<ImmutableGraph<MachineInstr *, int> *> {}; 190 191 template <> 192 struct DOTGraphTraits<MachineGadgetGraph *> : DefaultDOTGraphTraits { 193 using GraphType = MachineGadgetGraph; 194 using Traits = llvm::GraphTraits<GraphType *>; 195 using NodeRef = typename Traits::NodeRef; 196 using EdgeRef = typename Traits::EdgeRef; 197 using ChildIteratorType = typename Traits::ChildIteratorType; 198 using ChildEdgeIteratorType = typename Traits::ChildEdgeIteratorType; 199 200 DOTGraphTraits(bool IsSimple = false) : DefaultDOTGraphTraits(IsSimple) {} 201 202 std::string getNodeLabel(NodeRef Node, GraphType *) { 203 if (Node->getValue() == MachineGadgetGraph::ArgNodeSentinel) 204 return "ARGS"; 205 206 std::string Str; 207 raw_string_ostream OS(Str); 208 OS << *Node->getValue(); 209 return OS.str(); 210 } 211 212 static std::string getNodeAttributes(NodeRef Node, GraphType *) { 213 MachineInstr *MI = Node->getValue(); 214 if (MI == MachineGadgetGraph::ArgNodeSentinel) 215 return "color = blue"; 216 if (MI->getOpcode() == X86::LFENCE) 217 return "color = green"; 218 return ""; 219 } 220 221 static std::string getEdgeAttributes(NodeRef, ChildIteratorType E, 222 GraphType *) { 223 int EdgeVal = (*E.getCurrent()).getValue(); 224 return EdgeVal >= 0 ? "label = " + std::to_string(EdgeVal) 225 : "color = red, style = \"dashed\""; 226 } 227 }; 228 229 } // end namespace llvm 230 231 constexpr MachineInstr *MachineGadgetGraph::ArgNodeSentinel; 232 constexpr int MachineGadgetGraph::GadgetEdgeSentinel; 233 234 char X86LoadValueInjectionLoadHardeningPass::ID = 0; 235 236 void X86LoadValueInjectionLoadHardeningPass::getAnalysisUsage( 237 AnalysisUsage &AU) const { 238 MachineFunctionPass::getAnalysisUsage(AU); 239 AU.addRequired<MachineLoopInfoWrapperPass>(); 240 AU.addRequired<MachineDominatorTreeWrapperPass>(); 241 AU.addRequired<MachineDominanceFrontier>(); 242 AU.setPreservesCFG(); 243 } 244 245 static void writeGadgetGraph(raw_ostream &OS, MachineFunction &MF, 246 MachineGadgetGraph *G) { 247 WriteGraph(OS, G, /*ShortNames*/ false, 248 "Speculative gadgets for \"" + MF.getName() + "\" function"); 249 } 250 251 bool X86LoadValueInjectionLoadHardeningPass::runOnMachineFunction( 252 MachineFunction &MF) { 253 LLVM_DEBUG(dbgs() << "***** " << getPassName() << " : " << MF.getName() 254 << " *****\n"); 255 STI = &MF.getSubtarget<X86Subtarget>(); 256 if (!STI->useLVILoadHardening()) 257 return false; 258 259 // FIXME: support 32-bit 260 if (!STI->is64Bit()) 261 report_fatal_error("LVI load hardening is only supported on 64-bit", false); 262 263 // Don't skip functions with the "optnone" attr but participate in opt-bisect. 264 const Function &F = MF.getFunction(); 265 if (!F.hasOptNone() && skipFunction(F)) 266 return false; 267 268 ++NumFunctionsConsidered; 269 TII = STI->getInstrInfo(); 270 TRI = STI->getRegisterInfo(); 271 LLVM_DEBUG(dbgs() << "Building gadget graph...\n"); 272 const auto &MLI = getAnalysis<MachineLoopInfoWrapperPass>().getLI(); 273 const auto &MDT = getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree(); 274 const auto &MDF = getAnalysis<MachineDominanceFrontier>(); 275 std::unique_ptr<MachineGadgetGraph> Graph = getGadgetGraph(MF, MLI, MDT, MDF); 276 LLVM_DEBUG(dbgs() << "Building gadget graph... Done\n"); 277 if (Graph == nullptr) 278 return false; // didn't find any gadgets 279 280 if (EmitDotVerify) { 281 writeGadgetGraph(outs(), MF, Graph.get()); 282 return false; 283 } 284 285 if (EmitDot || EmitDotOnly) { 286 LLVM_DEBUG(dbgs() << "Emitting gadget graph...\n"); 287 std::error_code FileError; 288 std::string FileName = "lvi."; 289 FileName += MF.getName(); 290 FileName += ".dot"; 291 raw_fd_ostream FileOut(FileName, FileError); 292 if (FileError) 293 errs() << FileError.message(); 294 writeGadgetGraph(FileOut, MF, Graph.get()); 295 FileOut.close(); 296 LLVM_DEBUG(dbgs() << "Emitting gadget graph... Done\n"); 297 if (EmitDotOnly) 298 return false; 299 } 300 301 int FencesInserted; 302 if (!OptimizePluginPath.empty()) { 303 if (!OptimizeDL.isValid()) { 304 std::string ErrorMsg; 305 OptimizeDL = llvm::sys::DynamicLibrary::getPermanentLibrary( 306 OptimizePluginPath.c_str(), &ErrorMsg); 307 if (!ErrorMsg.empty()) 308 report_fatal_error(Twine("Failed to load opt plugin: \"") + ErrorMsg + 309 "\""); 310 OptimizeCut = (OptimizeCutT)OptimizeDL.getAddressOfSymbol("optimize_cut"); 311 if (!OptimizeCut) 312 report_fatal_error("Invalid optimization plugin"); 313 } 314 FencesInserted = hardenLoadsWithPlugin(MF, std::move(Graph)); 315 } else { // Use the default greedy heuristic 316 FencesInserted = hardenLoadsWithHeuristic(MF, std::move(Graph)); 317 } 318 319 if (FencesInserted > 0) 320 ++NumFunctionsMitigated; 321 NumFences += FencesInserted; 322 return (FencesInserted > 0); 323 } 324 325 std::unique_ptr<MachineGadgetGraph> 326 X86LoadValueInjectionLoadHardeningPass::getGadgetGraph( 327 MachineFunction &MF, const MachineLoopInfo &MLI, 328 const MachineDominatorTree &MDT, 329 const MachineDominanceFrontier &MDF) const { 330 using namespace rdf; 331 332 // Build the Register Dataflow Graph using the RDF framework 333 DataFlowGraph DFG{MF, *TII, *TRI, MDT, MDF}; 334 DFG.build(); 335 Liveness L{MF.getRegInfo(), DFG}; 336 L.computePhiInfo(); 337 338 GraphBuilder Builder; 339 using GraphIter = typename GraphBuilder::BuilderNodeRef; 340 DenseMap<MachineInstr *, GraphIter> NodeMap; 341 int FenceCount = 0, GadgetCount = 0; 342 auto MaybeAddNode = [&NodeMap, &Builder](MachineInstr *MI) { 343 auto Ref = NodeMap.find(MI); 344 if (Ref == NodeMap.end()) { 345 auto I = Builder.addVertex(MI); 346 NodeMap[MI] = I; 347 return std::pair<GraphIter, bool>{I, true}; 348 } 349 return std::pair<GraphIter, bool>{Ref->getSecond(), false}; 350 }; 351 352 // The `Transmitters` map memoizes transmitters found for each def. If a def 353 // has not yet been analyzed, then it will not appear in the map. If a def 354 // has been analyzed and was determined not to have any transmitters, then 355 // its list of transmitters will be empty. 356 DenseMap<NodeId, std::vector<NodeId>> Transmitters; 357 358 // Analyze all machine instructions to find gadgets and LFENCEs, adding 359 // each interesting value to `Nodes` 360 auto AnalyzeDef = [&](NodeAddr<DefNode *> SourceDef) { 361 SmallSet<NodeId, 8> UsesVisited, DefsVisited; 362 std::function<void(NodeAddr<DefNode *>)> AnalyzeDefUseChain = 363 [&](NodeAddr<DefNode *> Def) { 364 if (Transmitters.contains(Def.Id)) 365 return; // Already analyzed `Def` 366 367 // Use RDF to find all the uses of `Def` 368 rdf::NodeSet Uses; 369 RegisterRef DefReg = Def.Addr->getRegRef(DFG); 370 for (auto UseID : L.getAllReachedUses(DefReg, Def)) { 371 auto Use = DFG.addr<UseNode *>(UseID); 372 if (Use.Addr->getFlags() & NodeAttrs::PhiRef) { // phi node 373 NodeAddr<PhiNode *> Phi = Use.Addr->getOwner(DFG); 374 for (const auto& I : L.getRealUses(Phi.Id)) { 375 if (DFG.getPRI().alias(RegisterRef(I.first), DefReg)) { 376 for (const auto &UA : I.second) 377 Uses.emplace(UA.first); 378 } 379 } 380 } else { // not a phi node 381 Uses.emplace(UseID); 382 } 383 } 384 385 // For each use of `Def`, we want to know whether: 386 // (1) The use can leak the Def'ed value, 387 // (2) The use can further propagate the Def'ed value to more defs 388 for (auto UseID : Uses) { 389 if (!UsesVisited.insert(UseID).second) 390 continue; // Already visited this use of `Def` 391 392 auto Use = DFG.addr<UseNode *>(UseID); 393 assert(!(Use.Addr->getFlags() & NodeAttrs::PhiRef)); 394 MachineOperand &UseMO = Use.Addr->getOp(); 395 MachineInstr &UseMI = *UseMO.getParent(); 396 assert(UseMO.isReg()); 397 398 // We naively assume that an instruction propagates any loaded 399 // uses to all defs unless the instruction is a call, in which 400 // case all arguments will be treated as gadget sources during 401 // analysis of the callee function. 402 if (UseMI.isCall()) 403 continue; 404 405 // Check whether this use can transmit (leak) its value. 406 if (instrUsesRegToAccessMemory(UseMI, UseMO.getReg()) || 407 (!NoConditionalBranches && 408 instrUsesRegToBranch(UseMI, UseMO.getReg()))) { 409 Transmitters[Def.Id].push_back(Use.Addr->getOwner(DFG).Id); 410 if (UseMI.mayLoad()) 411 continue; // Found a transmitting load -- no need to continue 412 // traversing its defs (i.e., this load will become 413 // a new gadget source anyways). 414 } 415 416 // Check whether the use propagates to more defs. 417 NodeAddr<InstrNode *> Owner{Use.Addr->getOwner(DFG)}; 418 rdf::NodeList AnalyzedChildDefs; 419 for (const auto &ChildDef : 420 Owner.Addr->members_if(DataFlowGraph::IsDef, DFG)) { 421 if (!DefsVisited.insert(ChildDef.Id).second) 422 continue; // Already visited this def 423 if (Def.Addr->getAttrs() & NodeAttrs::Dead) 424 continue; 425 if (Def.Id == ChildDef.Id) 426 continue; // `Def` uses itself (e.g., increment loop counter) 427 428 AnalyzeDefUseChain(ChildDef); 429 430 // `Def` inherits all of its child defs' transmitters. 431 for (auto TransmitterId : Transmitters[ChildDef.Id]) 432 Transmitters[Def.Id].push_back(TransmitterId); 433 } 434 } 435 436 // Note that this statement adds `Def.Id` to the map if no 437 // transmitters were found for `Def`. 438 auto &DefTransmitters = Transmitters[Def.Id]; 439 440 // Remove duplicate transmitters 441 llvm::sort(DefTransmitters); 442 DefTransmitters.erase(llvm::unique(DefTransmitters), 443 DefTransmitters.end()); 444 }; 445 446 // Find all of the transmitters 447 AnalyzeDefUseChain(SourceDef); 448 auto &SourceDefTransmitters = Transmitters[SourceDef.Id]; 449 if (SourceDefTransmitters.empty()) 450 return; // No transmitters for `SourceDef` 451 452 MachineInstr *Source = SourceDef.Addr->getFlags() & NodeAttrs::PhiRef 453 ? MachineGadgetGraph::ArgNodeSentinel 454 : SourceDef.Addr->getOp().getParent(); 455 auto GadgetSource = MaybeAddNode(Source); 456 // Each transmitter is a sink for `SourceDef`. 457 for (auto TransmitterId : SourceDefTransmitters) { 458 MachineInstr *Sink = DFG.addr<StmtNode *>(TransmitterId).Addr->getCode(); 459 auto GadgetSink = MaybeAddNode(Sink); 460 // Add the gadget edge to the graph. 461 Builder.addEdge(MachineGadgetGraph::GadgetEdgeSentinel, 462 GadgetSource.first, GadgetSink.first); 463 ++GadgetCount; 464 } 465 }; 466 467 LLVM_DEBUG(dbgs() << "Analyzing def-use chains to find gadgets\n"); 468 // Analyze function arguments 469 NodeAddr<BlockNode *> EntryBlock = DFG.getFunc().Addr->getEntryBlock(DFG); 470 for (NodeAddr<PhiNode *> ArgPhi : 471 EntryBlock.Addr->members_if(DataFlowGraph::IsPhi, DFG)) { 472 NodeList Defs = ArgPhi.Addr->members_if(DataFlowGraph::IsDef, DFG); 473 llvm::for_each(Defs, AnalyzeDef); 474 } 475 // Analyze every instruction in MF 476 for (NodeAddr<BlockNode *> BA : DFG.getFunc().Addr->members(DFG)) { 477 for (NodeAddr<StmtNode *> SA : 478 BA.Addr->members_if(DataFlowGraph::IsCode<NodeAttrs::Stmt>, DFG)) { 479 MachineInstr *MI = SA.Addr->getCode(); 480 if (isFence(MI)) { 481 MaybeAddNode(MI); 482 ++FenceCount; 483 } else if (MI->mayLoad()) { 484 NodeList Defs = SA.Addr->members_if(DataFlowGraph::IsDef, DFG); 485 llvm::for_each(Defs, AnalyzeDef); 486 } 487 } 488 } 489 LLVM_DEBUG(dbgs() << "Found " << FenceCount << " fences\n"); 490 LLVM_DEBUG(dbgs() << "Found " << GadgetCount << " gadgets\n"); 491 if (GadgetCount == 0) 492 return nullptr; 493 NumGadgets += GadgetCount; 494 495 // Traverse CFG to build the rest of the graph 496 SmallSet<MachineBasicBlock *, 8> BlocksVisited; 497 std::function<void(MachineBasicBlock *, GraphIter, unsigned)> TraverseCFG = 498 [&](MachineBasicBlock *MBB, GraphIter GI, unsigned ParentDepth) { 499 unsigned LoopDepth = MLI.getLoopDepth(MBB); 500 if (!MBB->empty()) { 501 // Always add the first instruction in each block 502 auto NI = MBB->begin(); 503 auto BeginBB = MaybeAddNode(&*NI); 504 Builder.addEdge(ParentDepth, GI, BeginBB.first); 505 if (!BlocksVisited.insert(MBB).second) 506 return; 507 508 // Add any instructions within the block that are gadget components 509 GI = BeginBB.first; 510 while (++NI != MBB->end()) { 511 auto Ref = NodeMap.find(&*NI); 512 if (Ref != NodeMap.end()) { 513 Builder.addEdge(LoopDepth, GI, Ref->getSecond()); 514 GI = Ref->getSecond(); 515 } 516 } 517 518 // Always add the terminator instruction, if one exists 519 auto T = MBB->getFirstTerminator(); 520 if (T != MBB->end()) { 521 auto EndBB = MaybeAddNode(&*T); 522 if (EndBB.second) 523 Builder.addEdge(LoopDepth, GI, EndBB.first); 524 GI = EndBB.first; 525 } 526 } 527 for (MachineBasicBlock *Succ : MBB->successors()) 528 TraverseCFG(Succ, GI, LoopDepth); 529 }; 530 // ArgNodeSentinel is a pseudo-instruction that represents MF args in the 531 // GadgetGraph 532 GraphIter ArgNode = MaybeAddNode(MachineGadgetGraph::ArgNodeSentinel).first; 533 TraverseCFG(&MF.front(), ArgNode, 0); 534 std::unique_ptr<MachineGadgetGraph> G{Builder.get(FenceCount, GadgetCount)}; 535 LLVM_DEBUG(dbgs() << "Found " << G->nodes_size() << " nodes\n"); 536 return G; 537 } 538 539 // Returns the number of remaining gadget edges that could not be eliminated 540 int X86LoadValueInjectionLoadHardeningPass::elimMitigatedEdgesAndNodes( 541 MachineGadgetGraph &G, EdgeSet &ElimEdges /* in, out */, 542 NodeSet &ElimNodes /* in, out */) const { 543 if (G.NumFences > 0) { 544 // Eliminate fences and CFG edges that ingress and egress the fence, as 545 // they are trivially mitigated. 546 for (const Edge &E : G.edges()) { 547 const Node *Dest = E.getDest(); 548 if (isFence(Dest->getValue())) { 549 ElimNodes.insert(*Dest); 550 ElimEdges.insert(E); 551 for (const Edge &DE : Dest->edges()) 552 ElimEdges.insert(DE); 553 } 554 } 555 } 556 557 // Find and eliminate gadget edges that have been mitigated. 558 int RemainingGadgets = 0; 559 NodeSet ReachableNodes{G}; 560 for (const Node &RootN : G.nodes()) { 561 if (llvm::none_of(RootN.edges(), MachineGadgetGraph::isGadgetEdge)) 562 continue; // skip this node if it isn't a gadget source 563 564 // Find all of the nodes that are CFG-reachable from RootN using DFS 565 ReachableNodes.clear(); 566 std::function<void(const Node *, bool)> FindReachableNodes = 567 [&](const Node *N, bool FirstNode) { 568 if (!FirstNode) 569 ReachableNodes.insert(*N); 570 for (const Edge &E : N->edges()) { 571 const Node *Dest = E.getDest(); 572 if (MachineGadgetGraph::isCFGEdge(E) && !ElimEdges.contains(E) && 573 !ReachableNodes.contains(*Dest)) 574 FindReachableNodes(Dest, false); 575 } 576 }; 577 FindReachableNodes(&RootN, true); 578 579 // Any gadget whose sink is unreachable has been mitigated 580 for (const Edge &E : RootN.edges()) { 581 if (MachineGadgetGraph::isGadgetEdge(E)) { 582 if (ReachableNodes.contains(*E.getDest())) { 583 // This gadget's sink is reachable 584 ++RemainingGadgets; 585 } else { // This gadget's sink is unreachable, and therefore mitigated 586 ElimEdges.insert(E); 587 } 588 } 589 } 590 } 591 return RemainingGadgets; 592 } 593 594 std::unique_ptr<MachineGadgetGraph> 595 X86LoadValueInjectionLoadHardeningPass::trimMitigatedEdges( 596 std::unique_ptr<MachineGadgetGraph> Graph) const { 597 NodeSet ElimNodes{*Graph}; 598 EdgeSet ElimEdges{*Graph}; 599 int RemainingGadgets = 600 elimMitigatedEdgesAndNodes(*Graph, ElimEdges, ElimNodes); 601 if (ElimEdges.empty() && ElimNodes.empty()) { 602 Graph->NumFences = 0; 603 Graph->NumGadgets = RemainingGadgets; 604 } else { 605 Graph = GraphBuilder::trim(*Graph, ElimNodes, ElimEdges, 0 /* NumFences */, 606 RemainingGadgets); 607 } 608 return Graph; 609 } 610 611 int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithPlugin( 612 MachineFunction &MF, std::unique_ptr<MachineGadgetGraph> Graph) const { 613 int FencesInserted = 0; 614 615 do { 616 LLVM_DEBUG(dbgs() << "Eliminating mitigated paths...\n"); 617 Graph = trimMitigatedEdges(std::move(Graph)); 618 LLVM_DEBUG(dbgs() << "Eliminating mitigated paths... Done\n"); 619 if (Graph->NumGadgets == 0) 620 break; 621 622 LLVM_DEBUG(dbgs() << "Cutting edges...\n"); 623 EdgeSet CutEdges{*Graph}; 624 auto Nodes = std::make_unique<unsigned int[]>(Graph->nodes_size() + 625 1 /* terminator node */); 626 auto Edges = std::make_unique<unsigned int[]>(Graph->edges_size()); 627 auto EdgeCuts = std::make_unique<int[]>(Graph->edges_size()); 628 auto EdgeValues = std::make_unique<int[]>(Graph->edges_size()); 629 for (const Node &N : Graph->nodes()) { 630 Nodes[Graph->getNodeIndex(N)] = Graph->getEdgeIndex(*N.edges_begin()); 631 } 632 Nodes[Graph->nodes_size()] = Graph->edges_size(); // terminator node 633 for (const Edge &E : Graph->edges()) { 634 Edges[Graph->getEdgeIndex(E)] = Graph->getNodeIndex(*E.getDest()); 635 EdgeValues[Graph->getEdgeIndex(E)] = E.getValue(); 636 } 637 OptimizeCut(Nodes.get(), Graph->nodes_size(), Edges.get(), EdgeValues.get(), 638 EdgeCuts.get(), Graph->edges_size()); 639 for (int I = 0; I < Graph->edges_size(); ++I) 640 if (EdgeCuts[I]) 641 CutEdges.set(I); 642 LLVM_DEBUG(dbgs() << "Cutting edges... Done\n"); 643 LLVM_DEBUG(dbgs() << "Cut " << CutEdges.count() << " edges\n"); 644 645 LLVM_DEBUG(dbgs() << "Inserting LFENCEs...\n"); 646 FencesInserted += insertFences(MF, *Graph, CutEdges); 647 LLVM_DEBUG(dbgs() << "Inserting LFENCEs... Done\n"); 648 LLVM_DEBUG(dbgs() << "Inserted " << FencesInserted << " fences\n"); 649 650 Graph = GraphBuilder::trim(*Graph, NodeSet{*Graph}, CutEdges); 651 } while (true); 652 653 return FencesInserted; 654 } 655 656 int X86LoadValueInjectionLoadHardeningPass::hardenLoadsWithHeuristic( 657 MachineFunction &MF, std::unique_ptr<MachineGadgetGraph> Graph) const { 658 // If `MF` does not have any fences, then no gadgets would have been 659 // mitigated at this point. 660 if (Graph->NumFences > 0) { 661 LLVM_DEBUG(dbgs() << "Eliminating mitigated paths...\n"); 662 Graph = trimMitigatedEdges(std::move(Graph)); 663 LLVM_DEBUG(dbgs() << "Eliminating mitigated paths... Done\n"); 664 } 665 666 if (Graph->NumGadgets == 0) 667 return 0; 668 669 LLVM_DEBUG(dbgs() << "Cutting edges...\n"); 670 EdgeSet CutEdges{*Graph}; 671 672 // Begin by collecting all ingress CFG edges for each node 673 DenseMap<const Node *, SmallVector<const Edge *, 2>> IngressEdgeMap; 674 for (const Edge &E : Graph->edges()) 675 if (MachineGadgetGraph::isCFGEdge(E)) 676 IngressEdgeMap[E.getDest()].push_back(&E); 677 678 // For each gadget edge, make cuts that guarantee the gadget will be 679 // mitigated. A computationally efficient way to achieve this is to either: 680 // (a) cut all egress CFG edges from the gadget source, or 681 // (b) cut all ingress CFG edges to the gadget sink. 682 // 683 // Moreover, the algorithm tries not to make a cut into a loop by preferring 684 // to make a (b)-type cut if the gadget source resides at a greater loop depth 685 // than the gadget sink, or an (a)-type cut otherwise. 686 for (const Node &N : Graph->nodes()) { 687 for (const Edge &E : N.edges()) { 688 if (!MachineGadgetGraph::isGadgetEdge(E)) 689 continue; 690 691 SmallVector<const Edge *, 2> EgressEdges; 692 SmallVector<const Edge *, 2> &IngressEdges = IngressEdgeMap[E.getDest()]; 693 for (const Edge &EgressEdge : N.edges()) 694 if (MachineGadgetGraph::isCFGEdge(EgressEdge)) 695 EgressEdges.push_back(&EgressEdge); 696 697 int EgressCutCost = 0, IngressCutCost = 0; 698 for (const Edge *EgressEdge : EgressEdges) 699 if (!CutEdges.contains(*EgressEdge)) 700 EgressCutCost += EgressEdge->getValue(); 701 for (const Edge *IngressEdge : IngressEdges) 702 if (!CutEdges.contains(*IngressEdge)) 703 IngressCutCost += IngressEdge->getValue(); 704 705 auto &EdgesToCut = 706 IngressCutCost < EgressCutCost ? IngressEdges : EgressEdges; 707 for (const Edge *E : EdgesToCut) 708 CutEdges.insert(*E); 709 } 710 } 711 LLVM_DEBUG(dbgs() << "Cutting edges... Done\n"); 712 LLVM_DEBUG(dbgs() << "Cut " << CutEdges.count() << " edges\n"); 713 714 LLVM_DEBUG(dbgs() << "Inserting LFENCEs...\n"); 715 int FencesInserted = insertFences(MF, *Graph, CutEdges); 716 LLVM_DEBUG(dbgs() << "Inserting LFENCEs... Done\n"); 717 LLVM_DEBUG(dbgs() << "Inserted " << FencesInserted << " fences\n"); 718 719 return FencesInserted; 720 } 721 722 int X86LoadValueInjectionLoadHardeningPass::insertFences( 723 MachineFunction &MF, MachineGadgetGraph &G, 724 EdgeSet &CutEdges /* in, out */) const { 725 int FencesInserted = 0; 726 for (const Node &N : G.nodes()) { 727 for (const Edge &E : N.edges()) { 728 if (CutEdges.contains(E)) { 729 MachineInstr *MI = N.getValue(), *Prev; 730 MachineBasicBlock *MBB; // Insert an LFENCE in this MBB 731 MachineBasicBlock::iterator InsertionPt; // ...at this point 732 if (MI == MachineGadgetGraph::ArgNodeSentinel) { 733 // insert LFENCE at beginning of entry block 734 MBB = &MF.front(); 735 InsertionPt = MBB->begin(); 736 Prev = nullptr; 737 } else if (MI->isBranch()) { // insert the LFENCE before the branch 738 MBB = MI->getParent(); 739 InsertionPt = MI; 740 Prev = MI->getPrevNode(); 741 // Remove all egress CFG edges from this branch because the inserted 742 // LFENCE prevents gadgets from crossing the branch. 743 for (const Edge &E : N.edges()) { 744 if (MachineGadgetGraph::isCFGEdge(E)) 745 CutEdges.insert(E); 746 } 747 } else { // insert the LFENCE after the instruction 748 MBB = MI->getParent(); 749 InsertionPt = MI->getNextNode() ? MI->getNextNode() : MBB->end(); 750 Prev = InsertionPt == MBB->end() 751 ? (MBB->empty() ? nullptr : &MBB->back()) 752 : InsertionPt->getPrevNode(); 753 } 754 // Ensure this insertion is not redundant (two LFENCEs in sequence). 755 if ((InsertionPt == MBB->end() || !isFence(&*InsertionPt)) && 756 (!Prev || !isFence(Prev))) { 757 BuildMI(*MBB, InsertionPt, DebugLoc(), TII->get(X86::LFENCE)); 758 ++FencesInserted; 759 } 760 } 761 } 762 } 763 return FencesInserted; 764 } 765 766 bool X86LoadValueInjectionLoadHardeningPass::instrUsesRegToAccessMemory( 767 const MachineInstr &MI, unsigned Reg) const { 768 if (!MI.mayLoadOrStore() || MI.getOpcode() == X86::MFENCE || 769 MI.getOpcode() == X86::SFENCE || MI.getOpcode() == X86::LFENCE) 770 return false; 771 772 const int MemRefBeginIdx = X86::getFirstAddrOperandIdx(MI); 773 if (MemRefBeginIdx < 0) { 774 LLVM_DEBUG(dbgs() << "Warning: unable to obtain memory operand for loading " 775 "instruction:\n"; 776 MI.print(dbgs()); dbgs() << '\n';); 777 return false; 778 } 779 780 const MachineOperand &BaseMO = 781 MI.getOperand(MemRefBeginIdx + X86::AddrBaseReg); 782 const MachineOperand &IndexMO = 783 MI.getOperand(MemRefBeginIdx + X86::AddrIndexReg); 784 return (BaseMO.isReg() && BaseMO.getReg() != X86::NoRegister && 785 TRI->regsOverlap(BaseMO.getReg(), Reg)) || 786 (IndexMO.isReg() && IndexMO.getReg() != X86::NoRegister && 787 TRI->regsOverlap(IndexMO.getReg(), Reg)); 788 } 789 790 bool X86LoadValueInjectionLoadHardeningPass::instrUsesRegToBranch( 791 const MachineInstr &MI, unsigned Reg) const { 792 if (!MI.isConditionalBranch()) 793 return false; 794 for (const MachineOperand &Use : MI.uses()) 795 if (Use.isReg() && Use.getReg() == Reg) 796 return true; 797 return false; 798 } 799 800 INITIALIZE_PASS_BEGIN(X86LoadValueInjectionLoadHardeningPass, PASS_KEY, 801 "X86 LVI load hardening", false, false) 802 INITIALIZE_PASS_DEPENDENCY(MachineLoopInfoWrapperPass) 803 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTreeWrapperPass) 804 INITIALIZE_PASS_DEPENDENCY(MachineDominanceFrontier) 805 INITIALIZE_PASS_END(X86LoadValueInjectionLoadHardeningPass, PASS_KEY, 806 "X86 LVI load hardening", false, false) 807 808 FunctionPass *llvm::createX86LoadValueInjectionLoadHardeningPass() { 809 return new X86LoadValueInjectionLoadHardeningPass(); 810 } 811