xref: /freebsd/contrib/llvm-project/compiler-rt/lib/orc/wrapper_function_utils.h (revision ec0ea6efa1ad229d75c394c1a9b9cac33af2b1d3)
1 //===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file is a part of the ORC runtime support library.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H
14 #define ORC_RT_WRAPPER_FUNCTION_UTILS_H
15 
16 #include "c_api.h"
17 #include "common.h"
18 #include "error.h"
19 #include "simple_packed_serialization.h"
20 #include <type_traits>
21 
22 namespace __orc_rt {
23 
24 /// C++ wrapper function result: Same as CWrapperFunctionResult but
25 /// auto-releases memory.
26 class WrapperFunctionResult {
27 public:
28   /// Create a default WrapperFunctionResult.
29   WrapperFunctionResult() { __orc_rt_CWrapperFunctionResultInit(&R); }
30 
31   /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This
32   /// instance takes ownership of the result object and will automatically
33   /// call dispose on the result upon destruction.
34   WrapperFunctionResult(__orc_rt_CWrapperFunctionResult R) : R(R) {}
35 
36   WrapperFunctionResult(const WrapperFunctionResult &) = delete;
37   WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete;
38 
39   WrapperFunctionResult(WrapperFunctionResult &&Other) {
40     __orc_rt_CWrapperFunctionResultInit(&R);
41     std::swap(R, Other.R);
42   }
43 
44   WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) {
45     __orc_rt_CWrapperFunctionResult Tmp;
46     __orc_rt_CWrapperFunctionResultInit(&Tmp);
47     std::swap(Tmp, Other.R);
48     std::swap(R, Tmp);
49     return *this;
50   }
51 
52   ~WrapperFunctionResult() { __orc_rt_DisposeCWrapperFunctionResult(&R); }
53 
54   /// Relinquish ownership of and return the
55   /// __orc_rt_CWrapperFunctionResult.
56   __orc_rt_CWrapperFunctionResult release() {
57     __orc_rt_CWrapperFunctionResult Tmp;
58     __orc_rt_CWrapperFunctionResultInit(&Tmp);
59     std::swap(R, Tmp);
60     return Tmp;
61   }
62 
63   /// Get a pointer to the data contained in this instance.
64   const char *data() const { return __orc_rt_CWrapperFunctionResultData(&R); }
65 
66   /// Returns the size of the data contained in this instance.
67   size_t size() const { return __orc_rt_CWrapperFunctionResultSize(&R); }
68 
69   /// Returns true if this value is equivalent to a default-constructed
70   /// WrapperFunctionResult.
71   bool empty() const { return __orc_rt_CWrapperFunctionResultEmpty(&R); }
72 
73   /// Create a WrapperFunctionResult with the given size and return a pointer
74   /// to the underlying memory.
75   static char *allocate(WrapperFunctionResult &R, size_t Size) {
76     __orc_rt_DisposeCWrapperFunctionResult(&R.R);
77     __orc_rt_CWrapperFunctionResultInit(&R.R);
78     return __orc_rt_CWrapperFunctionResultAllocate(&R.R, Size);
79   }
80 
81   /// Copy from the given char range.
82   static WrapperFunctionResult copyFrom(const char *Source, size_t Size) {
83     return __orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size);
84   }
85 
86   /// Copy from the given null-terminated string (includes the null-terminator).
87   static WrapperFunctionResult copyFrom(const char *Source) {
88     return __orc_rt_CreateCWrapperFunctionResultFromString(Source);
89   }
90 
91   /// Copy from the given std::string (includes the null terminator).
92   static WrapperFunctionResult copyFrom(const std::string &Source) {
93     return copyFrom(Source.c_str());
94   }
95 
96   /// Create an out-of-band error by copying the given string.
97   static WrapperFunctionResult createOutOfBandError(const char *Msg) {
98     return __orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg);
99   }
100 
101   /// Create an out-of-band error by copying the given string.
102   static WrapperFunctionResult createOutOfBandError(const std::string &Msg) {
103     return createOutOfBandError(Msg.c_str());
104   }
105 
106   /// If this value is an out-of-band error then this returns the error message,
107   /// otherwise returns nullptr.
108   const char *getOutOfBandError() const {
109     return __orc_rt_CWrapperFunctionResultGetOutOfBandError(&R);
110   }
111 
112 private:
113   __orc_rt_CWrapperFunctionResult R;
114 };
115 
116 namespace detail {
117 
118 template <typename SPSArgListT, typename... ArgTs>
119 Expected<WrapperFunctionResult>
120 serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) {
121   WrapperFunctionResult Result;
122   char *DataPtr =
123       WrapperFunctionResult::allocate(Result, SPSArgListT::size(Args...));
124   SPSOutputBuffer OB(DataPtr, Result.size());
125   if (!SPSArgListT::serialize(OB, Args...))
126     return make_error<StringError>(
127         "Error serializing arguments to blob in call");
128   return std::move(Result);
129 }
130 
131 template <typename RetT> class WrapperFunctionHandlerCaller {
132 public:
133   template <typename HandlerT, typename ArgTupleT, std::size_t... I>
134   static decltype(auto) call(HandlerT &&H, ArgTupleT &Args,
135                              std::index_sequence<I...>) {
136     return std::forward<HandlerT>(H)(std::get<I>(Args)...);
137   }
138 };
139 
140 template <> class WrapperFunctionHandlerCaller<void> {
141 public:
142   template <typename HandlerT, typename ArgTupleT, std::size_t... I>
143   static SPSEmpty call(HandlerT &&H, ArgTupleT &Args,
144                        std::index_sequence<I...>) {
145     std::forward<HandlerT>(H)(std::get<I>(Args)...);
146     return SPSEmpty();
147   }
148 };
149 
150 template <typename WrapperFunctionImplT,
151           template <typename> class ResultSerializer, typename... SPSTagTs>
152 class WrapperFunctionHandlerHelper
153     : public WrapperFunctionHandlerHelper<
154           decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
155           ResultSerializer, SPSTagTs...> {};
156 
157 template <typename RetT, typename... ArgTs,
158           template <typename> class ResultSerializer, typename... SPSTagTs>
159 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
160                                    SPSTagTs...> {
161 public:
162   using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
163   using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
164 
165   template <typename HandlerT>
166   static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData,
167                                      size_t ArgSize) {
168     ArgTuple Args;
169     if (!deserialize(ArgData, ArgSize, Args, ArgIndices{}))
170       return WrapperFunctionResult::createOutOfBandError(
171           "Could not deserialize arguments for wrapper function call");
172 
173     auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
174         std::forward<HandlerT>(H), Args, ArgIndices{});
175 
176     if (auto Result = ResultSerializer<decltype(HandlerResult)>::serialize(
177             std::move(HandlerResult)))
178       return std::move(*Result);
179     else
180       return WrapperFunctionResult::createOutOfBandError(
181           toString(Result.takeError()));
182   }
183 
184 private:
185   template <std::size_t... I>
186   static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
187                           std::index_sequence<I...>) {
188     SPSInputBuffer IB(ArgData, ArgSize);
189     return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
190   }
191 
192 };
193 
194 // Map function references to function types.
195 template <typename RetT, typename... ArgTs,
196           template <typename> class ResultSerializer, typename... SPSTagTs>
197 class WrapperFunctionHandlerHelper<RetT (&)(ArgTs...), ResultSerializer,
198                                    SPSTagTs...>
199     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
200                                           SPSTagTs...> {};
201 
202 // Map non-const member function types to function types.
203 template <typename ClassT, typename RetT, typename... ArgTs,
204           template <typename> class ResultSerializer, typename... SPSTagTs>
205 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer,
206                                    SPSTagTs...>
207     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
208                                           SPSTagTs...> {};
209 
210 // Map const member function types to function types.
211 template <typename ClassT, typename RetT, typename... ArgTs,
212           template <typename> class ResultSerializer, typename... SPSTagTs>
213 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
214                                    ResultSerializer, SPSTagTs...>
215     : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
216                                           SPSTagTs...> {};
217 
218 template <typename SPSRetTagT, typename RetT> class ResultSerializer {
219 public:
220   static Expected<WrapperFunctionResult> serialize(RetT Result) {
221     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
222         Result);
223   }
224 };
225 
226 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
227 public:
228   static Expected<WrapperFunctionResult> serialize(Error Err) {
229     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
230         toSPSSerializable(std::move(Err)));
231   }
232 };
233 
234 template <typename SPSRetTagT, typename T>
235 class ResultSerializer<SPSRetTagT, Expected<T>> {
236 public:
237   static Expected<WrapperFunctionResult> serialize(Expected<T> E) {
238     return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>(
239         toSPSSerializable(std::move(E)));
240   }
241 };
242 
243 template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
244 public:
245   static void makeSafe(RetT &Result) {}
246 
247   static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
248     SPSInputBuffer IB(ArgData, ArgSize);
249     if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result))
250       return make_error<StringError>(
251           "Error deserializing return value from blob in call");
252     return Error::success();
253   }
254 };
255 
256 template <> class ResultDeserializer<SPSError, Error> {
257 public:
258   static void makeSafe(Error &Err) { cantFail(std::move(Err)); }
259 
260   static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
261     SPSInputBuffer IB(ArgData, ArgSize);
262     SPSSerializableError BSE;
263     if (!SPSArgList<SPSError>::deserialize(IB, BSE))
264       return make_error<StringError>(
265           "Error deserializing return value from blob in call");
266     Err = fromSPSSerializable(std::move(BSE));
267     return Error::success();
268   }
269 };
270 
271 template <typename SPSTagT, typename T>
272 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
273 public:
274   static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
275 
276   static Error deserialize(Expected<T> &E, const char *ArgData,
277                            size_t ArgSize) {
278     SPSInputBuffer IB(ArgData, ArgSize);
279     SPSSerializableExpected<T> BSE;
280     if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE))
281       return make_error<StringError>(
282           "Error deserializing return value from blob in call");
283     E = fromSPSSerializable(std::move(BSE));
284     return Error::success();
285   }
286 };
287 
288 } // end namespace detail
289 
290 template <typename SPSSignature> class WrapperFunction;
291 
292 template <typename SPSRetTagT, typename... SPSTagTs>
293 class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
294 private:
295   template <typename RetT>
296   using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
297 
298 public:
299   template <typename RetT, typename... ArgTs>
300   static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) {
301 
302     // RetT might be an Error or Expected value. Set the checked flag now:
303     // we don't want the user to have to check the unused result if this
304     // operation fails.
305     detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result);
306 
307     if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx))
308       return make_error<StringError>("__orc_rt_jit_dispatch_ctx not set");
309     if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch))
310       return make_error<StringError>("__orc_rt_jit_dispatch not set");
311 
312     auto ArgBuffer =
313         detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>(
314             Args...);
315     if (!ArgBuffer)
316       return ArgBuffer.takeError();
317 
318     WrapperFunctionResult ResultBuffer =
319         __orc_rt_jit_dispatch(&__orc_rt_jit_dispatch_ctx, FnTag,
320                               ArgBuffer->data(), ArgBuffer->size());
321     if (auto ErrMsg = ResultBuffer.getOutOfBandError())
322       return make_error<StringError>(ErrMsg);
323 
324     return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
325         Result, ResultBuffer.data(), ResultBuffer.size());
326   }
327 
328   template <typename HandlerT>
329   static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
330                                       HandlerT &&Handler) {
331     using WFHH =
332         detail::WrapperFunctionHandlerHelper<HandlerT, ResultSerializer,
333                                              SPSTagTs...>;
334     return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
335   }
336 
337 private:
338   template <typename T> static const T &makeSerializable(const T &Value) {
339     return Value;
340   }
341 
342   static detail::SPSSerializableError makeSerializable(Error Err) {
343     return detail::toSPSSerializable(std::move(Err));
344   }
345 
346   template <typename T>
347   static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) {
348     return detail::toSPSSerializable(std::move(E));
349   }
350 };
351 
352 template <typename... SPSTagTs>
353 class WrapperFunction<void(SPSTagTs...)>
354     : private WrapperFunction<SPSEmpty(SPSTagTs...)> {
355 public:
356   template <typename... ArgTs>
357   static Error call(const void *FnTag, const ArgTs &...Args) {
358     SPSEmpty BE;
359     return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(FnTag, BE, Args...);
360   }
361 
362   using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
363 };
364 
365 } // end namespace __orc_rt
366 
367 #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H
368