xref: /freebsd/contrib/llvm-project/llvm/lib/ExecutionEngine/Orc/SimpleRemoteEPC.cpp (revision 770cf0a5f02dc8983a89c6568d741fbc25baa999)
1 //===------- SimpleRemoteEPC.cpp -- Simple remote executor control --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h"
10 #include "llvm/ExecutionEngine/Orc/EPCGenericJITLinkMemoryManager.h"
11 #include "llvm/ExecutionEngine/Orc/Shared/OrcRTBridge.h"
12 #include "llvm/Support/FormatVariadic.h"
13 
14 #define DEBUG_TYPE "orc"
15 
16 namespace llvm {
17 namespace orc {
18 
19 SimpleRemoteEPC::~SimpleRemoteEPC() {
20 #ifndef NDEBUG
21   std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
22   assert(Disconnected && "Destroyed without disconnection");
23 #endif // NDEBUG
24 }
25 
26 Expected<tpctypes::DylibHandle>
27 SimpleRemoteEPC::loadDylib(const char *DylibPath) {
28   return EPCDylibMgr->open(DylibPath, 0);
29 }
30 
31 /// Async helper to chain together calls to DylibMgr::lookupAsync to fulfill all
32 /// all the requests.
33 /// FIXME: The dylib manager should support multiple LookupRequests natively.
34 static void
35 lookupSymbolsAsyncHelper(EPCGenericDylibManager &DylibMgr,
36                          ArrayRef<DylibManager::LookupRequest> Request,
37                          std::vector<tpctypes::LookupResult> Result,
38                          DylibManager::SymbolLookupCompleteFn Complete) {
39   if (Request.empty())
40     return Complete(std::move(Result));
41 
42   auto &Element = Request.front();
43   DylibMgr.lookupAsync(Element.Handle, Element.Symbols,
44                        [&DylibMgr, Request, Complete = std::move(Complete),
45                         Result = std::move(Result)](auto R) mutable {
46                          if (!R)
47                            return Complete(R.takeError());
48                          Result.push_back({});
49                          Result.back().reserve(R->size());
50                          llvm::append_range(Result.back(), *R);
51 
52                          lookupSymbolsAsyncHelper(
53                              DylibMgr, Request.drop_front(), std::move(Result),
54                              std::move(Complete));
55                        });
56 }
57 
58 void SimpleRemoteEPC::lookupSymbolsAsync(ArrayRef<LookupRequest> Request,
59                                          SymbolLookupCompleteFn Complete) {
60   lookupSymbolsAsyncHelper(*EPCDylibMgr, Request, {}, std::move(Complete));
61 }
62 
63 Expected<int32_t> SimpleRemoteEPC::runAsMain(ExecutorAddr MainFnAddr,
64                                              ArrayRef<std::string> Args) {
65   int64_t Result = 0;
66   if (auto Err = callSPSWrapper<rt::SPSRunAsMainSignature>(
67           RunAsMainAddr, Result, MainFnAddr, Args))
68     return std::move(Err);
69   return Result;
70 }
71 
72 Expected<int32_t> SimpleRemoteEPC::runAsVoidFunction(ExecutorAddr VoidFnAddr) {
73   int32_t Result = 0;
74   if (auto Err = callSPSWrapper<rt::SPSRunAsVoidFunctionSignature>(
75           RunAsVoidFunctionAddr, Result, VoidFnAddr))
76     return std::move(Err);
77   return Result;
78 }
79 
80 Expected<int32_t> SimpleRemoteEPC::runAsIntFunction(ExecutorAddr IntFnAddr,
81                                                     int Arg) {
82   int32_t Result = 0;
83   if (auto Err = callSPSWrapper<rt::SPSRunAsIntFunctionSignature>(
84           RunAsIntFunctionAddr, Result, IntFnAddr, Arg))
85     return std::move(Err);
86   return Result;
87 }
88 
89 void SimpleRemoteEPC::callWrapperAsync(ExecutorAddr WrapperFnAddr,
90                                        IncomingWFRHandler OnComplete,
91                                        ArrayRef<char> ArgBuffer) {
92   uint64_t SeqNo;
93   {
94     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
95     SeqNo = getNextSeqNo();
96     assert(!PendingCallWrapperResults.count(SeqNo) && "SeqNo already in use");
97     PendingCallWrapperResults[SeqNo] = std::move(OnComplete);
98   }
99 
100   if (auto Err = sendMessage(SimpleRemoteEPCOpcode::CallWrapper, SeqNo,
101                              WrapperFnAddr, ArgBuffer)) {
102     IncomingWFRHandler H;
103 
104     // We just registered OnComplete, but there may be a race between this
105     // thread returning from sendMessage and handleDisconnect being called from
106     // the transport's listener thread. If handleDisconnect gets there first
107     // then it will have failed 'H' for us. If we get there first (or if
108     // handleDisconnect already ran) then we need to take care of it.
109     {
110       std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
111       auto I = PendingCallWrapperResults.find(SeqNo);
112       if (I != PendingCallWrapperResults.end()) {
113         H = std::move(I->second);
114         PendingCallWrapperResults.erase(I);
115       }
116     }
117 
118     if (H)
119       H(shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));
120 
121     getExecutionSession().reportError(std::move(Err));
122   }
123 }
124 
125 Error SimpleRemoteEPC::disconnect() {
126   T->disconnect();
127   D->shutdown();
128   std::unique_lock<std::mutex> Lock(SimpleRemoteEPCMutex);
129   DisconnectCV.wait(Lock, [this] { return Disconnected; });
130   return std::move(DisconnectErr);
131 }
132 
133 Expected<SimpleRemoteEPCTransportClient::HandleMessageAction>
134 SimpleRemoteEPC::handleMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
135                                ExecutorAddr TagAddr,
136                                SimpleRemoteEPCArgBytesVector ArgBytes) {
137 
138   LLVM_DEBUG({
139     dbgs() << "SimpleRemoteEPC::handleMessage: opc = ";
140     switch (OpC) {
141     case SimpleRemoteEPCOpcode::Setup:
142       dbgs() << "Setup";
143       assert(SeqNo == 0 && "Non-zero SeqNo for Setup?");
144       assert(!TagAddr && "Non-zero TagAddr for Setup?");
145       break;
146     case SimpleRemoteEPCOpcode::Hangup:
147       dbgs() << "Hangup";
148       assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
149       assert(!TagAddr && "Non-zero TagAddr for Hangup?");
150       break;
151     case SimpleRemoteEPCOpcode::Result:
152       dbgs() << "Result";
153       assert(!TagAddr && "Non-zero TagAddr for Result?");
154       break;
155     case SimpleRemoteEPCOpcode::CallWrapper:
156       dbgs() << "CallWrapper";
157       break;
158     }
159     dbgs() << ", seqno = " << SeqNo << ", tag-addr = " << TagAddr
160            << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
161            << " bytes\n";
162   });
163 
164   using UT = std::underlying_type_t<SimpleRemoteEPCOpcode>;
165   if (static_cast<UT>(OpC) > static_cast<UT>(SimpleRemoteEPCOpcode::LastOpC))
166     return make_error<StringError>("Unexpected opcode",
167                                    inconvertibleErrorCode());
168 
169   switch (OpC) {
170   case SimpleRemoteEPCOpcode::Setup:
171     if (auto Err = handleSetup(SeqNo, TagAddr, std::move(ArgBytes)))
172       return std::move(Err);
173     break;
174   case SimpleRemoteEPCOpcode::Hangup:
175     T->disconnect();
176     if (auto Err = handleHangup(std::move(ArgBytes)))
177       return std::move(Err);
178     return EndSession;
179   case SimpleRemoteEPCOpcode::Result:
180     if (auto Err = handleResult(SeqNo, TagAddr, std::move(ArgBytes)))
181       return std::move(Err);
182     break;
183   case SimpleRemoteEPCOpcode::CallWrapper:
184     handleCallWrapper(SeqNo, TagAddr, std::move(ArgBytes));
185     break;
186   }
187   return ContinueSession;
188 }
189 
190 void SimpleRemoteEPC::handleDisconnect(Error Err) {
191   LLVM_DEBUG({
192     dbgs() << "SimpleRemoteEPC::handleDisconnect: "
193            << (Err ? "failure" : "success") << "\n";
194   });
195 
196   PendingCallWrapperResultsMap TmpPending;
197 
198   {
199     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
200     std::swap(TmpPending, PendingCallWrapperResults);
201   }
202 
203   for (auto &KV : TmpPending)
204     KV.second(
205         shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));
206 
207   std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
208   DisconnectErr = joinErrors(std::move(DisconnectErr), std::move(Err));
209   Disconnected = true;
210   DisconnectCV.notify_all();
211 }
212 
213 Expected<std::unique_ptr<jitlink::JITLinkMemoryManager>>
214 SimpleRemoteEPC::createDefaultMemoryManager(SimpleRemoteEPC &SREPC) {
215   EPCGenericJITLinkMemoryManager::SymbolAddrs SAs;
216   if (auto Err = SREPC.getBootstrapSymbols(
217           {{SAs.Allocator, rt::SimpleExecutorMemoryManagerInstanceName},
218            {SAs.Reserve, rt::SimpleExecutorMemoryManagerReserveWrapperName},
219            {SAs.Finalize, rt::SimpleExecutorMemoryManagerFinalizeWrapperName},
220            {SAs.Deallocate,
221             rt::SimpleExecutorMemoryManagerDeallocateWrapperName}}))
222     return std::move(Err);
223 
224   return std::make_unique<EPCGenericJITLinkMemoryManager>(SREPC, SAs);
225 }
226 
227 Expected<std::unique_ptr<MemoryAccess>>
228 SimpleRemoteEPC::createDefaultMemoryAccess(SimpleRemoteEPC &SREPC) {
229   EPCGenericMemoryAccess::FuncAddrs FAs;
230   if (auto Err = SREPC.getBootstrapSymbols(
231           {{FAs.WriteUInt8s, rt::MemoryWriteUInt8sWrapperName},
232            {FAs.WriteUInt16s, rt::MemoryWriteUInt16sWrapperName},
233            {FAs.WriteUInt32s, rt::MemoryWriteUInt32sWrapperName},
234            {FAs.WriteUInt64s, rt::MemoryWriteUInt64sWrapperName},
235            {FAs.WriteBuffers, rt::MemoryWriteBuffersWrapperName},
236            {FAs.WritePointers, rt::MemoryWritePointersWrapperName},
237            {FAs.ReadUInt8s, rt::MemoryReadUInt8sWrapperName},
238            {FAs.ReadUInt16s, rt::MemoryReadUInt16sWrapperName},
239            {FAs.ReadUInt32s, rt::MemoryReadUInt32sWrapperName},
240            {FAs.ReadUInt64s, rt::MemoryReadUInt64sWrapperName},
241            {FAs.ReadBuffers, rt::MemoryReadBuffersWrapperName},
242            {FAs.ReadStrings, rt::MemoryReadStringsWrapperName}}))
243     return std::move(Err);
244 
245   return std::make_unique<EPCGenericMemoryAccess>(SREPC, FAs);
246 }
247 
248 Error SimpleRemoteEPC::sendMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
249                                    ExecutorAddr TagAddr,
250                                    ArrayRef<char> ArgBytes) {
251   assert(OpC != SimpleRemoteEPCOpcode::Setup &&
252          "SimpleRemoteEPC sending Setup message? That's the wrong direction.");
253 
254   LLVM_DEBUG({
255     dbgs() << "SimpleRemoteEPC::sendMessage: opc = ";
256     switch (OpC) {
257     case SimpleRemoteEPCOpcode::Hangup:
258       dbgs() << "Hangup";
259       assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
260       assert(!TagAddr && "Non-zero TagAddr for Hangup?");
261       break;
262     case SimpleRemoteEPCOpcode::Result:
263       dbgs() << "Result";
264       assert(!TagAddr && "Non-zero TagAddr for Result?");
265       break;
266     case SimpleRemoteEPCOpcode::CallWrapper:
267       dbgs() << "CallWrapper";
268       break;
269     default:
270       llvm_unreachable("Invalid opcode");
271     }
272     dbgs() << ", seqno = " << SeqNo << ", tag-addr = " << TagAddr
273            << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
274            << " bytes\n";
275   });
276   auto Err = T->sendMessage(OpC, SeqNo, TagAddr, ArgBytes);
277   LLVM_DEBUG({
278     if (Err)
279       dbgs() << "  \\--> SimpleRemoteEPC::sendMessage failed\n";
280   });
281   return Err;
282 }
283 
284 Error SimpleRemoteEPC::handleSetup(uint64_t SeqNo, ExecutorAddr TagAddr,
285                                    SimpleRemoteEPCArgBytesVector ArgBytes) {
286   if (SeqNo != 0)
287     return make_error<StringError>("Setup packet SeqNo not zero",
288                                    inconvertibleErrorCode());
289 
290   if (TagAddr)
291     return make_error<StringError>("Setup packet TagAddr not zero",
292                                    inconvertibleErrorCode());
293 
294   std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
295   auto I = PendingCallWrapperResults.find(0);
296   assert(PendingCallWrapperResults.size() == 1 &&
297          I != PendingCallWrapperResults.end() &&
298          "Setup message handler not connectly set up");
299   auto SetupMsgHandler = std::move(I->second);
300   PendingCallWrapperResults.erase(I);
301 
302   auto WFR =
303       shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
304   SetupMsgHandler(std::move(WFR));
305   return Error::success();
306 }
307 
308 Error SimpleRemoteEPC::setup(Setup S) {
309   using namespace SimpleRemoteEPCDefaultBootstrapSymbolNames;
310 
311   std::promise<MSVCPExpected<SimpleRemoteEPCExecutorInfo>> EIP;
312   auto EIF = EIP.get_future();
313 
314   // Prepare a handler for the setup packet.
315   PendingCallWrapperResults[0] =
316     RunInPlace()(
317       [&](shared::WrapperFunctionResult SetupMsgBytes) {
318         if (const char *ErrMsg = SetupMsgBytes.getOutOfBandError()) {
319           EIP.set_value(
320               make_error<StringError>(ErrMsg, inconvertibleErrorCode()));
321           return;
322         }
323         using SPSSerialize =
324             shared::SPSArgList<shared::SPSSimpleRemoteEPCExecutorInfo>;
325         shared::SPSInputBuffer IB(SetupMsgBytes.data(), SetupMsgBytes.size());
326         SimpleRemoteEPCExecutorInfo EI;
327         if (SPSSerialize::deserialize(IB, EI))
328           EIP.set_value(EI);
329         else
330           EIP.set_value(make_error<StringError>(
331               "Could not deserialize setup message", inconvertibleErrorCode()));
332       });
333 
334   // Start the transport.
335   if (auto Err = T->start())
336     return Err;
337 
338   // Wait for setup packet to arrive.
339   auto EI = EIF.get();
340   if (!EI) {
341     T->disconnect();
342     return EI.takeError();
343   }
344 
345   LLVM_DEBUG({
346     dbgs() << "SimpleRemoteEPC received setup message:\n"
347            << "  Triple: " << EI->TargetTriple << "\n"
348            << "  Page size: " << EI->PageSize << "\n"
349            << "  Bootstrap map" << (EI->BootstrapMap.empty() ? " empty" : ":")
350            << "\n";
351     for (const auto &KV : EI->BootstrapMap)
352       dbgs() << "    " << KV.first() << ": " << KV.second.size()
353              << "-byte SPS encoded buffer\n";
354     dbgs() << "  Bootstrap symbols"
355            << (EI->BootstrapSymbols.empty() ? " empty" : ":") << "\n";
356     for (const auto &KV : EI->BootstrapSymbols)
357       dbgs() << "    " << KV.first() << ": " << KV.second << "\n";
358   });
359   TargetTriple = Triple(EI->TargetTriple);
360   PageSize = EI->PageSize;
361   BootstrapMap = std::move(EI->BootstrapMap);
362   BootstrapSymbols = std::move(EI->BootstrapSymbols);
363 
364   if (auto Err = getBootstrapSymbols(
365           {{JDI.JITDispatchContext, ExecutorSessionObjectName},
366            {JDI.JITDispatchFunction, DispatchFnName},
367            {RunAsMainAddr, rt::RunAsMainWrapperName},
368            {RunAsVoidFunctionAddr, rt::RunAsVoidFunctionWrapperName},
369            {RunAsIntFunctionAddr, rt::RunAsIntFunctionWrapperName}}))
370     return Err;
371 
372   if (auto DM =
373           EPCGenericDylibManager::CreateWithDefaultBootstrapSymbols(*this))
374     EPCDylibMgr = std::make_unique<EPCGenericDylibManager>(std::move(*DM));
375   else
376     return DM.takeError();
377 
378   // Set a default CreateMemoryManager if none is specified.
379   if (!S.CreateMemoryManager)
380     S.CreateMemoryManager = createDefaultMemoryManager;
381 
382   if (auto MemMgr = S.CreateMemoryManager(*this)) {
383     OwnedMemMgr = std::move(*MemMgr);
384     this->MemMgr = OwnedMemMgr.get();
385   } else
386     return MemMgr.takeError();
387 
388   // Set a default CreateMemoryAccess if none is specified.
389   if (!S.CreateMemoryAccess)
390     S.CreateMemoryAccess = createDefaultMemoryAccess;
391 
392   if (auto MemAccess = S.CreateMemoryAccess(*this)) {
393     OwnedMemAccess = std::move(*MemAccess);
394     this->MemAccess = OwnedMemAccess.get();
395   } else
396     return MemAccess.takeError();
397 
398   return Error::success();
399 }
400 
401 Error SimpleRemoteEPC::handleResult(uint64_t SeqNo, ExecutorAddr TagAddr,
402                                     SimpleRemoteEPCArgBytesVector ArgBytes) {
403   IncomingWFRHandler SendResult;
404 
405   if (TagAddr)
406     return make_error<StringError>("Unexpected TagAddr in result message",
407                                    inconvertibleErrorCode());
408 
409   {
410     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
411     auto I = PendingCallWrapperResults.find(SeqNo);
412     if (I == PendingCallWrapperResults.end())
413       return make_error<StringError>("No call for sequence number " +
414                                          Twine(SeqNo),
415                                      inconvertibleErrorCode());
416     SendResult = std::move(I->second);
417     PendingCallWrapperResults.erase(I);
418     releaseSeqNo(SeqNo);
419   }
420 
421   auto WFR =
422       shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
423   SendResult(std::move(WFR));
424   return Error::success();
425 }
426 
427 void SimpleRemoteEPC::handleCallWrapper(
428     uint64_t RemoteSeqNo, ExecutorAddr TagAddr,
429     SimpleRemoteEPCArgBytesVector ArgBytes) {
430   assert(ES && "No ExecutionSession attached");
431   D->dispatch(makeGenericNamedTask(
432       [this, RemoteSeqNo, TagAddr, ArgBytes = std::move(ArgBytes)]() {
433         ES->runJITDispatchHandler(
434             [this, RemoteSeqNo](shared::WrapperFunctionResult WFR) {
435               if (auto Err =
436                       sendMessage(SimpleRemoteEPCOpcode::Result, RemoteSeqNo,
437                                   ExecutorAddr(), {WFR.data(), WFR.size()}))
438                 getExecutionSession().reportError(std::move(Err));
439             },
440             TagAddr, ArgBytes);
441       },
442       "callWrapper task"));
443 }
444 
445 Error SimpleRemoteEPC::handleHangup(SimpleRemoteEPCArgBytesVector ArgBytes) {
446   using namespace llvm::orc::shared;
447   auto WFR = WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
448   if (const char *ErrMsg = WFR.getOutOfBandError())
449     return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
450 
451   detail::SPSSerializableError Info;
452   SPSInputBuffer IB(WFR.data(), WFR.size());
453   if (!SPSArgList<SPSError>::deserialize(IB, Info))
454     return make_error<StringError>("Could not deserialize hangup info",
455                                    inconvertibleErrorCode());
456   return fromSPSSerializable(std::move(Info));
457 }
458 
459 } // end namespace orc
460 } // end namespace llvm
461