xref: /freebsd/contrib/llvm-project/llvm/include/llvm/ExecutionEngine/Orc/Shared/WrapperFunctionUtils.h (revision 0fca6ea1d4eea4c934cfff25ac9ee8ad6fe95583)
1 //===- WrapperFunctionUtils.h - Utilities for wrapper functions -*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // A buffer for serialized results.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef LLVM_EXECUTIONENGINE_ORC_SHARED_WRAPPERFUNCTIONUTILS_H
14 #define LLVM_EXECUTIONENGINE_ORC_SHARED_WRAPPERFUNCTIONUTILS_H
15 
16 #include "llvm/ExecutionEngine/Orc/Shared/ExecutorAddress.h"
17 #include "llvm/ExecutionEngine/Orc/Shared/SimplePackedSerialization.h"
18 #include "llvm/Support/Error.h"
19 
20 #include <type_traits>
21 
22 namespace llvm {
23 namespace orc {
24 namespace shared {
25 
26 // Must be kept in-sync with compiler-rt/lib/orc/c-api.h.
27 union CWrapperFunctionResultDataUnion {
28   char *ValuePtr;
29   char Value[sizeof(ValuePtr)];
30 };
31 
32 // Must be kept in-sync with compiler-rt/lib/orc/c-api.h.
33 typedef struct {
34   CWrapperFunctionResultDataUnion Data;
35   size_t Size;
36 } CWrapperFunctionResult;
37 
38 /// C++ wrapper function result: Same as CWrapperFunctionResult but
39 /// auto-releases memory.
40 class WrapperFunctionResult {
41 public:
42   /// Create a default WrapperFunctionResult.
WrapperFunctionResult()43   WrapperFunctionResult() { init(R); }
44 
45   /// Create a WrapperFunctionResult by taking ownership of a
46   /// CWrapperFunctionResult.
47   ///
48   /// Warning: This should only be used by clients writing wrapper-function
49   /// caller utilities (like TargetProcessControl).
WrapperFunctionResult(CWrapperFunctionResult R)50   WrapperFunctionResult(CWrapperFunctionResult R) : R(R) {
51     // Reset R.
52     init(R);
53   }
54 
55   WrapperFunctionResult(const WrapperFunctionResult &) = delete;
56   WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete;
57 
WrapperFunctionResult(WrapperFunctionResult && Other)58   WrapperFunctionResult(WrapperFunctionResult &&Other) {
59     init(R);
60     std::swap(R, Other.R);
61   }
62 
63   WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) {
64     WrapperFunctionResult Tmp(std::move(Other));
65     std::swap(R, Tmp.R);
66     return *this;
67   }
68 
~WrapperFunctionResult()69   ~WrapperFunctionResult() {
70     if ((R.Size > sizeof(R.Data.Value)) ||
71         (R.Size == 0 && R.Data.ValuePtr != nullptr))
72       free(R.Data.ValuePtr);
73   }
74 
75   /// Release ownership of the contained CWrapperFunctionResult.
76   /// Warning: Do not use -- this method will be removed in the future. It only
77   /// exists to temporarily support some code that will eventually be moved to
78   /// the ORC runtime.
release()79   CWrapperFunctionResult release() {
80     CWrapperFunctionResult Tmp;
81     init(Tmp);
82     std::swap(R, Tmp);
83     return Tmp;
84   }
85 
86   /// Get a pointer to the data contained in this instance.
data()87   char *data() {
88     assert((R.Size != 0 || R.Data.ValuePtr == nullptr) &&
89            "Cannot get data for out-of-band error value");
90     return R.Size > sizeof(R.Data.Value) ? R.Data.ValuePtr : R.Data.Value;
91   }
92 
93   /// Get a const pointer to the data contained in this instance.
data()94   const char *data() const {
95     assert((R.Size != 0 || R.Data.ValuePtr == nullptr) &&
96            "Cannot get data for out-of-band error value");
97     return R.Size > sizeof(R.Data.Value) ? R.Data.ValuePtr : R.Data.Value;
98   }
99 
100   /// Returns the size of the data contained in this instance.
size()101   size_t size() const {
102     assert((R.Size != 0 || R.Data.ValuePtr == nullptr) &&
103            "Cannot get data for out-of-band error value");
104     return R.Size;
105   }
106 
107   /// Returns true if this value is equivalent to a default-constructed
108   /// WrapperFunctionResult.
empty()109   bool empty() const { return R.Size == 0 && R.Data.ValuePtr == nullptr; }
110 
111   /// Create a WrapperFunctionResult with the given size and return a pointer
112   /// to the underlying memory.
allocate(size_t Size)113   static WrapperFunctionResult allocate(size_t Size) {
114     // Reset.
115     WrapperFunctionResult WFR;
116     WFR.R.Size = Size;
117     if (WFR.R.Size > sizeof(WFR.R.Data.Value))
118       WFR.R.Data.ValuePtr = (char *)malloc(WFR.R.Size);
119     return WFR;
120   }
121 
122   /// Copy from the given char range.
copyFrom(const char * Source,size_t Size)123   static WrapperFunctionResult copyFrom(const char *Source, size_t Size) {
124     auto WFR = allocate(Size);
125     memcpy(WFR.data(), Source, Size);
126     return WFR;
127   }
128 
129   /// Copy from the given null-terminated string (includes the null-terminator).
copyFrom(const char * Source)130   static WrapperFunctionResult copyFrom(const char *Source) {
131     return copyFrom(Source, strlen(Source) + 1);
132   }
133 
134   /// Copy from the given std::string (includes the null terminator).
copyFrom(const std::string & Source)135   static WrapperFunctionResult copyFrom(const std::string &Source) {
136     return copyFrom(Source.c_str());
137   }
138 
139   /// Create an out-of-band error by copying the given string.
createOutOfBandError(const char * Msg)140   static WrapperFunctionResult createOutOfBandError(const char *Msg) {
141     // Reset.
142     WrapperFunctionResult WFR;
143     char *Tmp = (char *)malloc(strlen(Msg) + 1);
144     strcpy(Tmp, Msg);
145     WFR.R.Data.ValuePtr = Tmp;
146     return WFR;
147   }
148 
149   /// Create an out-of-band error by copying the given string.
createOutOfBandError(const std::string & Msg)150   static WrapperFunctionResult createOutOfBandError(const std::string &Msg) {
151     return createOutOfBandError(Msg.c_str());
152   }
153 
154   /// If this value is an out-of-band error then this returns the error message,
155   /// otherwise returns nullptr.
getOutOfBandError()156   const char *getOutOfBandError() const {
157     return R.Size == 0 ? R.Data.ValuePtr : nullptr;
158   }
159 
160 private:
init(CWrapperFunctionResult & R)161   static void init(CWrapperFunctionResult &R) {
162     R.Data.ValuePtr = nullptr;
163     R.Size = 0;
164   }
165 
166   CWrapperFunctionResult R;
167 };
168 
169 namespace detail {
170 
171 template <typename SPSArgListT, typename... ArgTs>
172 WrapperFunctionResult
serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args)173 serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) {
174   auto Result = WrapperFunctionResult::allocate(SPSArgListT::size(Args...));
175   SPSOutputBuffer OB(Result.data(), Result.size());
176   if (!SPSArgListT::serialize(OB, Args...))
177     return WrapperFunctionResult::createOutOfBandError(
178         "Error serializing arguments to blob in call");
179   return Result;
180 }
181 
182 template <typename RetT> class WrapperFunctionHandlerCaller {
183 public:
184   template <typename HandlerT, typename ArgTupleT, std::size_t... I>
decltype(auto)185   static decltype(auto) call(HandlerT &&H, ArgTupleT &Args,
186                              std::index_sequence<I...>) {
187     return std::forward<HandlerT>(H)(std::get<I>(Args)...);
188   }
189 };
190 
191 template <> class WrapperFunctionHandlerCaller<void> {
192 public:
193   template <typename HandlerT, typename ArgTupleT, std::size_t... I>
call(HandlerT && H,ArgTupleT & Args,std::index_sequence<I...>)194   static SPSEmpty call(HandlerT &&H, ArgTupleT &Args,
195                        std::index_sequence<I...>) {
196     std::forward<HandlerT>(H)(std::get<I>(Args)...);
197     return SPSEmpty();
198   }
199 };
200 
201 template <typename WrapperFunctionImplT,
202           template <typename> class ResultSerializer, typename... SPSTagTs>
203 class WrapperFunctionHandlerHelper
204     : public WrapperFunctionHandlerHelper<
205           decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
206           ResultSerializer, SPSTagTs...> {};
207 
208 template <typename RetT, typename... ArgTs,
209           template <typename> class ResultSerializer, typename... SPSTagTs>
210 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
211                                    SPSTagTs...> {
212 public:
213   using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
214   using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
215 
216   template <typename HandlerT>
apply(HandlerT && H,const char * ArgData,size_t ArgSize)217   static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData,
218                                      size_t ArgSize) {
219     ArgTuple Args;
220     if (!deserialize(ArgData, ArgSize, Args, ArgIndices{}))
221       return WrapperFunctionResult::createOutOfBandError(
222           "Could not deserialize arguments for wrapper function call");
223 
224     auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
225         std::forward<HandlerT>(H), Args, ArgIndices{});
226 
227     return ResultSerializer<decltype(HandlerResult)>::serialize(
228         std::move(HandlerResult));
229   }
230 
231 private:
232   template <std::size_t... I>
deserialize(const char * ArgData,size_t ArgSize,ArgTuple & Args,std::index_sequence<I...>)233   static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
234                           std::index_sequence<I...>) {
235     SPSInputBuffer IB(ArgData, ArgSize);
236     return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
237   }
238 };
239 
240 // Map function pointers to function types.
241 template <typename RetT, typename... ArgTs,
242           template <typename> class ResultSerializer, typename... SPSTagTs>
243 class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
244                                    SPSTagTs...>
245     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
246                                           SPSTagTs...> {};
247 
248 // Map non-const member function types to function types.
249 template <typename ClassT, typename RetT, typename... ArgTs,
250           template <typename> class ResultSerializer, typename... SPSTagTs>
251 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer,
252                                    SPSTagTs...>
253     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
254                                           SPSTagTs...> {};
255 
256 // Map const member function types to function types.
257 template <typename ClassT, typename RetT, typename... ArgTs,
258           template <typename> class ResultSerializer, typename... SPSTagTs>
259 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
260                                    ResultSerializer, SPSTagTs...>
261     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
262                                           SPSTagTs...> {};
263 
264 template <typename WrapperFunctionImplT,
265           template <typename> class ResultSerializer, typename... SPSTagTs>
266 class WrapperFunctionAsyncHandlerHelper
267     : public WrapperFunctionAsyncHandlerHelper<
268           decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
269           ResultSerializer, SPSTagTs...> {};
270 
271 template <typename RetT, typename SendResultT, typename... ArgTs,
272           template <typename> class ResultSerializer, typename... SPSTagTs>
273 class WrapperFunctionAsyncHandlerHelper<RetT(SendResultT, ArgTs...),
274                                         ResultSerializer, SPSTagTs...> {
275 public:
276   using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
277   using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
278 
279   template <typename HandlerT, typename SendWrapperFunctionResultT>
applyAsync(HandlerT && H,SendWrapperFunctionResultT && SendWrapperFunctionResult,const char * ArgData,size_t ArgSize)280   static void applyAsync(HandlerT &&H,
281                          SendWrapperFunctionResultT &&SendWrapperFunctionResult,
282                          const char *ArgData, size_t ArgSize) {
283     ArgTuple Args;
284     if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) {
285       SendWrapperFunctionResult(WrapperFunctionResult::createOutOfBandError(
286           "Could not deserialize arguments for wrapper function call"));
287       return;
288     }
289 
290     auto SendResult =
291         [SendWFR = std::move(SendWrapperFunctionResult)](auto Result) mutable {
292           using ResultT = decltype(Result);
293           SendWFR(ResultSerializer<ResultT>::serialize(std::move(Result)));
294         };
295 
296     callAsync(std::forward<HandlerT>(H), std::move(SendResult), std::move(Args),
297               ArgIndices{});
298   }
299 
300 private:
301   template <std::size_t... I>
deserialize(const char * ArgData,size_t ArgSize,ArgTuple & Args,std::index_sequence<I...>)302   static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
303                           std::index_sequence<I...>) {
304     SPSInputBuffer IB(ArgData, ArgSize);
305     return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
306   }
307 
308   template <typename HandlerT, typename SerializeAndSendResultT,
309             typename ArgTupleT, std::size_t... I>
callAsync(HandlerT && H,SerializeAndSendResultT && SerializeAndSendResult,ArgTupleT Args,std::index_sequence<I...>)310   static void callAsync(HandlerT &&H,
311                         SerializeAndSendResultT &&SerializeAndSendResult,
312                         ArgTupleT Args, std::index_sequence<I...>) {
313     (void)Args; // Silence a buggy GCC warning.
314     return std::forward<HandlerT>(H)(std::move(SerializeAndSendResult),
315                                      std::move(std::get<I>(Args))...);
316   }
317 };
318 
319 // Map function pointers to function types.
320 template <typename RetT, typename... ArgTs,
321           template <typename> class ResultSerializer, typename... SPSTagTs>
322 class WrapperFunctionAsyncHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
323                                         SPSTagTs...>
324     : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
325                                                SPSTagTs...> {};
326 
327 // Map non-const member function types to function types.
328 template <typename ClassT, typename RetT, typename... ArgTs,
329           template <typename> class ResultSerializer, typename... SPSTagTs>
330 class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...),
331                                         ResultSerializer, SPSTagTs...>
332     : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
333                                                SPSTagTs...> {};
334 
335 // Map const member function types to function types.
336 template <typename ClassT, typename RetT, typename... ArgTs,
337           template <typename> class ResultSerializer, typename... SPSTagTs>
338 class WrapperFunctionAsyncHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
339                                         ResultSerializer, SPSTagTs...>
340     : public WrapperFunctionAsyncHandlerHelper<RetT(ArgTs...), ResultSerializer,
341                                                SPSTagTs...> {};
342 
343 template <typename SPSRetTagT, typename RetT> class ResultSerializer {
344 public:
serialize(RetT Result)345   static WrapperFunctionResult serialize(RetT Result) {
346     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
347         Result);
348   }
349 };
350 
351 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
352 public:
serialize(Error Err)353   static WrapperFunctionResult serialize(Error Err) {
354     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
355         toSPSSerializable(std::move(Err)));
356   }
357 };
358 
359 template <typename SPSRetTagT>
360 class ResultSerializer<SPSRetTagT, ErrorSuccess> {
361 public:
serialize(ErrorSuccess Err)362   static WrapperFunctionResult serialize(ErrorSuccess Err) {
363     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
364         toSPSSerializable(std::move(Err)));
365   }
366 };
367 
368 template <typename SPSRetTagT, typename T>
369 class ResultSerializer<SPSRetTagT, Expected<T>> {
370 public:
serialize(Expected<T> E)371   static WrapperFunctionResult serialize(Expected<T> E) {
372     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
373         toSPSSerializable(std::move(E)));
374   }
375 };
376 
377 template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
378 public:
makeValue()379   static RetT makeValue() { return RetT(); }
makeSafe(RetT & Result)380   static void makeSafe(RetT &Result) {}
381 
deserialize(RetT & Result,const char * ArgData,size_t ArgSize)382   static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
383     SPSInputBuffer IB(ArgData, ArgSize);
384     if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result))
385       return make_error<StringError>(
386           "Error deserializing return value from blob in call",
387           inconvertibleErrorCode());
388     return Error::success();
389   }
390 };
391 
392 template <> class ResultDeserializer<SPSError, Error> {
393 public:
makeValue()394   static Error makeValue() { return Error::success(); }
makeSafe(Error & Err)395   static void makeSafe(Error &Err) { cantFail(std::move(Err)); }
396 
deserialize(Error & Err,const char * ArgData,size_t ArgSize)397   static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
398     SPSInputBuffer IB(ArgData, ArgSize);
399     SPSSerializableError BSE;
400     if (!SPSArgList<SPSError>::deserialize(IB, BSE))
401       return make_error<StringError>(
402           "Error deserializing return value from blob in call",
403           inconvertibleErrorCode());
404     Err = fromSPSSerializable(std::move(BSE));
405     return Error::success();
406   }
407 };
408 
409 template <typename SPSTagT, typename T>
410 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
411 public:
makeValue()412   static Expected<T> makeValue() { return T(); }
makeSafe(Expected<T> & E)413   static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
414 
deserialize(Expected<T> & E,const char * ArgData,size_t ArgSize)415   static Error deserialize(Expected<T> &E, const char *ArgData,
416                            size_t ArgSize) {
417     SPSInputBuffer IB(ArgData, ArgSize);
418     SPSSerializableExpected<T> BSE;
419     if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE))
420       return make_error<StringError>(
421           "Error deserializing return value from blob in call",
422           inconvertibleErrorCode());
423     E = fromSPSSerializable(std::move(BSE));
424     return Error::success();
425   }
426 };
427 
428 template <typename SPSRetTagT, typename RetT> class AsyncCallResultHelper {
429   // Did you forget to use Error / Expected in your handler?
430 };
431 
432 } // end namespace detail
433 
434 template <typename SPSSignature> class WrapperFunction;
435 
436 template <typename SPSRetTagT, typename... SPSTagTs>
437 class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
438 private:
439   template <typename RetT>
440   using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
441 
442 public:
443   /// Call a wrapper function. Caller should be callable as
444   /// WrapperFunctionResult Fn(const char *ArgData, size_t ArgSize);
445   template <typename CallerFn, typename RetT, typename... ArgTs>
call(const CallerFn & Caller,RetT & Result,const ArgTs &...Args)446   static Error call(const CallerFn &Caller, RetT &Result,
447                     const ArgTs &...Args) {
448 
449     // RetT might be an Error or Expected value. Set the checked flag now:
450     // we don't want the user to have to check the unused result if this
451     // operation fails.
452     detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result);
453 
454     auto ArgBuffer =
455         detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
456             Args...);
457     if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
458       return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
459 
460     WrapperFunctionResult ResultBuffer =
461         Caller(ArgBuffer.data(), ArgBuffer.size());
462     if (auto ErrMsg = ResultBuffer.getOutOfBandError())
463       return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
464 
465     return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
466         Result, ResultBuffer.data(), ResultBuffer.size());
467   }
468 
469   /// Call an async wrapper function.
470   /// Caller should be callable as
471   /// void Fn(unique_function<void(WrapperFunctionResult)> SendResult,
472   ///         WrapperFunctionResult ArgBuffer);
473   template <typename AsyncCallerFn, typename SendDeserializedResultFn,
474             typename... ArgTs>
callAsync(AsyncCallerFn && Caller,SendDeserializedResultFn && SendDeserializedResult,const ArgTs &...Args)475   static void callAsync(AsyncCallerFn &&Caller,
476                         SendDeserializedResultFn &&SendDeserializedResult,
477                         const ArgTs &...Args) {
478     using RetT = typename std::tuple_element<
479         1, typename detail::WrapperFunctionHandlerHelper<
480                std::remove_reference_t<SendDeserializedResultFn>,
481                ResultSerializer, SPSRetTagT>::ArgTuple>::type;
482 
483     auto ArgBuffer =
484         detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
485             Args...);
486     if (auto *ErrMsg = ArgBuffer.getOutOfBandError()) {
487       SendDeserializedResult(
488           make_error<StringError>(ErrMsg, inconvertibleErrorCode()),
489           detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue());
490       return;
491     }
492 
493     auto SendSerializedResult = [SDR = std::move(SendDeserializedResult)](
494                                     WrapperFunctionResult R) mutable {
495       RetT RetVal = detail::ResultDeserializer<SPSRetTagT, RetT>::makeValue();
496       detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(RetVal);
497 
498       if (auto *ErrMsg = R.getOutOfBandError()) {
499         SDR(make_error<StringError>(ErrMsg, inconvertibleErrorCode()),
500             std::move(RetVal));
501         return;
502       }
503 
504       SPSInputBuffer IB(R.data(), R.size());
505       if (auto Err = detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
506               RetVal, R.data(), R.size())) {
507         SDR(std::move(Err), std::move(RetVal));
508         return;
509       }
510 
511       SDR(Error::success(), std::move(RetVal));
512     };
513 
514     Caller(std::move(SendSerializedResult), ArgBuffer.data(), ArgBuffer.size());
515   }
516 
517   /// Handle a call to a wrapper function.
518   template <typename HandlerT>
handle(const char * ArgData,size_t ArgSize,HandlerT && Handler)519   static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
520                                       HandlerT &&Handler) {
521     using WFHH =
522         detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,
523                                              ResultSerializer, SPSTagTs...>;
524     return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
525   }
526 
527   /// Handle a call to an async wrapper function.
528   template <typename HandlerT, typename SendResultT>
handleAsync(const char * ArgData,size_t ArgSize,HandlerT && Handler,SendResultT && SendResult)529   static void handleAsync(const char *ArgData, size_t ArgSize,
530                           HandlerT &&Handler, SendResultT &&SendResult) {
531     using WFAHH = detail::WrapperFunctionAsyncHandlerHelper<
532         std::remove_reference_t<HandlerT>, ResultSerializer, SPSTagTs...>;
533     WFAHH::applyAsync(std::forward<HandlerT>(Handler),
534                       std::forward<SendResultT>(SendResult), ArgData, ArgSize);
535   }
536 
537 private:
makeSerializable(const T & Value)538   template <typename T> static const T &makeSerializable(const T &Value) {
539     return Value;
540   }
541 
makeSerializable(Error Err)542   static detail::SPSSerializableError makeSerializable(Error Err) {
543     return detail::toSPSSerializable(std::move(Err));
544   }
545 
546   template <typename T>
makeSerializable(Expected<T> E)547   static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) {
548     return detail::toSPSSerializable(std::move(E));
549   }
550 };
551 
552 template <typename... SPSTagTs>
553 class WrapperFunction<void(SPSTagTs...)>
554     : private WrapperFunction<SPSEmpty(SPSTagTs...)> {
555 
556 public:
557   template <typename CallerFn, typename... ArgTs>
call(const CallerFn & Caller,const ArgTs &...Args)558   static Error call(const CallerFn &Caller, const ArgTs &...Args) {
559     SPSEmpty BE;
560     return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(Caller, BE, Args...);
561   }
562 
563   template <typename AsyncCallerFn, typename SendDeserializedResultFn,
564             typename... ArgTs>
callAsync(AsyncCallerFn && Caller,SendDeserializedResultFn && SendDeserializedResult,const ArgTs &...Args)565   static void callAsync(AsyncCallerFn &&Caller,
566                         SendDeserializedResultFn &&SendDeserializedResult,
567                         const ArgTs &...Args) {
568     WrapperFunction<SPSEmpty(SPSTagTs...)>::callAsync(
569         std::forward<AsyncCallerFn>(Caller),
570         [SDR = std::move(SendDeserializedResult)](Error SerializeErr,
571                                                   SPSEmpty E) mutable {
572           SDR(std::move(SerializeErr));
573         },
574         Args...);
575   }
576 
577   using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
578   using WrapperFunction<SPSEmpty(SPSTagTs...)>::handleAsync;
579 };
580 
581 /// A function object that takes an ExecutorAddr as its first argument,
582 /// casts that address to a ClassT*, then calls the given method on that
583 /// pointer passing in the remaining function arguments. This utility
584 /// removes some of the boilerplate from writing wrappers for method calls.
585 ///
586 ///   @code{.cpp}
587 ///   class MyClass {
588 ///   public:
589 ///     void myMethod(uint32_t, bool) { ... }
590 ///   };
591 ///
592 ///   // SPS Method signature -- note MyClass object address as first argument.
593 ///   using SPSMyMethodWrapperSignature =
594 ///     SPSTuple<SPSExecutorAddr, uint32_t, bool>;
595 ///
596 ///   WrapperFunctionResult
597 ///   myMethodCallWrapper(const char *ArgData, size_t ArgSize) {
598 ///     return WrapperFunction<SPSMyMethodWrapperSignature>::handle(
599 ///        ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod));
600 ///   }
601 ///   @endcode
602 ///
603 template <typename RetT, typename ClassT, typename... ArgTs>
604 class MethodWrapperHandler {
605 public:
606   using MethodT = RetT (ClassT::*)(ArgTs...);
MethodWrapperHandler(MethodT M)607   MethodWrapperHandler(MethodT M) : M(M) {}
operator()608   RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) {
609     return (ObjAddr.toPtr<ClassT*>()->*M)(std::forward<ArgTs>(Args)...);
610   }
611 
612 private:
613   MethodT M;
614 };
615 
616 /// Create a MethodWrapperHandler object from the given method pointer.
617 template <typename RetT, typename ClassT, typename... ArgTs>
618 MethodWrapperHandler<RetT, ClassT, ArgTs...>
makeMethodWrapperHandler(RetT (ClassT::* Method)(ArgTs...))619 makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {
620   return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method);
621 }
622 
623 /// Represents a serialized wrapper function call.
624 /// Serializing calls themselves allows us to batch them: We can make one
625 /// "run-wrapper-functions" utility and send it a list of calls to run.
626 ///
627 /// The motivating use-case for this API is JITLink allocation actions, where
628 /// we want to run multiple functions to finalize linked memory without having
629 /// to make separate IPC calls for each one.
630 class WrapperFunctionCall {
631 public:
632   using ArgDataBufferType = SmallVector<char, 24>;
633 
634   /// Create a WrapperFunctionCall using the given SPS serializer to serialize
635   /// the arguments.
636   template <typename SPSSerializer, typename... ArgTs>
Create(ExecutorAddr FnAddr,const ArgTs &...Args)637   static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr,
638                                               const ArgTs &...Args) {
639     ArgDataBufferType ArgData;
640     ArgData.resize(SPSSerializer::size(Args...));
641     SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(),
642                        ArgData.size());
643     if (SPSSerializer::serialize(OB, Args...))
644       return WrapperFunctionCall(FnAddr, std::move(ArgData));
645     return make_error<StringError>("Cannot serialize arguments for "
646                                    "AllocActionCall",
647                                    inconvertibleErrorCode());
648   }
649 
650   WrapperFunctionCall() = default;
651 
652   /// Create a WrapperFunctionCall from a target function and arg buffer.
WrapperFunctionCall(ExecutorAddr FnAddr,ArgDataBufferType ArgData)653   WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData)
654       : FnAddr(FnAddr), ArgData(std::move(ArgData)) {}
655 
656   /// Returns the address to be called.
getCallee()657   const ExecutorAddr &getCallee() const { return FnAddr; }
658 
659   /// Returns the argument data.
getArgData()660   const ArgDataBufferType &getArgData() const { return ArgData; }
661 
662   /// WrapperFunctionCalls convert to true if the callee is non-null.
663   explicit operator bool() const { return !!FnAddr; }
664 
665   /// Run call returning raw WrapperFunctionResult.
run()666   shared::WrapperFunctionResult run() const {
667     using FnTy =
668         shared::CWrapperFunctionResult(const char *ArgData, size_t ArgSize);
669     return shared::WrapperFunctionResult(
670         FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size()));
671   }
672 
673   /// Run call and deserialize result using SPS.
674   template <typename SPSRetT, typename RetT>
675   std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error>
runWithSPSRet(RetT & RetVal)676   runWithSPSRet(RetT &RetVal) const {
677     auto WFR = run();
678     if (const char *ErrMsg = WFR.getOutOfBandError())
679       return make_error<StringError>(ErrMsg, inconvertibleErrorCode());
680     shared::SPSInputBuffer IB(WFR.data(), WFR.size());
681     if (!shared::SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal))
682       return make_error<StringError>("Could not deserialize result from "
683                                      "serialized wrapper function call",
684                                      inconvertibleErrorCode());
685     return Error::success();
686   }
687 
688   /// Overload for SPS functions returning void.
689   template <typename SPSRetT>
690   std::enable_if_t<std::is_same<SPSRetT, void>::value, Error>
runWithSPSRet()691   runWithSPSRet() const {
692     shared::SPSEmpty E;
693     return runWithSPSRet<shared::SPSEmpty>(E);
694   }
695 
696   /// Run call and deserialize an SPSError result. SPSError returns and
697   /// deserialization failures are merged into the returned error.
runWithSPSRetErrorMerged()698   Error runWithSPSRetErrorMerged() const {
699     detail::SPSSerializableError RetErr;
700     if (auto Err = runWithSPSRet<SPSError>(RetErr))
701       return Err;
702     return detail::fromSPSSerializable(std::move(RetErr));
703   }
704 
705 private:
706   orc::ExecutorAddr FnAddr;
707   ArgDataBufferType ArgData;
708 };
709 
710 using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>;
711 
712 template <>
713 class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> {
714 public:
size(const WrapperFunctionCall & WFC)715   static size_t size(const WrapperFunctionCall &WFC) {
716     return SPSWrapperFunctionCall::AsArgList::size(WFC.getCallee(),
717                                                    WFC.getArgData());
718   }
719 
serialize(SPSOutputBuffer & OB,const WrapperFunctionCall & WFC)720   static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) {
721     return SPSWrapperFunctionCall::AsArgList::serialize(OB, WFC.getCallee(),
722                                                         WFC.getArgData());
723   }
724 
deserialize(SPSInputBuffer & IB,WrapperFunctionCall & WFC)725   static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) {
726     ExecutorAddr FnAddr;
727     WrapperFunctionCall::ArgDataBufferType ArgData;
728     if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData))
729       return false;
730     WFC = WrapperFunctionCall(FnAddr, std::move(ArgData));
731     return true;
732   }
733 };
734 
735 } // end namespace shared
736 } // end namespace orc
737 } // end namespace llvm
738 
739 #endif // LLVM_EXECUTIONENGINE_ORC_SHARED_WRAPPERFUNCTIONUTILS_H
740