1 //===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file is a part of the ORC runtime support library.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H
14 #define ORC_RT_WRAPPER_FUNCTION_UTILS_H
15
16 #include "error.h"
17 #include "executor_address.h"
18 #include "orc_rt/c_api.h"
19 #include "simple_packed_serialization.h"
20 #include <type_traits>
21
22 namespace orc_rt {
23
24 /// C++ wrapper function result: Same as orc_rt_WrapperFunctionResult but
25 /// auto-releases memory.
26 class WrapperFunctionResult {
27 public:
28 /// Create a default WrapperFunctionResult.
WrapperFunctionResult()29 WrapperFunctionResult() { orc_rt_WrapperFunctionResultInit(&R); }
30
31 /// Create a WrapperFunctionResult from a WrapperFunctionResult. This
32 /// instance takes ownership of the result object and will automatically
33 /// call dispose on the result upon destruction.
WrapperFunctionResult(orc_rt_WrapperFunctionResult R)34 WrapperFunctionResult(orc_rt_WrapperFunctionResult R) : R(R) {}
35
36 WrapperFunctionResult(const WrapperFunctionResult &) = delete;
37 WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete;
38
WrapperFunctionResult(WrapperFunctionResult && Other)39 WrapperFunctionResult(WrapperFunctionResult &&Other) {
40 orc_rt_WrapperFunctionResultInit(&R);
41 std::swap(R, Other.R);
42 }
43
44 WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) {
45 orc_rt_WrapperFunctionResult Tmp;
46 orc_rt_WrapperFunctionResultInit(&Tmp);
47 std::swap(Tmp, Other.R);
48 std::swap(R, Tmp);
49 return *this;
50 }
51
~WrapperFunctionResult()52 ~WrapperFunctionResult() { orc_rt_DisposeWrapperFunctionResult(&R); }
53
54 /// Relinquish ownership of and return the
55 /// orc_rt_WrapperFunctionResult.
release()56 orc_rt_WrapperFunctionResult release() {
57 orc_rt_WrapperFunctionResult Tmp;
58 orc_rt_WrapperFunctionResultInit(&Tmp);
59 std::swap(R, Tmp);
60 return Tmp;
61 }
62
63 /// Get a pointer to the data contained in this instance.
data()64 char *data() { return orc_rt_WrapperFunctionResultData(&R); }
65
66 /// Returns the size of the data contained in this instance.
size()67 size_t size() const { return orc_rt_WrapperFunctionResultSize(&R); }
68
69 /// Returns true if this value is equivalent to a default-constructed
70 /// WrapperFunctionResult.
empty()71 bool empty() const { return orc_rt_WrapperFunctionResultEmpty(&R); }
72
73 /// Create a WrapperFunctionResult with the given size and return a pointer
74 /// to the underlying memory.
allocate(size_t Size)75 static WrapperFunctionResult allocate(size_t Size) {
76 WrapperFunctionResult R;
77 R.R = orc_rt_WrapperFunctionResultAllocate(Size);
78 return R;
79 }
80
81 /// Copy from the given char range.
copyFrom(const char * Source,size_t Size)82 static WrapperFunctionResult copyFrom(const char *Source, size_t Size) {
83 return orc_rt_CreateWrapperFunctionResultFromRange(Source, Size);
84 }
85
86 /// Copy from the given null-terminated string (includes the null-terminator).
copyFrom(const char * Source)87 static WrapperFunctionResult copyFrom(const char *Source) {
88 return orc_rt_CreateWrapperFunctionResultFromString(Source);
89 }
90
91 /// Copy from the given std::string (includes the null terminator).
copyFrom(const std::string & Source)92 static WrapperFunctionResult copyFrom(const std::string &Source) {
93 return copyFrom(Source.c_str());
94 }
95
96 /// Create an out-of-band error by copying the given string.
createOutOfBandError(const char * Msg)97 static WrapperFunctionResult createOutOfBandError(const char *Msg) {
98 return orc_rt_CreateWrapperFunctionResultFromOutOfBandError(Msg);
99 }
100
101 /// Create an out-of-band error by copying the given string.
createOutOfBandError(const std::string & Msg)102 static WrapperFunctionResult createOutOfBandError(const std::string &Msg) {
103 return createOutOfBandError(Msg.c_str());
104 }
105
106 template <typename SPSArgListT, typename... ArgTs>
fromSPSArgs(const ArgTs &...Args)107 static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) {
108 auto Result = allocate(SPSArgListT::size(Args...));
109 SPSOutputBuffer OB(Result.data(), Result.size());
110 if (!SPSArgListT::serialize(OB, Args...))
111 return createOutOfBandError(
112 "Error serializing arguments to blob in call");
113 return Result;
114 }
115
116 /// If this value is an out-of-band error then this returns the error message,
117 /// otherwise returns nullptr.
getOutOfBandError()118 const char *getOutOfBandError() const {
119 return orc_rt_WrapperFunctionResultGetOutOfBandError(&R);
120 }
121
122 private:
123 orc_rt_WrapperFunctionResult R;
124 };
125
126 namespace detail {
127
128 template <typename RetT> class WrapperFunctionHandlerCaller {
129 public:
130 template <typename HandlerT, typename ArgTupleT, std::size_t... I>
decltype(auto)131 static decltype(auto) call(HandlerT &&H, ArgTupleT &Args,
132 std::index_sequence<I...>) {
133 return std::forward<HandlerT>(H)(std::get<I>(Args)...);
134 }
135 };
136
137 template <> class WrapperFunctionHandlerCaller<void> {
138 public:
139 template <typename HandlerT, typename ArgTupleT, std::size_t... I>
call(HandlerT && H,ArgTupleT & Args,std::index_sequence<I...>)140 static SPSEmpty call(HandlerT &&H, ArgTupleT &Args,
141 std::index_sequence<I...>) {
142 std::forward<HandlerT>(H)(std::get<I>(Args)...);
143 return SPSEmpty();
144 }
145 };
146
147 template <typename WrapperFunctionImplT,
148 template <typename> class ResultSerializer, typename... SPSTagTs>
149 class WrapperFunctionHandlerHelper
150 : public WrapperFunctionHandlerHelper<
151 decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
152 ResultSerializer, SPSTagTs...> {};
153
154 template <typename RetT, typename... ArgTs,
155 template <typename> class ResultSerializer, typename... SPSTagTs>
156 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
157 SPSTagTs...> {
158 public:
159 using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
160 using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
161
162 template <typename HandlerT>
apply(HandlerT && H,const char * ArgData,size_t ArgSize)163 static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData,
164 size_t ArgSize) {
165 ArgTuple Args;
166 if (!deserialize(ArgData, ArgSize, Args, ArgIndices{}))
167 return WrapperFunctionResult::createOutOfBandError(
168 "Could not deserialize arguments for wrapper function call");
169
170 auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
171 std::forward<HandlerT>(H), Args, ArgIndices{});
172
173 return ResultSerializer<decltype(HandlerResult)>::serialize(
174 std::move(HandlerResult));
175 }
176
177 private:
178 template <std::size_t... I>
deserialize(const char * ArgData,size_t ArgSize,ArgTuple & Args,std::index_sequence<I...>)179 static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
180 std::index_sequence<I...>) {
181 SPSInputBuffer IB(ArgData, ArgSize);
182 return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
183 }
184 };
185
186 // Map function pointers to function types.
187 template <typename RetT, typename... ArgTs,
188 template <typename> class ResultSerializer, typename... SPSTagTs>
189 class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
190 SPSTagTs...>
191 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
192 SPSTagTs...> {};
193
194 // Map non-const member function types to function types.
195 template <typename ClassT, typename RetT, typename... ArgTs,
196 template <typename> class ResultSerializer, typename... SPSTagTs>
197 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer,
198 SPSTagTs...>
199 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
200 SPSTagTs...> {};
201
202 // Map const member function types to function types.
203 template <typename ClassT, typename RetT, typename... ArgTs,
204 template <typename> class ResultSerializer, typename... SPSTagTs>
205 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
206 ResultSerializer, SPSTagTs...>
207 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
208 SPSTagTs...> {};
209
210 template <typename SPSRetTagT, typename RetT> class ResultSerializer {
211 public:
serialize(RetT Result)212 static WrapperFunctionResult serialize(RetT Result) {
213 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result);
214 }
215 };
216
217 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
218 public:
serialize(Error Err)219 static WrapperFunctionResult serialize(Error Err) {
220 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
221 toSPSSerializable(std::move(Err)));
222 }
223 };
224
225 template <typename SPSRetTagT, typename T>
226 class ResultSerializer<SPSRetTagT, Expected<T>> {
227 public:
serialize(Expected<T> E)228 static WrapperFunctionResult serialize(Expected<T> E) {
229 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
230 toSPSSerializable(std::move(E)));
231 }
232 };
233
234 template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
235 public:
makeSafe(RetT & Result)236 static void makeSafe(RetT &Result) {}
237
deserialize(RetT & Result,const char * ArgData,size_t ArgSize)238 static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
239 SPSInputBuffer IB(ArgData, ArgSize);
240 if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result))
241 return make_error<StringError>(
242 "Error deserializing return value from blob in call");
243 return Error::success();
244 }
245 };
246
247 template <> class ResultDeserializer<SPSError, Error> {
248 public:
makeSafe(Error & Err)249 static void makeSafe(Error &Err) { cantFail(std::move(Err)); }
250
deserialize(Error & Err,const char * ArgData,size_t ArgSize)251 static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
252 SPSInputBuffer IB(ArgData, ArgSize);
253 SPSSerializableError BSE;
254 if (!SPSArgList<SPSError>::deserialize(IB, BSE))
255 return make_error<StringError>(
256 "Error deserializing return value from blob in call");
257 Err = fromSPSSerializable(std::move(BSE));
258 return Error::success();
259 }
260 };
261
262 template <typename SPSTagT, typename T>
263 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
264 public:
makeSafe(Expected<T> & E)265 static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
266
deserialize(Expected<T> & E,const char * ArgData,size_t ArgSize)267 static Error deserialize(Expected<T> &E, const char *ArgData,
268 size_t ArgSize) {
269 SPSInputBuffer IB(ArgData, ArgSize);
270 SPSSerializableExpected<T> BSE;
271 if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE))
272 return make_error<StringError>(
273 "Error deserializing return value from blob in call");
274 E = fromSPSSerializable(std::move(BSE));
275 return Error::success();
276 }
277 };
278
279 } // end namespace detail
280
281 template <typename SPSSignature> class WrapperFunction;
282
283 template <typename SPSRetTagT, typename... SPSTagTs>
284 class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
285 private:
286 template <typename RetT>
287 using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
288
289 public:
290 template <typename DispatchFn, typename RetT, typename... ArgTs>
call(DispatchFn && Dispatch,RetT & Result,const ArgTs &...Args)291 static Error call(DispatchFn &&Dispatch, RetT &Result, const ArgTs &...Args) {
292
293 // RetT might be an Error or Expected value. Set the checked flag now:
294 // we don't want the user to have to check the unused result if this
295 // operation fails.
296 detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result);
297
298 auto ArgBuffer =
299 WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...);
300 if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
301 return make_error<StringError>(ErrMsg);
302
303 WrapperFunctionResult ResultBuffer =
304 Dispatch(ArgBuffer.data(), ArgBuffer.size());
305
306 if (auto ErrMsg = ResultBuffer.getOutOfBandError())
307 return make_error<StringError>(ErrMsg);
308
309 return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
310 Result, ResultBuffer.data(), ResultBuffer.size());
311 }
312
313 template <typename HandlerT>
handle(const char * ArgData,size_t ArgSize,HandlerT && Handler)314 static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
315 HandlerT &&Handler) {
316 using WFHH =
317 detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,
318 ResultSerializer, SPSTagTs...>;
319 return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
320 }
321
322 private:
makeSerializable(const T & Value)323 template <typename T> static const T &makeSerializable(const T &Value) {
324 return Value;
325 }
326
makeSerializable(Error Err)327 static detail::SPSSerializableError makeSerializable(Error Err) {
328 return detail::toSPSSerializable(std::move(Err));
329 }
330
331 template <typename T>
makeSerializable(Expected<T> E)332 static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) {
333 return detail::toSPSSerializable(std::move(E));
334 }
335 };
336
337 template <typename... SPSTagTs>
338 class WrapperFunction<void(SPSTagTs...)>
339 : private WrapperFunction<SPSEmpty(SPSTagTs...)> {
340 public:
341 template <typename DispatchFn, typename... ArgTs>
call(DispatchFn && Dispatch,const ArgTs &...Args)342 static Error call(DispatchFn &&Dispatch, const ArgTs &...Args) {
343 SPSEmpty BE;
344 return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(
345 std::forward<DispatchFn>(Dispatch), BE, Args...);
346 }
347
348 using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
349 };
350
351 /// A function object that takes an ExecutorAddr as its first argument,
352 /// casts that address to a ClassT*, then calls the given method on that
353 /// pointer passing in the remaining function arguments. This utility
354 /// removes some of the boilerplate from writing wrappers for method calls.
355 ///
356 /// @code{.cpp}
357 /// class MyClass {
358 /// public:
359 /// void myMethod(uint32_t, bool) { ... }
360 /// };
361 ///
362 /// // SPS Method signature -- note MyClass object address as first argument.
363 /// using SPSMyMethodWrapperSignature =
364 /// SPSTuple<SPSExecutorAddr, uint32_t, bool>;
365 ///
366 /// WrapperFunctionResult
367 /// myMethodCallWrapper(const char *ArgData, size_t ArgSize) {
368 /// return WrapperFunction<SPSMyMethodWrapperSignature>::handle(
369 /// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod));
370 /// }
371 /// @endcode
372 ///
373 template <typename RetT, typename ClassT, typename... ArgTs>
374 class MethodWrapperHandler {
375 public:
376 using MethodT = RetT (ClassT::*)(ArgTs...);
MethodWrapperHandler(MethodT M)377 MethodWrapperHandler(MethodT M) : M(M) {}
operator()378 RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) {
379 return (ObjAddr.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...);
380 }
381
382 private:
383 MethodT M;
384 };
385
386 /// Create a MethodWrapperHandler object from the given method pointer.
387 template <typename RetT, typename ClassT, typename... ArgTs>
388 MethodWrapperHandler<RetT, ClassT, ArgTs...>
makeMethodWrapperHandler(RetT (ClassT::* Method)(ArgTs...))389 makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {
390 return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method);
391 }
392
393 /// Represents a call to a wrapper function.
394 class WrapperFunctionCall {
395 public:
396 // FIXME: Switch to a SmallVector<char, 24> once ORC runtime has a
397 // smallvector.
398 using ArgDataBufferType = std::vector<char>;
399
400 /// Create a WrapperFunctionCall using the given SPS serializer to serialize
401 /// the arguments.
402 template <typename SPSSerializer, typename... ArgTs>
Create(ExecutorAddr FnAddr,const ArgTs &...Args)403 static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr,
404 const ArgTs &...Args) {
405 ArgDataBufferType ArgData;
406 ArgData.resize(SPSSerializer::size(Args...));
407 SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(),
408 ArgData.size());
409 if (SPSSerializer::serialize(OB, Args...))
410 return WrapperFunctionCall(FnAddr, std::move(ArgData));
411 return make_error<StringError>("Cannot serialize arguments for "
412 "AllocActionCall");
413 }
414
415 WrapperFunctionCall() = default;
416
417 /// Create a WrapperFunctionCall from a target function and arg buffer.
WrapperFunctionCall(ExecutorAddr FnAddr,ArgDataBufferType ArgData)418 WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData)
419 : FnAddr(FnAddr), ArgData(std::move(ArgData)) {}
420
421 /// Returns the address to be called.
getCallee()422 const ExecutorAddr &getCallee() const { return FnAddr; }
423
424 /// Returns the argument data.
getArgData()425 const ArgDataBufferType &getArgData() const { return ArgData; }
426
427 /// WrapperFunctionCalls convert to true if the callee is non-null.
428 explicit operator bool() const { return !!FnAddr; }
429
430 /// Run call returning raw WrapperFunctionResult.
run()431 WrapperFunctionResult run() const {
432 using FnTy =
433 orc_rt_WrapperFunctionResult(const char *ArgData, size_t ArgSize);
434 return WrapperFunctionResult(
435 FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size()));
436 }
437
438 /// Run call and deserialize result using SPS.
439 template <typename SPSRetT, typename RetT>
440 std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error>
runWithSPSRet(RetT & RetVal)441 runWithSPSRet(RetT &RetVal) const {
442 auto WFR = run();
443 if (const char *ErrMsg = WFR.getOutOfBandError())
444 return make_error<StringError>(ErrMsg);
445 SPSInputBuffer IB(WFR.data(), WFR.size());
446 if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal))
447 return make_error<StringError>("Could not deserialize result from "
448 "serialized wrapper function call");
449 return Error::success();
450 }
451
452 /// Overload for SPS functions returning void.
453 template <typename SPSRetT>
454 std::enable_if_t<std::is_same<SPSRetT, void>::value, Error>
runWithSPSRet()455 runWithSPSRet() const {
456 SPSEmpty E;
457 return runWithSPSRet<SPSEmpty>(E);
458 }
459
460 /// Run call and deserialize an SPSError result. SPSError returns and
461 /// deserialization failures are merged into the returned error.
runWithSPSRetErrorMerged()462 Error runWithSPSRetErrorMerged() const {
463 detail::SPSSerializableError RetErr;
464 if (auto Err = runWithSPSRet<SPSError>(RetErr))
465 return Err;
466 return detail::fromSPSSerializable(std::move(RetErr));
467 }
468
469 private:
470 ExecutorAddr FnAddr;
471 std::vector<char> ArgData;
472 };
473
474 using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>;
475
476 template <>
477 class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> {
478 public:
size(const WrapperFunctionCall & WFC)479 static size_t size(const WrapperFunctionCall &WFC) {
480 return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::size(
481 WFC.getCallee(), WFC.getArgData());
482 }
483
serialize(SPSOutputBuffer & OB,const WrapperFunctionCall & WFC)484 static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) {
485 return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::serialize(
486 OB, WFC.getCallee(), WFC.getArgData());
487 }
488
deserialize(SPSInputBuffer & IB,WrapperFunctionCall & WFC)489 static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) {
490 ExecutorAddr FnAddr;
491 WrapperFunctionCall::ArgDataBufferType ArgData;
492 if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData))
493 return false;
494 WFC = WrapperFunctionCall(FnAddr, std::move(ArgData));
495 return true;
496 }
497 };
498
499 } // namespace orc_rt
500
501 #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H
502