1 //===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file is a part of the ORC runtime support library. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H 14 #define ORC_RT_WRAPPER_FUNCTION_UTILS_H 15 16 #include "c_api.h" 17 #include "common.h" 18 #include "error.h" 19 #include "simple_packed_serialization.h" 20 #include <type_traits> 21 22 namespace __orc_rt { 23 24 /// C++ wrapper function result: Same as CWrapperFunctionResult but 25 /// auto-releases memory. 26 class WrapperFunctionResult { 27 public: 28 /// Create a default WrapperFunctionResult. 29 WrapperFunctionResult() { __orc_rt_CWrapperFunctionResultInit(&R); } 30 31 /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This 32 /// instance takes ownership of the result object and will automatically 33 /// call dispose on the result upon destruction. 34 WrapperFunctionResult(__orc_rt_CWrapperFunctionResult R) : R(R) {} 35 36 WrapperFunctionResult(const WrapperFunctionResult &) = delete; 37 WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; 38 39 WrapperFunctionResult(WrapperFunctionResult &&Other) { 40 __orc_rt_CWrapperFunctionResultInit(&R); 41 std::swap(R, Other.R); 42 } 43 44 WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { 45 __orc_rt_CWrapperFunctionResult Tmp; 46 __orc_rt_CWrapperFunctionResultInit(&Tmp); 47 std::swap(Tmp, Other.R); 48 std::swap(R, Tmp); 49 return *this; 50 } 51 52 ~WrapperFunctionResult() { __orc_rt_DisposeCWrapperFunctionResult(&R); } 53 54 /// Relinquish ownership of and return the 55 /// __orc_rt_CWrapperFunctionResult. 56 __orc_rt_CWrapperFunctionResult release() { 57 __orc_rt_CWrapperFunctionResult Tmp; 58 __orc_rt_CWrapperFunctionResultInit(&Tmp); 59 std::swap(R, Tmp); 60 return Tmp; 61 } 62 63 /// Get a pointer to the data contained in this instance. 64 const char *data() const { return __orc_rt_CWrapperFunctionResultData(&R); } 65 66 /// Returns the size of the data contained in this instance. 67 size_t size() const { return __orc_rt_CWrapperFunctionResultSize(&R); } 68 69 /// Returns true if this value is equivalent to a default-constructed 70 /// WrapperFunctionResult. 71 bool empty() const { return __orc_rt_CWrapperFunctionResultEmpty(&R); } 72 73 /// Create a WrapperFunctionResult with the given size and return a pointer 74 /// to the underlying memory. 75 static char *allocate(WrapperFunctionResult &R, size_t Size) { 76 __orc_rt_DisposeCWrapperFunctionResult(&R.R); 77 __orc_rt_CWrapperFunctionResultInit(&R.R); 78 return __orc_rt_CWrapperFunctionResultAllocate(&R.R, Size); 79 } 80 81 /// Copy from the given char range. 82 static WrapperFunctionResult copyFrom(const char *Source, size_t Size) { 83 return __orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size); 84 } 85 86 /// Copy from the given null-terminated string (includes the null-terminator). 87 static WrapperFunctionResult copyFrom(const char *Source) { 88 return __orc_rt_CreateCWrapperFunctionResultFromString(Source); 89 } 90 91 /// Copy from the given std::string (includes the null terminator). 92 static WrapperFunctionResult copyFrom(const std::string &Source) { 93 return copyFrom(Source.c_str()); 94 } 95 96 /// Create an out-of-band error by copying the given string. 97 static WrapperFunctionResult createOutOfBandError(const char *Msg) { 98 return __orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg); 99 } 100 101 /// Create an out-of-band error by copying the given string. 102 static WrapperFunctionResult createOutOfBandError(const std::string &Msg) { 103 return createOutOfBandError(Msg.c_str()); 104 } 105 106 /// If this value is an out-of-band error then this returns the error message, 107 /// otherwise returns nullptr. 108 const char *getOutOfBandError() const { 109 return __orc_rt_CWrapperFunctionResultGetOutOfBandError(&R); 110 } 111 112 private: 113 __orc_rt_CWrapperFunctionResult R; 114 }; 115 116 namespace detail { 117 118 template <typename SPSArgListT, typename... ArgTs> 119 Expected<WrapperFunctionResult> 120 serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) { 121 WrapperFunctionResult Result; 122 char *DataPtr = 123 WrapperFunctionResult::allocate(Result, SPSArgListT::size(Args...)); 124 SPSOutputBuffer OB(DataPtr, Result.size()); 125 if (!SPSArgListT::serialize(OB, Args...)) 126 return make_error<StringError>( 127 "Error serializing arguments to blob in call"); 128 return std::move(Result); 129 } 130 131 template <typename RetT> class WrapperFunctionHandlerCaller { 132 public: 133 template <typename HandlerT, typename ArgTupleT, std::size_t... I> 134 static decltype(auto) call(HandlerT &&H, ArgTupleT &Args, 135 std::index_sequence<I...>) { 136 return std::forward<HandlerT>(H)(std::get<I>(Args)...); 137 } 138 }; 139 140 template <> class WrapperFunctionHandlerCaller<void> { 141 public: 142 template <typename HandlerT, typename ArgTupleT, std::size_t... I> 143 static SPSEmpty call(HandlerT &&H, ArgTupleT &Args, 144 std::index_sequence<I...>) { 145 std::forward<HandlerT>(H)(std::get<I>(Args)...); 146 return SPSEmpty(); 147 } 148 }; 149 150 template <typename WrapperFunctionImplT, 151 template <typename> class ResultSerializer, typename... SPSTagTs> 152 class WrapperFunctionHandlerHelper 153 : public WrapperFunctionHandlerHelper< 154 decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()), 155 ResultSerializer, SPSTagTs...> {}; 156 157 template <typename RetT, typename... ArgTs, 158 template <typename> class ResultSerializer, typename... SPSTagTs> 159 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 160 SPSTagTs...> { 161 public: 162 using ArgTuple = std::tuple<std::decay_t<ArgTs>...>; 163 using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>; 164 165 template <typename HandlerT> 166 static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData, 167 size_t ArgSize) { 168 ArgTuple Args; 169 if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) 170 return WrapperFunctionResult::createOutOfBandError( 171 "Could not deserialize arguments for wrapper function call"); 172 173 auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call( 174 std::forward<HandlerT>(H), Args, ArgIndices{}); 175 176 if (auto Result = ResultSerializer<decltype(HandlerResult)>::serialize( 177 std::move(HandlerResult))) 178 return std::move(*Result); 179 else 180 return WrapperFunctionResult::createOutOfBandError( 181 toString(Result.takeError())); 182 } 183 184 private: 185 template <std::size_t... I> 186 static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, 187 std::index_sequence<I...>) { 188 SPSInputBuffer IB(ArgData, ArgSize); 189 return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...); 190 } 191 192 }; 193 194 // Map function references to function types. 195 template <typename RetT, typename... ArgTs, 196 template <typename> class ResultSerializer, typename... SPSTagTs> 197 class WrapperFunctionHandlerHelper<RetT (&)(ArgTs...), ResultSerializer, 198 SPSTagTs...> 199 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 200 SPSTagTs...> {}; 201 202 // Map non-const member function types to function types. 203 template <typename ClassT, typename RetT, typename... ArgTs, 204 template <typename> class ResultSerializer, typename... SPSTagTs> 205 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer, 206 SPSTagTs...> 207 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 208 SPSTagTs...> {}; 209 210 // Map const member function types to function types. 211 template <typename ClassT, typename RetT, typename... ArgTs, 212 template <typename> class ResultSerializer, typename... SPSTagTs> 213 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const, 214 ResultSerializer, SPSTagTs...> 215 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 216 SPSTagTs...> {}; 217 218 template <typename SPSRetTagT, typename RetT> class ResultSerializer { 219 public: 220 static Expected<WrapperFunctionResult> serialize(RetT Result) { 221 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( 222 Result); 223 } 224 }; 225 226 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> { 227 public: 228 static Expected<WrapperFunctionResult> serialize(Error Err) { 229 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( 230 toSPSSerializable(std::move(Err))); 231 } 232 }; 233 234 template <typename SPSRetTagT, typename T> 235 class ResultSerializer<SPSRetTagT, Expected<T>> { 236 public: 237 static Expected<WrapperFunctionResult> serialize(Expected<T> E) { 238 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( 239 toSPSSerializable(std::move(E))); 240 } 241 }; 242 243 template <typename SPSRetTagT, typename RetT> class ResultDeserializer { 244 public: 245 static void makeSafe(RetT &Result) {} 246 247 static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) { 248 SPSInputBuffer IB(ArgData, ArgSize); 249 if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result)) 250 return make_error<StringError>( 251 "Error deserializing return value from blob in call"); 252 return Error::success(); 253 } 254 }; 255 256 template <> class ResultDeserializer<SPSError, Error> { 257 public: 258 static void makeSafe(Error &Err) { cantFail(std::move(Err)); } 259 260 static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) { 261 SPSInputBuffer IB(ArgData, ArgSize); 262 SPSSerializableError BSE; 263 if (!SPSArgList<SPSError>::deserialize(IB, BSE)) 264 return make_error<StringError>( 265 "Error deserializing return value from blob in call"); 266 Err = fromSPSSerializable(std::move(BSE)); 267 return Error::success(); 268 } 269 }; 270 271 template <typename SPSTagT, typename T> 272 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> { 273 public: 274 static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); } 275 276 static Error deserialize(Expected<T> &E, const char *ArgData, 277 size_t ArgSize) { 278 SPSInputBuffer IB(ArgData, ArgSize); 279 SPSSerializableExpected<T> BSE; 280 if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE)) 281 return make_error<StringError>( 282 "Error deserializing return value from blob in call"); 283 E = fromSPSSerializable(std::move(BSE)); 284 return Error::success(); 285 } 286 }; 287 288 } // end namespace detail 289 290 template <typename SPSSignature> class WrapperFunction; 291 292 template <typename SPSRetTagT, typename... SPSTagTs> 293 class WrapperFunction<SPSRetTagT(SPSTagTs...)> { 294 private: 295 template <typename RetT> 296 using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>; 297 298 public: 299 template <typename RetT, typename... ArgTs> 300 static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) { 301 302 // RetT might be an Error or Expected value. Set the checked flag now: 303 // we don't want the user to have to check the unused result if this 304 // operation fails. 305 detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result); 306 307 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx)) 308 return make_error<StringError>("__orc_rt_jit_dispatch_ctx not set"); 309 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch)) 310 return make_error<StringError>("__orc_rt_jit_dispatch not set"); 311 312 auto ArgBuffer = 313 detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>( 314 Args...); 315 if (!ArgBuffer) 316 return ArgBuffer.takeError(); 317 318 WrapperFunctionResult ResultBuffer = 319 __orc_rt_jit_dispatch(&__orc_rt_jit_dispatch_ctx, FnTag, 320 ArgBuffer->data(), ArgBuffer->size()); 321 if (auto ErrMsg = ResultBuffer.getOutOfBandError()) 322 return make_error<StringError>(ErrMsg); 323 324 return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize( 325 Result, ResultBuffer.data(), ResultBuffer.size()); 326 } 327 328 template <typename HandlerT> 329 static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize, 330 HandlerT &&Handler) { 331 using WFHH = 332 detail::WrapperFunctionHandlerHelper<HandlerT, ResultSerializer, 333 SPSTagTs...>; 334 return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize); 335 } 336 337 private: 338 template <typename T> static const T &makeSerializable(const T &Value) { 339 return Value; 340 } 341 342 static detail::SPSSerializableError makeSerializable(Error Err) { 343 return detail::toSPSSerializable(std::move(Err)); 344 } 345 346 template <typename T> 347 static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) { 348 return detail::toSPSSerializable(std::move(E)); 349 } 350 }; 351 352 template <typename... SPSTagTs> 353 class WrapperFunction<void(SPSTagTs...)> 354 : private WrapperFunction<SPSEmpty(SPSTagTs...)> { 355 public: 356 template <typename... ArgTs> 357 static Error call(const void *FnTag, const ArgTs &...Args) { 358 SPSEmpty BE; 359 return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(FnTag, BE, Args...); 360 } 361 362 using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle; 363 }; 364 365 } // end namespace __orc_rt 366 367 #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H 368