1 //===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file is a part of the ORC runtime support library. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H 14 #define ORC_RT_WRAPPER_FUNCTION_UTILS_H 15 16 #include "orc_rt/c_api.h" 17 #include "common.h" 18 #include "error.h" 19 #include "executor_address.h" 20 #include "simple_packed_serialization.h" 21 #include <type_traits> 22 23 namespace __orc_rt { 24 25 /// C++ wrapper function result: Same as CWrapperFunctionResult but 26 /// auto-releases memory. 27 class WrapperFunctionResult { 28 public: 29 /// Create a default WrapperFunctionResult. 30 WrapperFunctionResult() { __orc_rt_CWrapperFunctionResultInit(&R); } 31 32 /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This 33 /// instance takes ownership of the result object and will automatically 34 /// call dispose on the result upon destruction. 35 WrapperFunctionResult(__orc_rt_CWrapperFunctionResult R) : R(R) {} 36 37 WrapperFunctionResult(const WrapperFunctionResult &) = delete; 38 WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; 39 40 WrapperFunctionResult(WrapperFunctionResult &&Other) { 41 __orc_rt_CWrapperFunctionResultInit(&R); 42 std::swap(R, Other.R); 43 } 44 45 WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { 46 __orc_rt_CWrapperFunctionResult Tmp; 47 __orc_rt_CWrapperFunctionResultInit(&Tmp); 48 std::swap(Tmp, Other.R); 49 std::swap(R, Tmp); 50 return *this; 51 } 52 53 ~WrapperFunctionResult() { __orc_rt_DisposeCWrapperFunctionResult(&R); } 54 55 /// Relinquish ownership of and return the 56 /// __orc_rt_CWrapperFunctionResult. 57 __orc_rt_CWrapperFunctionResult release() { 58 __orc_rt_CWrapperFunctionResult Tmp; 59 __orc_rt_CWrapperFunctionResultInit(&Tmp); 60 std::swap(R, Tmp); 61 return Tmp; 62 } 63 64 /// Get a pointer to the data contained in this instance. 65 char *data() { return __orc_rt_CWrapperFunctionResultData(&R); } 66 67 /// Returns the size of the data contained in this instance. 68 size_t size() const { return __orc_rt_CWrapperFunctionResultSize(&R); } 69 70 /// Returns true if this value is equivalent to a default-constructed 71 /// WrapperFunctionResult. 72 bool empty() const { return __orc_rt_CWrapperFunctionResultEmpty(&R); } 73 74 /// Create a WrapperFunctionResult with the given size and return a pointer 75 /// to the underlying memory. 76 static WrapperFunctionResult allocate(size_t Size) { 77 WrapperFunctionResult R; 78 R.R = __orc_rt_CWrapperFunctionResultAllocate(Size); 79 return R; 80 } 81 82 /// Copy from the given char range. 83 static WrapperFunctionResult copyFrom(const char *Source, size_t Size) { 84 return __orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size); 85 } 86 87 /// Copy from the given null-terminated string (includes the null-terminator). 88 static WrapperFunctionResult copyFrom(const char *Source) { 89 return __orc_rt_CreateCWrapperFunctionResultFromString(Source); 90 } 91 92 /// Copy from the given std::string (includes the null terminator). 93 static WrapperFunctionResult copyFrom(const std::string &Source) { 94 return copyFrom(Source.c_str()); 95 } 96 97 /// Create an out-of-band error by copying the given string. 98 static WrapperFunctionResult createOutOfBandError(const char *Msg) { 99 return __orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg); 100 } 101 102 /// Create an out-of-band error by copying the given string. 103 static WrapperFunctionResult createOutOfBandError(const std::string &Msg) { 104 return createOutOfBandError(Msg.c_str()); 105 } 106 107 template <typename SPSArgListT, typename... ArgTs> 108 static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) { 109 auto Result = allocate(SPSArgListT::size(Args...)); 110 SPSOutputBuffer OB(Result.data(), Result.size()); 111 if (!SPSArgListT::serialize(OB, Args...)) 112 return createOutOfBandError( 113 "Error serializing arguments to blob in call"); 114 return Result; 115 } 116 117 /// If this value is an out-of-band error then this returns the error message, 118 /// otherwise returns nullptr. 119 const char *getOutOfBandError() const { 120 return __orc_rt_CWrapperFunctionResultGetOutOfBandError(&R); 121 } 122 123 private: 124 __orc_rt_CWrapperFunctionResult R; 125 }; 126 127 namespace detail { 128 129 template <typename RetT> class WrapperFunctionHandlerCaller { 130 public: 131 template <typename HandlerT, typename ArgTupleT, std::size_t... I> 132 static decltype(auto) call(HandlerT &&H, ArgTupleT &Args, 133 std::index_sequence<I...>) { 134 return std::forward<HandlerT>(H)(std::get<I>(Args)...); 135 } 136 }; 137 138 template <> class WrapperFunctionHandlerCaller<void> { 139 public: 140 template <typename HandlerT, typename ArgTupleT, std::size_t... I> 141 static SPSEmpty call(HandlerT &&H, ArgTupleT &Args, 142 std::index_sequence<I...>) { 143 std::forward<HandlerT>(H)(std::get<I>(Args)...); 144 return SPSEmpty(); 145 } 146 }; 147 148 template <typename WrapperFunctionImplT, 149 template <typename> class ResultSerializer, typename... SPSTagTs> 150 class WrapperFunctionHandlerHelper 151 : public WrapperFunctionHandlerHelper< 152 decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()), 153 ResultSerializer, SPSTagTs...> {}; 154 155 template <typename RetT, typename... ArgTs, 156 template <typename> class ResultSerializer, typename... SPSTagTs> 157 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 158 SPSTagTs...> { 159 public: 160 using ArgTuple = std::tuple<std::decay_t<ArgTs>...>; 161 using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>; 162 163 template <typename HandlerT> 164 static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData, 165 size_t ArgSize) { 166 ArgTuple Args; 167 if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) 168 return WrapperFunctionResult::createOutOfBandError( 169 "Could not deserialize arguments for wrapper function call"); 170 171 auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call( 172 std::forward<HandlerT>(H), Args, ArgIndices{}); 173 174 return ResultSerializer<decltype(HandlerResult)>::serialize( 175 std::move(HandlerResult)); 176 } 177 178 private: 179 template <std::size_t... I> 180 static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, 181 std::index_sequence<I...>) { 182 SPSInputBuffer IB(ArgData, ArgSize); 183 return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...); 184 } 185 }; 186 187 // Map function pointers to function types. 188 template <typename RetT, typename... ArgTs, 189 template <typename> class ResultSerializer, typename... SPSTagTs> 190 class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer, 191 SPSTagTs...> 192 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 193 SPSTagTs...> {}; 194 195 // Map non-const member function types to function types. 196 template <typename ClassT, typename RetT, typename... ArgTs, 197 template <typename> class ResultSerializer, typename... SPSTagTs> 198 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer, 199 SPSTagTs...> 200 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 201 SPSTagTs...> {}; 202 203 // Map const member function types to function types. 204 template <typename ClassT, typename RetT, typename... ArgTs, 205 template <typename> class ResultSerializer, typename... SPSTagTs> 206 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const, 207 ResultSerializer, SPSTagTs...> 208 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 209 SPSTagTs...> {}; 210 211 template <typename SPSRetTagT, typename RetT> class ResultSerializer { 212 public: 213 static WrapperFunctionResult serialize(RetT Result) { 214 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result); 215 } 216 }; 217 218 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> { 219 public: 220 static WrapperFunctionResult serialize(Error Err) { 221 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>( 222 toSPSSerializable(std::move(Err))); 223 } 224 }; 225 226 template <typename SPSRetTagT, typename T> 227 class ResultSerializer<SPSRetTagT, Expected<T>> { 228 public: 229 static WrapperFunctionResult serialize(Expected<T> E) { 230 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>( 231 toSPSSerializable(std::move(E))); 232 } 233 }; 234 235 template <typename SPSRetTagT, typename RetT> class ResultDeserializer { 236 public: 237 static void makeSafe(RetT &Result) {} 238 239 static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) { 240 SPSInputBuffer IB(ArgData, ArgSize); 241 if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result)) 242 return make_error<StringError>( 243 "Error deserializing return value from blob in call"); 244 return Error::success(); 245 } 246 }; 247 248 template <> class ResultDeserializer<SPSError, Error> { 249 public: 250 static void makeSafe(Error &Err) { cantFail(std::move(Err)); } 251 252 static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) { 253 SPSInputBuffer IB(ArgData, ArgSize); 254 SPSSerializableError BSE; 255 if (!SPSArgList<SPSError>::deserialize(IB, BSE)) 256 return make_error<StringError>( 257 "Error deserializing return value from blob in call"); 258 Err = fromSPSSerializable(std::move(BSE)); 259 return Error::success(); 260 } 261 }; 262 263 template <typename SPSTagT, typename T> 264 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> { 265 public: 266 static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); } 267 268 static Error deserialize(Expected<T> &E, const char *ArgData, 269 size_t ArgSize) { 270 SPSInputBuffer IB(ArgData, ArgSize); 271 SPSSerializableExpected<T> BSE; 272 if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE)) 273 return make_error<StringError>( 274 "Error deserializing return value from blob in call"); 275 E = fromSPSSerializable(std::move(BSE)); 276 return Error::success(); 277 } 278 }; 279 280 } // end namespace detail 281 282 template <typename SPSSignature> class WrapperFunction; 283 284 template <typename SPSRetTagT, typename... SPSTagTs> 285 class WrapperFunction<SPSRetTagT(SPSTagTs...)> { 286 private: 287 template <typename RetT> 288 using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>; 289 290 public: 291 template <typename RetT, typename... ArgTs> 292 static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) { 293 294 // RetT might be an Error or Expected value. Set the checked flag now: 295 // we don't want the user to have to check the unused result if this 296 // operation fails. 297 detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result); 298 299 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx)) 300 return make_error<StringError>("__orc_rt_jit_dispatch_ctx not set"); 301 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch)) 302 return make_error<StringError>("__orc_rt_jit_dispatch not set"); 303 304 auto ArgBuffer = 305 WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...); 306 if (const char *ErrMsg = ArgBuffer.getOutOfBandError()) 307 return make_error<StringError>(ErrMsg); 308 309 WrapperFunctionResult ResultBuffer = __orc_rt_jit_dispatch( 310 &__orc_rt_jit_dispatch_ctx, FnTag, ArgBuffer.data(), ArgBuffer.size()); 311 if (auto ErrMsg = ResultBuffer.getOutOfBandError()) 312 return make_error<StringError>(ErrMsg); 313 314 return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize( 315 Result, ResultBuffer.data(), ResultBuffer.size()); 316 } 317 318 template <typename HandlerT> 319 static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize, 320 HandlerT &&Handler) { 321 using WFHH = 322 detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>, 323 ResultSerializer, SPSTagTs...>; 324 return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize); 325 } 326 327 private: 328 template <typename T> static const T &makeSerializable(const T &Value) { 329 return Value; 330 } 331 332 static detail::SPSSerializableError makeSerializable(Error Err) { 333 return detail::toSPSSerializable(std::move(Err)); 334 } 335 336 template <typename T> 337 static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) { 338 return detail::toSPSSerializable(std::move(E)); 339 } 340 }; 341 342 template <typename... SPSTagTs> 343 class WrapperFunction<void(SPSTagTs...)> 344 : private WrapperFunction<SPSEmpty(SPSTagTs...)> { 345 public: 346 template <typename... ArgTs> 347 static Error call(const void *FnTag, const ArgTs &...Args) { 348 SPSEmpty BE; 349 return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(FnTag, BE, Args...); 350 } 351 352 using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle; 353 }; 354 355 /// A function object that takes an ExecutorAddr as its first argument, 356 /// casts that address to a ClassT*, then calls the given method on that 357 /// pointer passing in the remaining function arguments. This utility 358 /// removes some of the boilerplate from writing wrappers for method calls. 359 /// 360 /// @code{.cpp} 361 /// class MyClass { 362 /// public: 363 /// void myMethod(uint32_t, bool) { ... } 364 /// }; 365 /// 366 /// // SPS Method signature -- note MyClass object address as first argument. 367 /// using SPSMyMethodWrapperSignature = 368 /// SPSTuple<SPSExecutorAddr, uint32_t, bool>; 369 /// 370 /// WrapperFunctionResult 371 /// myMethodCallWrapper(const char *ArgData, size_t ArgSize) { 372 /// return WrapperFunction<SPSMyMethodWrapperSignature>::handle( 373 /// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod)); 374 /// } 375 /// @endcode 376 /// 377 template <typename RetT, typename ClassT, typename... ArgTs> 378 class MethodWrapperHandler { 379 public: 380 using MethodT = RetT (ClassT::*)(ArgTs...); 381 MethodWrapperHandler(MethodT M) : M(M) {} 382 RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) { 383 return (ObjAddr.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...); 384 } 385 386 private: 387 MethodT M; 388 }; 389 390 /// Create a MethodWrapperHandler object from the given method pointer. 391 template <typename RetT, typename ClassT, typename... ArgTs> 392 MethodWrapperHandler<RetT, ClassT, ArgTs...> 393 makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) { 394 return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method); 395 } 396 397 /// Represents a call to a wrapper function. 398 class WrapperFunctionCall { 399 public: 400 // FIXME: Switch to a SmallVector<char, 24> once ORC runtime has a 401 // smallvector. 402 using ArgDataBufferType = std::vector<char>; 403 404 /// Create a WrapperFunctionCall using the given SPS serializer to serialize 405 /// the arguments. 406 template <typename SPSSerializer, typename... ArgTs> 407 static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr, 408 const ArgTs &...Args) { 409 ArgDataBufferType ArgData; 410 ArgData.resize(SPSSerializer::size(Args...)); 411 SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(), 412 ArgData.size()); 413 if (SPSSerializer::serialize(OB, Args...)) 414 return WrapperFunctionCall(FnAddr, std::move(ArgData)); 415 return make_error<StringError>("Cannot serialize arguments for " 416 "AllocActionCall"); 417 } 418 419 WrapperFunctionCall() = default; 420 421 /// Create a WrapperFunctionCall from a target function and arg buffer. 422 WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData) 423 : FnAddr(FnAddr), ArgData(std::move(ArgData)) {} 424 425 /// Returns the address to be called. 426 const ExecutorAddr &getCallee() const { return FnAddr; } 427 428 /// Returns the argument data. 429 const ArgDataBufferType &getArgData() const { return ArgData; } 430 431 /// WrapperFunctionCalls convert to true if the callee is non-null. 432 explicit operator bool() const { return !!FnAddr; } 433 434 /// Run call returning raw WrapperFunctionResult. 435 WrapperFunctionResult run() const { 436 using FnTy = 437 __orc_rt_CWrapperFunctionResult(const char *ArgData, size_t ArgSize); 438 return WrapperFunctionResult( 439 FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size())); 440 } 441 442 /// Run call and deserialize result using SPS. 443 template <typename SPSRetT, typename RetT> 444 std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error> 445 runWithSPSRet(RetT &RetVal) const { 446 auto WFR = run(); 447 if (const char *ErrMsg = WFR.getOutOfBandError()) 448 return make_error<StringError>(ErrMsg); 449 SPSInputBuffer IB(WFR.data(), WFR.size()); 450 if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal)) 451 return make_error<StringError>("Could not deserialize result from " 452 "serialized wrapper function call"); 453 return Error::success(); 454 } 455 456 /// Overload for SPS functions returning void. 457 template <typename SPSRetT> 458 std::enable_if_t<std::is_same<SPSRetT, void>::value, Error> 459 runWithSPSRet() const { 460 SPSEmpty E; 461 return runWithSPSRet<SPSEmpty>(E); 462 } 463 464 /// Run call and deserialize an SPSError result. SPSError returns and 465 /// deserialization failures are merged into the returned error. 466 Error runWithSPSRetErrorMerged() const { 467 detail::SPSSerializableError RetErr; 468 if (auto Err = runWithSPSRet<SPSError>(RetErr)) 469 return Err; 470 return detail::fromSPSSerializable(std::move(RetErr)); 471 } 472 473 private: 474 ExecutorAddr FnAddr; 475 std::vector<char> ArgData; 476 }; 477 478 using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>; 479 480 template <> 481 class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> { 482 public: 483 static size_t size(const WrapperFunctionCall &WFC) { 484 return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::size( 485 WFC.getCallee(), WFC.getArgData()); 486 } 487 488 static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) { 489 return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::serialize( 490 OB, WFC.getCallee(), WFC.getArgData()); 491 } 492 493 static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) { 494 ExecutorAddr FnAddr; 495 WrapperFunctionCall::ArgDataBufferType ArgData; 496 if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData)) 497 return false; 498 WFC = WrapperFunctionCall(FnAddr, std::move(ArgData)); 499 return true; 500 } 501 }; 502 503 } // end namespace __orc_rt 504 505 #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H 506