xref: /freebsd/contrib/llvm-project/llvm/lib/ExecutionEngine/Orc/SimpleRemoteEPC.cpp (revision 7ef62cebc2f965b0f640263e179276928885e33d)
1 //===------- SimpleRemoteEPC.cpp -- Simple remote executor control --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h"
10 #include "llvm/ExecutionEngine/Orc/EPCGenericJITLinkMemoryManager.h"
11 #include "llvm/ExecutionEngine/Orc/EPCGenericMemoryAccess.h"
12 #include "llvm/ExecutionEngine/Orc/Shared/OrcRTBridge.h"
13 #include "llvm/Support/FormatVariadic.h"
14 
15 #define DEBUG_TYPE "orc"
16 
17 namespace llvm {
18 namespace orc {
19 
20 SimpleRemoteEPC::~SimpleRemoteEPC() {
21 #ifndef NDEBUG
22   std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
23   assert(Disconnected && "Destroyed without disconnection");
24 #endif // NDEBUG
25 }
26 
27 Expected<tpctypes::DylibHandle>
28 SimpleRemoteEPC::loadDylib(const char *DylibPath) {
29   return DylibMgr->open(DylibPath, 0);
30 }
31 
32 Expected<std::vector<tpctypes::LookupResult>>
33 SimpleRemoteEPC::lookupSymbols(ArrayRef<LookupRequest> Request) {
34   std::vector<tpctypes::LookupResult> Result;
35 
36   for (auto &Element : Request) {
37     if (auto R = DylibMgr->lookup(Element.Handle, Element.Symbols)) {
38       Result.push_back({});
39       Result.back().reserve(R->size());
40       for (auto Addr : *R)
41         Result.back().push_back(Addr);
42     } else
43       return R.takeError();
44   }
45   return std::move(Result);
46 }
47 
48 Expected<int32_t> SimpleRemoteEPC::runAsMain(ExecutorAddr MainFnAddr,
49                                              ArrayRef<std::string> Args) {
50   int64_t Result = 0;
51   if (auto Err = callSPSWrapper<rt::SPSRunAsMainSignature>(
52           RunAsMainAddr, Result, ExecutorAddr(MainFnAddr), Args))
53     return std::move(Err);
54   return Result;
55 }
56 
57 Expected<int32_t> SimpleRemoteEPC::runAsVoidFunction(ExecutorAddr VoidFnAddr) {
58   int32_t Result = 0;
59   if (auto Err = callSPSWrapper<rt::SPSRunAsVoidFunctionSignature>(
60           RunAsVoidFunctionAddr, Result, ExecutorAddr(VoidFnAddr)))
61     return std::move(Err);
62   return Result;
63 }
64 
65 Expected<int32_t> SimpleRemoteEPC::runAsIntFunction(ExecutorAddr IntFnAddr,
66                                                     int Arg) {
67   int32_t Result = 0;
68   if (auto Err = callSPSWrapper<rt::SPSRunAsIntFunctionSignature>(
69           RunAsIntFunctionAddr, Result, ExecutorAddr(IntFnAddr), Arg))
70     return std::move(Err);
71   return Result;
72 }
73 
74 void SimpleRemoteEPC::callWrapperAsync(ExecutorAddr WrapperFnAddr,
75                                        IncomingWFRHandler OnComplete,
76                                        ArrayRef<char> ArgBuffer) {
77   uint64_t SeqNo;
78   {
79     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
80     SeqNo = getNextSeqNo();
81     assert(!PendingCallWrapperResults.count(SeqNo) && "SeqNo already in use");
82     PendingCallWrapperResults[SeqNo] = std::move(OnComplete);
83   }
84 
85   if (auto Err = sendMessage(SimpleRemoteEPCOpcode::CallWrapper, SeqNo,
86                              WrapperFnAddr, ArgBuffer)) {
87     IncomingWFRHandler H;
88 
89     // We just registered OnComplete, but there may be a race between this
90     // thread returning from sendMessage and handleDisconnect being called from
91     // the transport's listener thread. If handleDisconnect gets there first
92     // then it will have failed 'H' for us. If we get there first (or if
93     // handleDisconnect already ran) then we need to take care of it.
94     {
95       std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
96       auto I = PendingCallWrapperResults.find(SeqNo);
97       if (I != PendingCallWrapperResults.end()) {
98         H = std::move(I->second);
99         PendingCallWrapperResults.erase(I);
100       }
101     }
102 
103     if (H)
104       H(shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));
105 
106     getExecutionSession().reportError(std::move(Err));
107   }
108 }
109 
110 Error SimpleRemoteEPC::disconnect() {
111   T->disconnect();
112   D->shutdown();
113   std::unique_lock<std::mutex> Lock(SimpleRemoteEPCMutex);
114   DisconnectCV.wait(Lock, [this] { return Disconnected; });
115   return std::move(DisconnectErr);
116 }
117 
118 Expected<SimpleRemoteEPCTransportClient::HandleMessageAction>
119 SimpleRemoteEPC::handleMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
120                                ExecutorAddr TagAddr,
121                                SimpleRemoteEPCArgBytesVector ArgBytes) {
122 
123   LLVM_DEBUG({
124     dbgs() << "SimpleRemoteEPC::handleMessage: opc = ";
125     switch (OpC) {
126     case SimpleRemoteEPCOpcode::Setup:
127       dbgs() << "Setup";
128       assert(SeqNo == 0 && "Non-zero SeqNo for Setup?");
129       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Setup?");
130       break;
131     case SimpleRemoteEPCOpcode::Hangup:
132       dbgs() << "Hangup";
133       assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
134       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Hangup?");
135       break;
136     case SimpleRemoteEPCOpcode::Result:
137       dbgs() << "Result";
138       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Result?");
139       break;
140     case SimpleRemoteEPCOpcode::CallWrapper:
141       dbgs() << "CallWrapper";
142       break;
143     }
144     dbgs() << ", seqno = " << SeqNo
145            << ", tag-addr = " << formatv("{0:x}", TagAddr.getValue())
146            << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
147            << " bytes\n";
148   });
149 
150   using UT = std::underlying_type_t<SimpleRemoteEPCOpcode>;
151   if (static_cast<UT>(OpC) > static_cast<UT>(SimpleRemoteEPCOpcode::LastOpC))
152     return make_error<StringError>("Unexpected opcode",
153                                    inconvertibleErrorCode());
154 
155   switch (OpC) {
156   case SimpleRemoteEPCOpcode::Setup:
157     if (auto Err = handleSetup(SeqNo, TagAddr, std::move(ArgBytes)))
158       return std::move(Err);
159     break;
160   case SimpleRemoteEPCOpcode::Hangup:
161     T->disconnect();
162     if (auto Err = handleHangup(std::move(ArgBytes)))
163       return std::move(Err);
164     return EndSession;
165   case SimpleRemoteEPCOpcode::Result:
166     if (auto Err = handleResult(SeqNo, TagAddr, std::move(ArgBytes)))
167       return std::move(Err);
168     break;
169   case SimpleRemoteEPCOpcode::CallWrapper:
170     handleCallWrapper(SeqNo, TagAddr, std::move(ArgBytes));
171     break;
172   }
173   return ContinueSession;
174 }
175 
176 void SimpleRemoteEPC::handleDisconnect(Error Err) {
177   LLVM_DEBUG({
178     dbgs() << "SimpleRemoteEPC::handleDisconnect: "
179            << (Err ? "failure" : "success") << "\n";
180   });
181 
182   PendingCallWrapperResultsMap TmpPending;
183 
184   {
185     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
186     std::swap(TmpPending, PendingCallWrapperResults);
187   }
188 
189   for (auto &KV : TmpPending)
190     KV.second(
191         shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));
192 
193   std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
194   DisconnectErr = joinErrors(std::move(DisconnectErr), std::move(Err));
195   Disconnected = true;
196   DisconnectCV.notify_all();
197 }
198 
199 Expected<std::unique_ptr<jitlink::JITLinkMemoryManager>>
200 SimpleRemoteEPC::createDefaultMemoryManager(SimpleRemoteEPC &SREPC) {
201   EPCGenericJITLinkMemoryManager::SymbolAddrs SAs;
202   if (auto Err = SREPC.getBootstrapSymbols(
203           {{SAs.Allocator, rt::SimpleExecutorMemoryManagerInstanceName},
204            {SAs.Reserve, rt::SimpleExecutorMemoryManagerReserveWrapperName},
205            {SAs.Finalize, rt::SimpleExecutorMemoryManagerFinalizeWrapperName},
206            {SAs.Deallocate,
207             rt::SimpleExecutorMemoryManagerDeallocateWrapperName}}))
208     return std::move(Err);
209 
210   return std::make_unique<EPCGenericJITLinkMemoryManager>(SREPC, SAs);
211 }
212 
213 Expected<std::unique_ptr<ExecutorProcessControl::MemoryAccess>>
214 SimpleRemoteEPC::createDefaultMemoryAccess(SimpleRemoteEPC &SREPC) {
215   return nullptr;
216 }
217 
218 Error SimpleRemoteEPC::sendMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
219                                    ExecutorAddr TagAddr,
220                                    ArrayRef<char> ArgBytes) {
221   assert(OpC != SimpleRemoteEPCOpcode::Setup &&
222          "SimpleRemoteEPC sending Setup message? That's the wrong direction.");
223 
224   LLVM_DEBUG({
225     dbgs() << "SimpleRemoteEPC::sendMessage: opc = ";
226     switch (OpC) {
227     case SimpleRemoteEPCOpcode::Hangup:
228       dbgs() << "Hangup";
229       assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
230       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Hangup?");
231       break;
232     case SimpleRemoteEPCOpcode::Result:
233       dbgs() << "Result";
234       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Result?");
235       break;
236     case SimpleRemoteEPCOpcode::CallWrapper:
237       dbgs() << "CallWrapper";
238       break;
239     default:
240       llvm_unreachable("Invalid opcode");
241     }
242     dbgs() << ", seqno = " << SeqNo
243            << ", tag-addr = " << formatv("{0:x}", TagAddr.getValue())
244            << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
245            << " bytes\n";
246   });
247   auto Err = T->sendMessage(OpC, SeqNo, TagAddr, ArgBytes);
248   LLVM_DEBUG({
249     if (Err)
250       dbgs() << "  \\--> SimpleRemoteEPC::sendMessage failed\n";
251   });
252   return Err;
253 }
254 
255 Error SimpleRemoteEPC::handleSetup(uint64_t SeqNo, ExecutorAddr TagAddr,
256                                    SimpleRemoteEPCArgBytesVector ArgBytes) {
257   if (SeqNo != 0)
258     return make_error<StringError>("Setup packet SeqNo not zero",
259                                    inconvertibleErrorCode());
260 
261   if (TagAddr)
262     return make_error<StringError>("Setup packet TagAddr not zero",
263                                    inconvertibleErrorCode());
264 
265   std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
266   auto I = PendingCallWrapperResults.find(0);
267   assert(PendingCallWrapperResults.size() == 1 &&
268          I != PendingCallWrapperResults.end() &&
269          "Setup message handler not connectly set up");
270   auto SetupMsgHandler = std::move(I->second);
271   PendingCallWrapperResults.erase(I);
272 
273   auto WFR =
274       shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
275   SetupMsgHandler(std::move(WFR));
276   return Error::success();
277 }
278 
279 Error SimpleRemoteEPC::setup(Setup S) {
280   using namespace SimpleRemoteEPCDefaultBootstrapSymbolNames;
281 
282   std::promise<MSVCPExpected<SimpleRemoteEPCExecutorInfo>> EIP;
283   auto EIF = EIP.get_future();
284 
285   // Prepare a handler for the setup packet.
286   PendingCallWrapperResults[0] =
287     RunInPlace()(
288       [&](shared::WrapperFunctionResult SetupMsgBytes) {
289         if (const char *ErrMsg = SetupMsgBytes.getOutOfBandError()) {
290           EIP.set_value(
291               make_error<StringError>(ErrMsg, inconvertibleErrorCode()));
292           return;
293         }
294         using SPSSerialize =
295             shared::SPSArgList<shared::SPSSimpleRemoteEPCExecutorInfo>;
296         shared::SPSInputBuffer IB(SetupMsgBytes.data(), SetupMsgBytes.size());
297         SimpleRemoteEPCExecutorInfo EI;
298         if (SPSSerialize::deserialize(IB, EI))
299           EIP.set_value(EI);
300         else
301           EIP.set_value(make_error<StringError>(
302               "Could not deserialize setup message", inconvertibleErrorCode()));
303       });
304 
305   // Start the transport.
306   if (auto Err = T->start())
307     return Err;
308 
309   // Wait for setup packet to arrive.
310   auto EI = EIF.get();
311   if (!EI) {
312     T->disconnect();
313     return EI.takeError();
314   }
315 
316   LLVM_DEBUG({
317     dbgs() << "SimpleRemoteEPC received setup message:\n"
318            << "  Triple: " << EI->TargetTriple << "\n"
319            << "  Page size: " << EI->PageSize << "\n"
320            << "  Bootstrap symbols:\n";
321     for (const auto &KV : EI->BootstrapSymbols)
322       dbgs() << "    " << KV.first() << ": "
323              << formatv("{0:x16}", KV.second.getValue()) << "\n";
324   });
325   TargetTriple = Triple(EI->TargetTriple);
326   PageSize = EI->PageSize;
327   BootstrapSymbols = std::move(EI->BootstrapSymbols);
328 
329   if (auto Err = getBootstrapSymbols(
330           {{JDI.JITDispatchContext, ExecutorSessionObjectName},
331            {JDI.JITDispatchFunction, DispatchFnName},
332            {RunAsMainAddr, rt::RunAsMainWrapperName},
333            {RunAsVoidFunctionAddr, rt::RunAsVoidFunctionWrapperName},
334            {RunAsIntFunctionAddr, rt::RunAsIntFunctionWrapperName}}))
335     return Err;
336 
337   if (auto DM =
338           EPCGenericDylibManager::CreateWithDefaultBootstrapSymbols(*this))
339     DylibMgr = std::make_unique<EPCGenericDylibManager>(std::move(*DM));
340   else
341     return DM.takeError();
342 
343   // Set a default CreateMemoryManager if none is specified.
344   if (!S.CreateMemoryManager)
345     S.CreateMemoryManager = createDefaultMemoryManager;
346 
347   if (auto MemMgr = S.CreateMemoryManager(*this)) {
348     OwnedMemMgr = std::move(*MemMgr);
349     this->MemMgr = OwnedMemMgr.get();
350   } else
351     return MemMgr.takeError();
352 
353   // Set a default CreateMemoryAccess if none is specified.
354   if (!S.CreateMemoryAccess)
355     S.CreateMemoryAccess = createDefaultMemoryAccess;
356 
357   if (auto MemAccess = S.CreateMemoryAccess(*this)) {
358     OwnedMemAccess = std::move(*MemAccess);
359     this->MemAccess = OwnedMemAccess.get();
360   } else
361     return MemAccess.takeError();
362 
363   return Error::success();
364 }
365 
366 Error SimpleRemoteEPC::handleResult(uint64_t SeqNo, ExecutorAddr TagAddr,
367                                     SimpleRemoteEPCArgBytesVector ArgBytes) {
368   IncomingWFRHandler SendResult;
369 
370   if (TagAddr)
371     return make_error<StringError>("Unexpected TagAddr in result message",
372                                    inconvertibleErrorCode());
373 
374   {
375     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
376     auto I = PendingCallWrapperResults.find(SeqNo);
377     if (I == PendingCallWrapperResults.end())
378       return make_error<StringError>("No call for sequence number " +
379                                          Twine(SeqNo),
380                                      inconvertibleErrorCode());
381     SendResult = std::move(I->second);
382     PendingCallWrapperResults.erase(I);
383     releaseSeqNo(SeqNo);
384   }
385 
386   auto WFR =
387       shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
388   SendResult(std::move(WFR));
389   return Error::success();
390 }
391 
392 void SimpleRemoteEPC::handleCallWrapper(
393     uint64_t RemoteSeqNo, ExecutorAddr TagAddr,
394     SimpleRemoteEPCArgBytesVector ArgBytes) {
395   assert(ES && "No ExecutionSession attached");
396   D->dispatch(makeGenericNamedTask(
397       [this, RemoteSeqNo, TagAddr, ArgBytes = std::move(ArgBytes)]() {
398         ES->runJITDispatchHandler(
399             [this, RemoteSeqNo](shared::WrapperFunctionResult WFR) {
400               if (auto Err =
401                       sendMessage(SimpleRemoteEPCOpcode::Result, RemoteSeqNo,
402                                   ExecutorAddr(), {WFR.data(), WFR.size()}))
403                 getExecutionSession().reportError(std::move(Err));
404             },
405             TagAddr.getValue(), ArgBytes);
406       },
407       "callWrapper task"));
408 }
409 
410 Error SimpleRemoteEPC::handleHangup(SimpleRemoteEPCArgBytesVector ArgBytes) {
411   using namespace llvm::orc::shared;
412   auto WFR = WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
413   if (const char *ErrMsg = WFR.getOutOfBandError())
414     return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
415 
416   detail::SPSSerializableError Info;
417   SPSInputBuffer IB(WFR.data(), WFR.size());
418   if (!SPSArgList<SPSError>::deserialize(IB, Info))
419     return make_error<StringError>("Could not deserialize hangup info",
420                                    inconvertibleErrorCode());
421   return fromSPSSerializable(std::move(Info));
422 }
423 
424 } // end namespace orc
425 } // end namespace llvm
426