xref: /freebsd/contrib/llvm-project/llvm/lib/Target/WebAssembly/WebAssemblyLateEHPrepare.cpp (revision fe815331bb40604ba31312acf7e4619674631777)
1 //=== WebAssemblyLateEHPrepare.cpp - WebAssembly Exception Preparation -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 ///
9 /// \file
10 /// \brief Does various transformations for exception handling.
11 ///
12 //===----------------------------------------------------------------------===//
13 
14 #include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
15 #include "WebAssembly.h"
16 #include "WebAssemblySubtarget.h"
17 #include "WebAssemblyUtilities.h"
18 #include "llvm/ADT/SmallSet.h"
19 #include "llvm/CodeGen/MachineInstrBuilder.h"
20 #include "llvm/CodeGen/WasmEHFuncInfo.h"
21 #include "llvm/MC/MCAsmInfo.h"
22 #include "llvm/Support/Debug.h"
23 #include "llvm/Target/TargetMachine.h"
24 using namespace llvm;
25 
26 #define DEBUG_TYPE "wasm-late-eh-prepare"
27 
28 namespace {
29 class WebAssemblyLateEHPrepare final : public MachineFunctionPass {
30   StringRef getPassName() const override {
31     return "WebAssembly Late Prepare Exception";
32   }
33 
34   bool runOnMachineFunction(MachineFunction &MF) override;
35   void recordCatchRetBBs(MachineFunction &MF);
36   bool addCatches(MachineFunction &MF);
37   bool replaceFuncletReturns(MachineFunction &MF);
38   bool removeUnnecessaryUnreachables(MachineFunction &MF);
39   bool addExceptionExtraction(MachineFunction &MF);
40   bool restoreStackPointer(MachineFunction &MF);
41 
42   MachineBasicBlock *getMatchingEHPad(MachineInstr *MI);
43   SmallSet<MachineBasicBlock *, 8> CatchRetBBs;
44 
45 public:
46   static char ID; // Pass identification, replacement for typeid
47   WebAssemblyLateEHPrepare() : MachineFunctionPass(ID) {}
48 };
49 } // end anonymous namespace
50 
51 char WebAssemblyLateEHPrepare::ID = 0;
52 INITIALIZE_PASS(WebAssemblyLateEHPrepare, DEBUG_TYPE,
53                 "WebAssembly Late Exception Preparation", false, false)
54 
55 FunctionPass *llvm::createWebAssemblyLateEHPrepare() {
56   return new WebAssemblyLateEHPrepare();
57 }
58 
59 // Returns the nearest EH pad that dominates this instruction. This does not use
60 // dominator analysis; it just does BFS on its predecessors until arriving at an
61 // EH pad. This assumes valid EH scopes so the first EH pad it arrives in all
62 // possible search paths should be the same.
63 // Returns nullptr in case it does not find any EH pad in the search, or finds
64 // multiple different EH pads.
65 MachineBasicBlock *
66 WebAssemblyLateEHPrepare::getMatchingEHPad(MachineInstr *MI) {
67   MachineFunction *MF = MI->getParent()->getParent();
68   SmallVector<MachineBasicBlock *, 2> WL;
69   SmallPtrSet<MachineBasicBlock *, 2> Visited;
70   WL.push_back(MI->getParent());
71   MachineBasicBlock *EHPad = nullptr;
72   while (!WL.empty()) {
73     MachineBasicBlock *MBB = WL.pop_back_val();
74     if (Visited.count(MBB))
75       continue;
76     Visited.insert(MBB);
77     if (MBB->isEHPad()) {
78       if (EHPad && EHPad != MBB)
79         return nullptr;
80       EHPad = MBB;
81       continue;
82     }
83     if (MBB == &MF->front())
84       return nullptr;
85     for (auto *Pred : MBB->predecessors())
86       if (!CatchRetBBs.count(Pred)) // We don't go into child scopes
87         WL.push_back(Pred);
88   }
89   return EHPad;
90 }
91 
92 // Erase the specified BBs if the BB does not have any remaining predecessors,
93 // and also all its dead children.
94 template <typename Container>
95 static void eraseDeadBBsAndChildren(const Container &MBBs) {
96   SmallVector<MachineBasicBlock *, 8> WL(MBBs.begin(), MBBs.end());
97   while (!WL.empty()) {
98     MachineBasicBlock *MBB = WL.pop_back_val();
99     if (!MBB->pred_empty())
100       continue;
101     SmallVector<MachineBasicBlock *, 4> Succs(MBB->succ_begin(),
102                                               MBB->succ_end());
103     WL.append(MBB->succ_begin(), MBB->succ_end());
104     for (auto *Succ : Succs)
105       MBB->removeSuccessor(Succ);
106     MBB->eraseFromParent();
107   }
108 }
109 
110 bool WebAssemblyLateEHPrepare::runOnMachineFunction(MachineFunction &MF) {
111   LLVM_DEBUG(dbgs() << "********** Late EH Prepare **********\n"
112                        "********** Function: "
113                     << MF.getName() << '\n');
114 
115   if (MF.getTarget().getMCAsmInfo()->getExceptionHandlingType() !=
116       ExceptionHandling::Wasm)
117     return false;
118 
119   bool Changed = false;
120   if (MF.getFunction().hasPersonalityFn()) {
121     recordCatchRetBBs(MF);
122     Changed |= addCatches(MF);
123     Changed |= replaceFuncletReturns(MF);
124   }
125   Changed |= removeUnnecessaryUnreachables(MF);
126   if (MF.getFunction().hasPersonalityFn()) {
127     Changed |= addExceptionExtraction(MF);
128     Changed |= restoreStackPointer(MF);
129   }
130   return Changed;
131 }
132 
133 // Record which BB ends with 'CATCHRET' instruction, because this will be
134 // replaced with BRs later. This set of 'CATCHRET' BBs is necessary in
135 // 'getMatchingEHPad' function.
136 void WebAssemblyLateEHPrepare::recordCatchRetBBs(MachineFunction &MF) {
137   CatchRetBBs.clear();
138   for (auto &MBB : MF) {
139     auto Pos = MBB.getFirstTerminator();
140     if (Pos == MBB.end())
141       continue;
142     MachineInstr *TI = &*Pos;
143     if (TI->getOpcode() == WebAssembly::CATCHRET)
144       CatchRetBBs.insert(&MBB);
145   }
146 }
147 
148 // Add catch instruction to beginning of catchpads and cleanuppads.
149 bool WebAssemblyLateEHPrepare::addCatches(MachineFunction &MF) {
150   bool Changed = false;
151   const auto &TII = *MF.getSubtarget<WebAssemblySubtarget>().getInstrInfo();
152   MachineRegisterInfo &MRI = MF.getRegInfo();
153   for (auto &MBB : MF) {
154     if (MBB.isEHPad()) {
155       Changed = true;
156       auto InsertPos = MBB.begin();
157       if (InsertPos->isEHLabel()) // EH pad starts with an EH label
158         ++InsertPos;
159       Register DstReg = MRI.createVirtualRegister(&WebAssembly::EXNREFRegClass);
160       BuildMI(MBB, InsertPos, MBB.begin()->getDebugLoc(),
161               TII.get(WebAssembly::CATCH), DstReg);
162     }
163   }
164   return Changed;
165 }
166 
167 bool WebAssemblyLateEHPrepare::replaceFuncletReturns(MachineFunction &MF) {
168   bool Changed = false;
169   const auto &TII = *MF.getSubtarget<WebAssemblySubtarget>().getInstrInfo();
170 
171   for (auto &MBB : MF) {
172     auto Pos = MBB.getFirstTerminator();
173     if (Pos == MBB.end())
174       continue;
175     MachineInstr *TI = &*Pos;
176 
177     switch (TI->getOpcode()) {
178     case WebAssembly::CATCHRET: {
179       // Replace a catchret with a branch
180       MachineBasicBlock *TBB = TI->getOperand(0).getMBB();
181       if (!MBB.isLayoutSuccessor(TBB))
182         BuildMI(MBB, TI, TI->getDebugLoc(), TII.get(WebAssembly::BR))
183             .addMBB(TBB);
184       TI->eraseFromParent();
185       Changed = true;
186       break;
187     }
188     case WebAssembly::CLEANUPRET:
189     case WebAssembly::RETHROW_IN_CATCH: {
190       // Replace a cleanupret/rethrow_in_catch with a rethrow
191       auto *EHPad = getMatchingEHPad(TI);
192       auto CatchPos = EHPad->begin();
193       if (CatchPos->isEHLabel()) // EH pad starts with an EH label
194         ++CatchPos;
195       MachineInstr *Catch = &*CatchPos;
196       Register ExnReg = Catch->getOperand(0).getReg();
197       BuildMI(MBB, TI, TI->getDebugLoc(), TII.get(WebAssembly::RETHROW))
198           .addReg(ExnReg);
199       TI->eraseFromParent();
200       Changed = true;
201       break;
202     }
203     }
204   }
205   return Changed;
206 }
207 
208 bool WebAssemblyLateEHPrepare::removeUnnecessaryUnreachables(
209     MachineFunction &MF) {
210   bool Changed = false;
211   for (auto &MBB : MF) {
212     for (auto &MI : MBB) {
213       if (MI.getOpcode() != WebAssembly::THROW &&
214           MI.getOpcode() != WebAssembly::RETHROW)
215         continue;
216       Changed = true;
217 
218       // The instruction after the throw should be an unreachable or a branch to
219       // another BB that should eventually lead to an unreachable. Delete it
220       // because throw itself is a terminator, and also delete successors if
221       // any.
222       MBB.erase(std::next(MI.getIterator()), MBB.end());
223       SmallVector<MachineBasicBlock *, 8> Succs(MBB.succ_begin(),
224                                                 MBB.succ_end());
225       for (auto *Succ : Succs)
226         if (!Succ->isEHPad())
227           MBB.removeSuccessor(Succ);
228       eraseDeadBBsAndChildren(Succs);
229     }
230   }
231 
232   return Changed;
233 }
234 
235 // Wasm uses 'br_on_exn' instruction to check the tag of an exception. It takes
236 // exnref type object returned by 'catch', and branches to the destination if it
237 // matches a given tag. We currently use __cpp_exception symbol to represent the
238 // tag for all C++ exceptions.
239 //
240 // block $l (result i32)
241 //   ...
242 //   ;; exnref $e is on the stack at this point
243 //   br_on_exn $l $e ;; branch to $l with $e's arguments
244 //   ...
245 // end
246 // ;; Here we expect the extracted values are on top of the wasm value stack
247 // ... Handle exception using values ...
248 //
249 // br_on_exn takes an exnref object and branches if it matches the given tag.
250 // There can be multiple br_on_exn instructions if we want to match for another
251 // tag, but for now we only test for __cpp_exception tag, and if it does not
252 // match, i.e., it is a foreign exception, we rethrow it.
253 //
254 // In the destination BB that's the target of br_on_exn, extracted exception
255 // values (in C++'s case a single i32, which represents an exception pointer)
256 // are placed on top of the wasm stack. Because we can't model wasm stack in
257 // LLVM instruction, we use 'extract_exception' pseudo instruction to retrieve
258 // it. The pseudo instruction will be deleted later.
259 bool WebAssemblyLateEHPrepare::addExceptionExtraction(MachineFunction &MF) {
260   const auto &TII = *MF.getSubtarget<WebAssemblySubtarget>().getInstrInfo();
261   MachineRegisterInfo &MRI = MF.getRegInfo();
262   auto *EHInfo = MF.getWasmEHFuncInfo();
263   SmallVector<MachineInstr *, 16> ExtractInstrs;
264   SmallVector<MachineInstr *, 8> ToDelete;
265   for (auto &MBB : MF) {
266     for (auto &MI : MBB) {
267       if (MI.getOpcode() == WebAssembly::EXTRACT_EXCEPTION_I32) {
268         if (MI.getOperand(0).isDead())
269           ToDelete.push_back(&MI);
270         else
271           ExtractInstrs.push_back(&MI);
272       }
273     }
274   }
275   bool Changed = !ToDelete.empty() || !ExtractInstrs.empty();
276   for (auto *MI : ToDelete)
277     MI->eraseFromParent();
278   if (ExtractInstrs.empty())
279     return Changed;
280 
281   // Find terminate pads.
282   SmallSet<MachineBasicBlock *, 8> TerminatePads;
283   for (auto &MBB : MF) {
284     for (auto &MI : MBB) {
285       if (MI.isCall()) {
286         const MachineOperand &CalleeOp = MI.getOperand(0);
287         if (CalleeOp.isGlobal() && CalleeOp.getGlobal()->getName() ==
288                                        WebAssembly::ClangCallTerminateFn)
289           TerminatePads.insert(getMatchingEHPad(&MI));
290       }
291     }
292   }
293 
294   for (auto *Extract : ExtractInstrs) {
295     MachineBasicBlock *EHPad = getMatchingEHPad(Extract);
296     assert(EHPad && "No matching EH pad for extract_exception");
297     auto CatchPos = EHPad->begin();
298     if (CatchPos->isEHLabel()) // EH pad starts with an EH label
299       ++CatchPos;
300     MachineInstr *Catch = &*CatchPos;
301 
302     if (Catch->getNextNode() != Extract)
303       EHPad->insert(Catch->getNextNode(), Extract->removeFromParent());
304 
305     // - Before:
306     // ehpad:
307     //   %exnref:exnref = catch
308     //   %exn:i32 = extract_exception
309     //   ... use exn ...
310     //
311     // - After:
312     // ehpad:
313     //   %exnref:exnref = catch
314     //   br_on_exn %thenbb, $__cpp_exception, %exnref
315     //   br %elsebb
316     // elsebb:
317     //   rethrow
318     // thenbb:
319     //   %exn:i32 = extract_exception
320     //   ... use exn ...
321     Register ExnReg = Catch->getOperand(0).getReg();
322     auto *ThenMBB = MF.CreateMachineBasicBlock();
323     auto *ElseMBB = MF.CreateMachineBasicBlock();
324     MF.insert(std::next(MachineFunction::iterator(EHPad)), ElseMBB);
325     MF.insert(std::next(MachineFunction::iterator(ElseMBB)), ThenMBB);
326     ThenMBB->splice(ThenMBB->end(), EHPad, Extract, EHPad->end());
327     ThenMBB->transferSuccessors(EHPad);
328     EHPad->addSuccessor(ThenMBB);
329     EHPad->addSuccessor(ElseMBB);
330 
331     DebugLoc DL = Extract->getDebugLoc();
332     const char *CPPExnSymbol = MF.createExternalSymbolName("__cpp_exception");
333     BuildMI(EHPad, DL, TII.get(WebAssembly::BR_ON_EXN))
334         .addMBB(ThenMBB)
335         .addExternalSymbol(CPPExnSymbol)
336         .addReg(ExnReg);
337     BuildMI(EHPad, DL, TII.get(WebAssembly::BR)).addMBB(ElseMBB);
338 
339     // When this is a terminate pad with __clang_call_terminate() call, we don't
340     // rethrow it anymore and call __clang_call_terminate() with a nullptr
341     // argument, which will call std::terminate().
342     //
343     // - Before:
344     // ehpad:
345     //   %exnref:exnref = catch
346     //   %exn:i32 = extract_exception
347     //   call @__clang_call_terminate(%exn)
348     //   unreachable
349     //
350     // - After:
351     // ehpad:
352     //   %exnref:exnref = catch
353     //   br_on_exn %thenbb, $__cpp_exception, %exnref
354     //   br %elsebb
355     // elsebb:
356     //   call @__clang_call_terminate(0)
357     //   unreachable
358     // thenbb:
359     //   %exn:i32 = extract_exception
360     //   call @__clang_call_terminate(%exn)
361     //   unreachable
362     if (TerminatePads.count(EHPad)) {
363       Function *ClangCallTerminateFn =
364           MF.getFunction().getParent()->getFunction(
365               WebAssembly::ClangCallTerminateFn);
366       assert(ClangCallTerminateFn &&
367              "There is no __clang_call_terminate() function");
368       Register Reg = MRI.createVirtualRegister(&WebAssembly::I32RegClass);
369       BuildMI(ElseMBB, DL, TII.get(WebAssembly::CONST_I32), Reg).addImm(0);
370       BuildMI(ElseMBB, DL, TII.get(WebAssembly::CALL))
371           .addGlobalAddress(ClangCallTerminateFn)
372           .addReg(Reg);
373       BuildMI(ElseMBB, DL, TII.get(WebAssembly::UNREACHABLE));
374 
375     } else {
376       BuildMI(ElseMBB, DL, TII.get(WebAssembly::RETHROW)).addReg(ExnReg);
377       if (EHInfo->hasEHPadUnwindDest(EHPad))
378         ElseMBB->addSuccessor(EHInfo->getEHPadUnwindDest(EHPad));
379     }
380   }
381 
382   return true;
383 }
384 
385 // After the stack is unwound due to a thrown exception, the __stack_pointer
386 // global can point to an invalid address. This inserts instructions that
387 // restore __stack_pointer global.
388 bool WebAssemblyLateEHPrepare::restoreStackPointer(MachineFunction &MF) {
389   const auto *FrameLowering = static_cast<const WebAssemblyFrameLowering *>(
390       MF.getSubtarget().getFrameLowering());
391   if (!FrameLowering->needsPrologForEH(MF))
392     return false;
393   bool Changed = false;
394 
395   for (auto &MBB : MF) {
396     if (!MBB.isEHPad())
397       continue;
398     Changed = true;
399 
400     // Insert __stack_pointer restoring instructions at the beginning of each EH
401     // pad, after the catch instruction. Here it is safe to assume that SP32
402     // holds the latest value of __stack_pointer, because the only exception for
403     // this case is when a function uses the red zone, but that only happens
404     // with leaf functions, and we don't restore __stack_pointer in leaf
405     // functions anyway.
406     auto InsertPos = MBB.begin();
407     if (InsertPos->isEHLabel()) // EH pad starts with an EH label
408       ++InsertPos;
409     if (InsertPos->getOpcode() == WebAssembly::CATCH)
410       ++InsertPos;
411     FrameLowering->writeSPToGlobal(FrameLowering->getSPReg(MF), MF, MBB,
412                                    InsertPos, MBB.begin()->getDebugLoc());
413   }
414   return Changed;
415 }
416