xref: /freebsd/contrib/llvm-project/llvm/lib/ExecutionEngine/Orc/SimpleRemoteEPC.cpp (revision 924226fba12cc9a228c73b956e1b7fa24c60b055)
1 //===------- SimpleRemoteEPC.cpp -- Simple remote executor control --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h"
10 #include "llvm/ExecutionEngine/Orc/EPCGenericJITLinkMemoryManager.h"
11 #include "llvm/ExecutionEngine/Orc/EPCGenericMemoryAccess.h"
12 #include "llvm/ExecutionEngine/Orc/Shared/OrcRTBridge.h"
13 #include "llvm/Support/FormatVariadic.h"
14 
15 #define DEBUG_TYPE "orc"
16 
17 namespace llvm {
18 namespace orc {
19 
20 SimpleRemoteEPC::~SimpleRemoteEPC() {
21 #ifndef NDEBUG
22   std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
23   assert(Disconnected && "Destroyed without disconnection");
24 #endif // NDEBUG
25 }
26 
27 Expected<tpctypes::DylibHandle>
28 SimpleRemoteEPC::loadDylib(const char *DylibPath) {
29   return DylibMgr->open(DylibPath, 0);
30 }
31 
32 Expected<std::vector<tpctypes::LookupResult>>
33 SimpleRemoteEPC::lookupSymbols(ArrayRef<LookupRequest> Request) {
34   std::vector<tpctypes::LookupResult> Result;
35 
36   for (auto &Element : Request) {
37     if (auto R = DylibMgr->lookup(Element.Handle, Element.Symbols)) {
38       Result.push_back({});
39       Result.back().reserve(R->size());
40       for (auto Addr : *R)
41         Result.back().push_back(Addr.getValue());
42     } else
43       return R.takeError();
44   }
45   return std::move(Result);
46 }
47 
48 Expected<int32_t> SimpleRemoteEPC::runAsMain(ExecutorAddr MainFnAddr,
49                                              ArrayRef<std::string> Args) {
50   int64_t Result = 0;
51   if (auto Err = callSPSWrapper<rt::SPSRunAsMainSignature>(
52           RunAsMainAddr, Result, ExecutorAddr(MainFnAddr), Args))
53     return std::move(Err);
54   return Result;
55 }
56 
57 void SimpleRemoteEPC::callWrapperAsync(ExecutorAddr WrapperFnAddr,
58                                        IncomingWFRHandler OnComplete,
59                                        ArrayRef<char> ArgBuffer) {
60   uint64_t SeqNo;
61   {
62     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
63     SeqNo = getNextSeqNo();
64     assert(!PendingCallWrapperResults.count(SeqNo) && "SeqNo already in use");
65     PendingCallWrapperResults[SeqNo] = std::move(OnComplete);
66   }
67 
68   if (auto Err = sendMessage(SimpleRemoteEPCOpcode::CallWrapper, SeqNo,
69                              WrapperFnAddr, ArgBuffer)) {
70     IncomingWFRHandler H;
71 
72     // We just registered OnComplete, but there may be a race between this
73     // thread returning from sendMessage and handleDisconnect being called from
74     // the transport's listener thread. If handleDisconnect gets there first
75     // then it will have failed 'H' for us. If we get there first (or if
76     // handleDisconnect already ran) then we need to take care of it.
77     {
78       std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
79       auto I = PendingCallWrapperResults.find(SeqNo);
80       if (I != PendingCallWrapperResults.end()) {
81         H = std::move(I->second);
82         PendingCallWrapperResults.erase(I);
83       }
84     }
85 
86     if (H)
87       H(shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));
88 
89     getExecutionSession().reportError(std::move(Err));
90   }
91 }
92 
93 Error SimpleRemoteEPC::disconnect() {
94   T->disconnect();
95   D->shutdown();
96   std::unique_lock<std::mutex> Lock(SimpleRemoteEPCMutex);
97   DisconnectCV.wait(Lock, [this] { return Disconnected; });
98   return std::move(DisconnectErr);
99 }
100 
101 Expected<SimpleRemoteEPCTransportClient::HandleMessageAction>
102 SimpleRemoteEPC::handleMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
103                                ExecutorAddr TagAddr,
104                                SimpleRemoteEPCArgBytesVector ArgBytes) {
105 
106   LLVM_DEBUG({
107     dbgs() << "SimpleRemoteEPC::handleMessage: opc = ";
108     switch (OpC) {
109     case SimpleRemoteEPCOpcode::Setup:
110       dbgs() << "Setup";
111       assert(SeqNo == 0 && "Non-zero SeqNo for Setup?");
112       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Setup?");
113       break;
114     case SimpleRemoteEPCOpcode::Hangup:
115       dbgs() << "Hangup";
116       assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
117       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Hangup?");
118       break;
119     case SimpleRemoteEPCOpcode::Result:
120       dbgs() << "Result";
121       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Result?");
122       break;
123     case SimpleRemoteEPCOpcode::CallWrapper:
124       dbgs() << "CallWrapper";
125       break;
126     }
127     dbgs() << ", seqno = " << SeqNo
128            << ", tag-addr = " << formatv("{0:x}", TagAddr.getValue())
129            << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
130            << " bytes\n";
131   });
132 
133   using UT = std::underlying_type_t<SimpleRemoteEPCOpcode>;
134   if (static_cast<UT>(OpC) > static_cast<UT>(SimpleRemoteEPCOpcode::LastOpC))
135     return make_error<StringError>("Unexpected opcode",
136                                    inconvertibleErrorCode());
137 
138   switch (OpC) {
139   case SimpleRemoteEPCOpcode::Setup:
140     if (auto Err = handleSetup(SeqNo, TagAddr, std::move(ArgBytes)))
141       return std::move(Err);
142     break;
143   case SimpleRemoteEPCOpcode::Hangup:
144     T->disconnect();
145     if (auto Err = handleHangup(std::move(ArgBytes)))
146       return std::move(Err);
147     return EndSession;
148   case SimpleRemoteEPCOpcode::Result:
149     if (auto Err = handleResult(SeqNo, TagAddr, std::move(ArgBytes)))
150       return std::move(Err);
151     break;
152   case SimpleRemoteEPCOpcode::CallWrapper:
153     handleCallWrapper(SeqNo, TagAddr, std::move(ArgBytes));
154     break;
155   }
156   return ContinueSession;
157 }
158 
159 void SimpleRemoteEPC::handleDisconnect(Error Err) {
160   LLVM_DEBUG({
161     dbgs() << "SimpleRemoteEPC::handleDisconnect: "
162            << (Err ? "failure" : "success") << "\n";
163   });
164 
165   PendingCallWrapperResultsMap TmpPending;
166 
167   {
168     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
169     std::swap(TmpPending, PendingCallWrapperResults);
170   }
171 
172   for (auto &KV : TmpPending)
173     KV.second(
174         shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));
175 
176   std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
177   DisconnectErr = joinErrors(std::move(DisconnectErr), std::move(Err));
178   Disconnected = true;
179   DisconnectCV.notify_all();
180 }
181 
182 Expected<std::unique_ptr<jitlink::JITLinkMemoryManager>>
183 SimpleRemoteEPC::createDefaultMemoryManager(SimpleRemoteEPC &SREPC) {
184   EPCGenericJITLinkMemoryManager::SymbolAddrs SAs;
185   if (auto Err = SREPC.getBootstrapSymbols(
186           {{SAs.Allocator, rt::SimpleExecutorMemoryManagerInstanceName},
187            {SAs.Reserve, rt::SimpleExecutorMemoryManagerReserveWrapperName},
188            {SAs.Finalize, rt::SimpleExecutorMemoryManagerFinalizeWrapperName},
189            {SAs.Deallocate,
190             rt::SimpleExecutorMemoryManagerDeallocateWrapperName}}))
191     return std::move(Err);
192 
193   return std::make_unique<EPCGenericJITLinkMemoryManager>(SREPC, SAs);
194 }
195 
196 Expected<std::unique_ptr<ExecutorProcessControl::MemoryAccess>>
197 SimpleRemoteEPC::createDefaultMemoryAccess(SimpleRemoteEPC &SREPC) {
198   return nullptr;
199 }
200 
201 Error SimpleRemoteEPC::sendMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,
202                                    ExecutorAddr TagAddr,
203                                    ArrayRef<char> ArgBytes) {
204   assert(OpC != SimpleRemoteEPCOpcode::Setup &&
205          "SimpleRemoteEPC sending Setup message? That's the wrong direction.");
206 
207   LLVM_DEBUG({
208     dbgs() << "SimpleRemoteEPC::sendMessage: opc = ";
209     switch (OpC) {
210     case SimpleRemoteEPCOpcode::Hangup:
211       dbgs() << "Hangup";
212       assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");
213       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Hangup?");
214       break;
215     case SimpleRemoteEPCOpcode::Result:
216       dbgs() << "Result";
217       assert(TagAddr.getValue() == 0 && "Non-zero TagAddr for Result?");
218       break;
219     case SimpleRemoteEPCOpcode::CallWrapper:
220       dbgs() << "CallWrapper";
221       break;
222     default:
223       llvm_unreachable("Invalid opcode");
224     }
225     dbgs() << ", seqno = " << SeqNo
226            << ", tag-addr = " << formatv("{0:x}", TagAddr.getValue())
227            << ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())
228            << " bytes\n";
229   });
230   auto Err = T->sendMessage(OpC, SeqNo, TagAddr, ArgBytes);
231   LLVM_DEBUG({
232     if (Err)
233       dbgs() << "  \\--> SimpleRemoteEPC::sendMessage failed\n";
234   });
235   return Err;
236 }
237 
238 Error SimpleRemoteEPC::handleSetup(uint64_t SeqNo, ExecutorAddr TagAddr,
239                                    SimpleRemoteEPCArgBytesVector ArgBytes) {
240   if (SeqNo != 0)
241     return make_error<StringError>("Setup packet SeqNo not zero",
242                                    inconvertibleErrorCode());
243 
244   if (TagAddr)
245     return make_error<StringError>("Setup packet TagAddr not zero",
246                                    inconvertibleErrorCode());
247 
248   std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
249   auto I = PendingCallWrapperResults.find(0);
250   assert(PendingCallWrapperResults.size() == 1 &&
251          I != PendingCallWrapperResults.end() &&
252          "Setup message handler not connectly set up");
253   auto SetupMsgHandler = std::move(I->second);
254   PendingCallWrapperResults.erase(I);
255 
256   auto WFR =
257       shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
258   SetupMsgHandler(std::move(WFR));
259   return Error::success();
260 }
261 
262 Error SimpleRemoteEPC::setup(Setup S) {
263   using namespace SimpleRemoteEPCDefaultBootstrapSymbolNames;
264 
265   std::promise<MSVCPExpected<SimpleRemoteEPCExecutorInfo>> EIP;
266   auto EIF = EIP.get_future();
267 
268   // Prepare a handler for the setup packet.
269   PendingCallWrapperResults[0] =
270     RunInPlace()(
271       [&](shared::WrapperFunctionResult SetupMsgBytes) {
272         if (const char *ErrMsg = SetupMsgBytes.getOutOfBandError()) {
273           EIP.set_value(
274               make_error<StringError>(ErrMsg, inconvertibleErrorCode()));
275           return;
276         }
277         using SPSSerialize =
278             shared::SPSArgList<shared::SPSSimpleRemoteEPCExecutorInfo>;
279         shared::SPSInputBuffer IB(SetupMsgBytes.data(), SetupMsgBytes.size());
280         SimpleRemoteEPCExecutorInfo EI;
281         if (SPSSerialize::deserialize(IB, EI))
282           EIP.set_value(EI);
283         else
284           EIP.set_value(make_error<StringError>(
285               "Could not deserialize setup message", inconvertibleErrorCode()));
286       });
287 
288   // Start the transport.
289   if (auto Err = T->start())
290     return Err;
291 
292   // Wait for setup packet to arrive.
293   auto EI = EIF.get();
294   if (!EI) {
295     T->disconnect();
296     return EI.takeError();
297   }
298 
299   LLVM_DEBUG({
300     dbgs() << "SimpleRemoteEPC received setup message:\n"
301            << "  Triple: " << EI->TargetTriple << "\n"
302            << "  Page size: " << EI->PageSize << "\n"
303            << "  Bootstrap symbols:\n";
304     for (const auto &KV : EI->BootstrapSymbols)
305       dbgs() << "    " << KV.first() << ": "
306              << formatv("{0:x16}", KV.second.getValue()) << "\n";
307   });
308   TargetTriple = Triple(EI->TargetTriple);
309   PageSize = EI->PageSize;
310   BootstrapSymbols = std::move(EI->BootstrapSymbols);
311 
312   if (auto Err = getBootstrapSymbols(
313           {{JDI.JITDispatchContext, ExecutorSessionObjectName},
314            {JDI.JITDispatchFunction, DispatchFnName},
315            {RunAsMainAddr, rt::RunAsMainWrapperName}}))
316     return Err;
317 
318   if (auto DM =
319           EPCGenericDylibManager::CreateWithDefaultBootstrapSymbols(*this))
320     DylibMgr = std::make_unique<EPCGenericDylibManager>(std::move(*DM));
321   else
322     return DM.takeError();
323 
324   // Set a default CreateMemoryManager if none is specified.
325   if (!S.CreateMemoryManager)
326     S.CreateMemoryManager = createDefaultMemoryManager;
327 
328   if (auto MemMgr = S.CreateMemoryManager(*this)) {
329     OwnedMemMgr = std::move(*MemMgr);
330     this->MemMgr = OwnedMemMgr.get();
331   } else
332     return MemMgr.takeError();
333 
334   // Set a default CreateMemoryAccess if none is specified.
335   if (!S.CreateMemoryAccess)
336     S.CreateMemoryAccess = createDefaultMemoryAccess;
337 
338   if (auto MemAccess = S.CreateMemoryAccess(*this)) {
339     OwnedMemAccess = std::move(*MemAccess);
340     this->MemAccess = OwnedMemAccess.get();
341   } else
342     return MemAccess.takeError();
343 
344   return Error::success();
345 }
346 
347 Error SimpleRemoteEPC::handleResult(uint64_t SeqNo, ExecutorAddr TagAddr,
348                                     SimpleRemoteEPCArgBytesVector ArgBytes) {
349   IncomingWFRHandler SendResult;
350 
351   if (TagAddr)
352     return make_error<StringError>("Unexpected TagAddr in result message",
353                                    inconvertibleErrorCode());
354 
355   {
356     std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);
357     auto I = PendingCallWrapperResults.find(SeqNo);
358     if (I == PendingCallWrapperResults.end())
359       return make_error<StringError>("No call for sequence number " +
360                                          Twine(SeqNo),
361                                      inconvertibleErrorCode());
362     SendResult = std::move(I->second);
363     PendingCallWrapperResults.erase(I);
364     releaseSeqNo(SeqNo);
365   }
366 
367   auto WFR =
368       shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
369   SendResult(std::move(WFR));
370   return Error::success();
371 }
372 
373 void SimpleRemoteEPC::handleCallWrapper(
374     uint64_t RemoteSeqNo, ExecutorAddr TagAddr,
375     SimpleRemoteEPCArgBytesVector ArgBytes) {
376   assert(ES && "No ExecutionSession attached");
377   D->dispatch(makeGenericNamedTask(
378       [this, RemoteSeqNo, TagAddr, ArgBytes = std::move(ArgBytes)]() {
379         ES->runJITDispatchHandler(
380             [this, RemoteSeqNo](shared::WrapperFunctionResult WFR) {
381               if (auto Err =
382                       sendMessage(SimpleRemoteEPCOpcode::Result, RemoteSeqNo,
383                                   ExecutorAddr(), {WFR.data(), WFR.size()}))
384                 getExecutionSession().reportError(std::move(Err));
385             },
386             TagAddr.getValue(), ArgBytes);
387       },
388       "callWrapper task"));
389 }
390 
391 Error SimpleRemoteEPC::handleHangup(SimpleRemoteEPCArgBytesVector ArgBytes) {
392   using namespace llvm::orc::shared;
393   auto WFR = WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());
394   if (const char *ErrMsg = WFR.getOutOfBandError())
395     return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
396 
397   detail::SPSSerializableError Info;
398   SPSInputBuffer IB(WFR.data(), WFR.size());
399   if (!SPSArgList<SPSError>::deserialize(IB, Info))
400     return make_error<StringError>("Could not deserialize hangup info",
401                                    inconvertibleErrorCode());
402   return fromSPSSerializable(std::move(Info));
403 }
404 
405 } // end namespace orc
406 } // end namespace llvm
407