1 //===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file is a part of the ORC runtime support library. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H 14 #define ORC_RT_WRAPPER_FUNCTION_UTILS_H 15 16 #include "orc_rt/c_api.h" 17 #include "common.h" 18 #include "error.h" 19 #include "executor_address.h" 20 #include "simple_packed_serialization.h" 21 #include <type_traits> 22 23 namespace __orc_rt { 24 25 /// C++ wrapper function result: Same as CWrapperFunctionResult but 26 /// auto-releases memory. 27 class WrapperFunctionResult { 28 public: 29 /// Create a default WrapperFunctionResult. 30 WrapperFunctionResult() { orc_rt_CWrapperFunctionResultInit(&R); } 31 32 /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This 33 /// instance takes ownership of the result object and will automatically 34 /// call dispose on the result upon destruction. 35 WrapperFunctionResult(orc_rt_CWrapperFunctionResult R) : R(R) {} 36 37 WrapperFunctionResult(const WrapperFunctionResult &) = delete; 38 WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; 39 40 WrapperFunctionResult(WrapperFunctionResult &&Other) { 41 orc_rt_CWrapperFunctionResultInit(&R); 42 std::swap(R, Other.R); 43 } 44 45 WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { 46 orc_rt_CWrapperFunctionResult Tmp; 47 orc_rt_CWrapperFunctionResultInit(&Tmp); 48 std::swap(Tmp, Other.R); 49 std::swap(R, Tmp); 50 return *this; 51 } 52 53 ~WrapperFunctionResult() { orc_rt_DisposeCWrapperFunctionResult(&R); } 54 55 /// Relinquish ownership of and return the 56 /// orc_rt_CWrapperFunctionResult. 57 orc_rt_CWrapperFunctionResult release() { 58 orc_rt_CWrapperFunctionResult Tmp; 59 orc_rt_CWrapperFunctionResultInit(&Tmp); 60 std::swap(R, Tmp); 61 return Tmp; 62 } 63 64 /// Get a pointer to the data contained in this instance. 65 char *data() { return orc_rt_CWrapperFunctionResultData(&R); } 66 67 /// Returns the size of the data contained in this instance. 68 size_t size() const { return orc_rt_CWrapperFunctionResultSize(&R); } 69 70 /// Returns true if this value is equivalent to a default-constructed 71 /// WrapperFunctionResult. 72 bool empty() const { return orc_rt_CWrapperFunctionResultEmpty(&R); } 73 74 /// Create a WrapperFunctionResult with the given size and return a pointer 75 /// to the underlying memory. 76 static WrapperFunctionResult allocate(size_t Size) { 77 WrapperFunctionResult R; 78 R.R = orc_rt_CWrapperFunctionResultAllocate(Size); 79 return R; 80 } 81 82 /// Copy from the given char range. 83 static WrapperFunctionResult copyFrom(const char *Source, size_t Size) { 84 return orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size); 85 } 86 87 /// Copy from the given null-terminated string (includes the null-terminator). 88 static WrapperFunctionResult copyFrom(const char *Source) { 89 return orc_rt_CreateCWrapperFunctionResultFromString(Source); 90 } 91 92 /// Copy from the given std::string (includes the null terminator). 93 static WrapperFunctionResult copyFrom(const std::string &Source) { 94 return copyFrom(Source.c_str()); 95 } 96 97 /// Create an out-of-band error by copying the given string. 98 static WrapperFunctionResult createOutOfBandError(const char *Msg) { 99 return orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg); 100 } 101 102 /// Create an out-of-band error by copying the given string. 103 static WrapperFunctionResult createOutOfBandError(const std::string &Msg) { 104 return createOutOfBandError(Msg.c_str()); 105 } 106 107 template <typename SPSArgListT, typename... ArgTs> 108 static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) { 109 auto Result = allocate(SPSArgListT::size(Args...)); 110 SPSOutputBuffer OB(Result.data(), Result.size()); 111 if (!SPSArgListT::serialize(OB, Args...)) 112 return createOutOfBandError( 113 "Error serializing arguments to blob in call"); 114 return Result; 115 } 116 117 /// If this value is an out-of-band error then this returns the error message, 118 /// otherwise returns nullptr. 119 const char *getOutOfBandError() const { 120 return orc_rt_CWrapperFunctionResultGetOutOfBandError(&R); 121 } 122 123 private: 124 orc_rt_CWrapperFunctionResult R; 125 }; 126 127 namespace detail { 128 129 template <typename RetT> class WrapperFunctionHandlerCaller { 130 public: 131 template <typename HandlerT, typename ArgTupleT, std::size_t... I> 132 static decltype(auto) call(HandlerT &&H, ArgTupleT &Args, 133 std::index_sequence<I...>) { 134 return std::forward<HandlerT>(H)(std::get<I>(Args)...); 135 } 136 }; 137 138 template <> class WrapperFunctionHandlerCaller<void> { 139 public: 140 template <typename HandlerT, typename ArgTupleT, std::size_t... I> 141 static SPSEmpty call(HandlerT &&H, ArgTupleT &Args, 142 std::index_sequence<I...>) { 143 std::forward<HandlerT>(H)(std::get<I>(Args)...); 144 return SPSEmpty(); 145 } 146 }; 147 148 template <typename WrapperFunctionImplT, 149 template <typename> class ResultSerializer, typename... SPSTagTs> 150 class WrapperFunctionHandlerHelper 151 : public WrapperFunctionHandlerHelper< 152 decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()), 153 ResultSerializer, SPSTagTs...> {}; 154 155 template <typename RetT, typename... ArgTs, 156 template <typename> class ResultSerializer, typename... SPSTagTs> 157 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 158 SPSTagTs...> { 159 public: 160 using ArgTuple = std::tuple<std::decay_t<ArgTs>...>; 161 using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>; 162 163 template <typename HandlerT> 164 static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData, 165 size_t ArgSize) { 166 ArgTuple Args; 167 if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) 168 return WrapperFunctionResult::createOutOfBandError( 169 "Could not deserialize arguments for wrapper function call"); 170 171 auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call( 172 std::forward<HandlerT>(H), Args, ArgIndices{}); 173 174 return ResultSerializer<decltype(HandlerResult)>::serialize( 175 std::move(HandlerResult)); 176 } 177 178 private: 179 template <std::size_t... I> 180 static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, 181 std::index_sequence<I...>) { 182 SPSInputBuffer IB(ArgData, ArgSize); 183 return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...); 184 } 185 }; 186 187 // Map function pointers to function types. 188 template <typename RetT, typename... ArgTs, 189 template <typename> class ResultSerializer, typename... SPSTagTs> 190 class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer, 191 SPSTagTs...> 192 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 193 SPSTagTs...> {}; 194 195 // Map non-const member function types to function types. 196 template <typename ClassT, typename RetT, typename... ArgTs, 197 template <typename> class ResultSerializer, typename... SPSTagTs> 198 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer, 199 SPSTagTs...> 200 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 201 SPSTagTs...> {}; 202 203 // Map const member function types to function types. 204 template <typename ClassT, typename RetT, typename... ArgTs, 205 template <typename> class ResultSerializer, typename... SPSTagTs> 206 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const, 207 ResultSerializer, SPSTagTs...> 208 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 209 SPSTagTs...> {}; 210 211 template <typename SPSRetTagT, typename RetT> class ResultSerializer { 212 public: 213 static WrapperFunctionResult serialize(RetT Result) { 214 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result); 215 } 216 }; 217 218 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> { 219 public: 220 static WrapperFunctionResult serialize(Error Err) { 221 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>( 222 toSPSSerializable(std::move(Err))); 223 } 224 }; 225 226 template <typename SPSRetTagT, typename T> 227 class ResultSerializer<SPSRetTagT, Expected<T>> { 228 public: 229 static WrapperFunctionResult serialize(Expected<T> E) { 230 return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>( 231 toSPSSerializable(std::move(E))); 232 } 233 }; 234 235 template <typename SPSRetTagT, typename RetT> class ResultDeserializer { 236 public: 237 static void makeSafe(RetT &Result) {} 238 239 static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) { 240 SPSInputBuffer IB(ArgData, ArgSize); 241 if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result)) 242 return make_error<StringError>( 243 "Error deserializing return value from blob in call"); 244 return Error::success(); 245 } 246 }; 247 248 template <> class ResultDeserializer<SPSError, Error> { 249 public: 250 static void makeSafe(Error &Err) { cantFail(std::move(Err)); } 251 252 static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) { 253 SPSInputBuffer IB(ArgData, ArgSize); 254 SPSSerializableError BSE; 255 if (!SPSArgList<SPSError>::deserialize(IB, BSE)) 256 return make_error<StringError>( 257 "Error deserializing return value from blob in call"); 258 Err = fromSPSSerializable(std::move(BSE)); 259 return Error::success(); 260 } 261 }; 262 263 template <typename SPSTagT, typename T> 264 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> { 265 public: 266 static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); } 267 268 static Error deserialize(Expected<T> &E, const char *ArgData, 269 size_t ArgSize) { 270 SPSInputBuffer IB(ArgData, ArgSize); 271 SPSSerializableExpected<T> BSE; 272 if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE)) 273 return make_error<StringError>( 274 "Error deserializing return value from blob in call"); 275 E = fromSPSSerializable(std::move(BSE)); 276 return Error::success(); 277 } 278 }; 279 280 } // end namespace detail 281 282 template <typename SPSSignature> class WrapperFunction; 283 284 template <typename SPSRetTagT, typename... SPSTagTs> 285 class WrapperFunction<SPSRetTagT(SPSTagTs...)> { 286 private: 287 template <typename RetT> 288 using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>; 289 290 public: 291 template <typename RetT, typename... ArgTs> 292 static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) { 293 294 // RetT might be an Error or Expected value. Set the checked flag now: 295 // we don't want the user to have to check the unused result if this 296 // operation fails. 297 detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result); 298 299 // Since the functions cannot be zero/unresolved on Windows, the following 300 // reference taking would always be non-zero, thus generating a compiler 301 // warning otherwise. 302 #if !defined(_WIN32) 303 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx)) 304 return make_error<StringError>("__orc_rt_jit_dispatch_ctx not set"); 305 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch)) 306 return make_error<StringError>("__orc_rt_jit_dispatch not set"); 307 #endif 308 auto ArgBuffer = 309 WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...); 310 if (const char *ErrMsg = ArgBuffer.getOutOfBandError()) 311 return make_error<StringError>(ErrMsg); 312 313 WrapperFunctionResult ResultBuffer = __orc_rt_jit_dispatch( 314 &__orc_rt_jit_dispatch_ctx, FnTag, ArgBuffer.data(), ArgBuffer.size()); 315 if (auto ErrMsg = ResultBuffer.getOutOfBandError()) 316 return make_error<StringError>(ErrMsg); 317 318 return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize( 319 Result, ResultBuffer.data(), ResultBuffer.size()); 320 } 321 322 template <typename HandlerT> 323 static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize, 324 HandlerT &&Handler) { 325 using WFHH = 326 detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>, 327 ResultSerializer, SPSTagTs...>; 328 return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize); 329 } 330 331 private: 332 template <typename T> static const T &makeSerializable(const T &Value) { 333 return Value; 334 } 335 336 static detail::SPSSerializableError makeSerializable(Error Err) { 337 return detail::toSPSSerializable(std::move(Err)); 338 } 339 340 template <typename T> 341 static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) { 342 return detail::toSPSSerializable(std::move(E)); 343 } 344 }; 345 346 template <typename... SPSTagTs> 347 class WrapperFunction<void(SPSTagTs...)> 348 : private WrapperFunction<SPSEmpty(SPSTagTs...)> { 349 public: 350 template <typename... ArgTs> 351 static Error call(const void *FnTag, const ArgTs &...Args) { 352 SPSEmpty BE; 353 return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(FnTag, BE, Args...); 354 } 355 356 using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle; 357 }; 358 359 /// A function object that takes an ExecutorAddr as its first argument, 360 /// casts that address to a ClassT*, then calls the given method on that 361 /// pointer passing in the remaining function arguments. This utility 362 /// removes some of the boilerplate from writing wrappers for method calls. 363 /// 364 /// @code{.cpp} 365 /// class MyClass { 366 /// public: 367 /// void myMethod(uint32_t, bool) { ... } 368 /// }; 369 /// 370 /// // SPS Method signature -- note MyClass object address as first argument. 371 /// using SPSMyMethodWrapperSignature = 372 /// SPSTuple<SPSExecutorAddr, uint32_t, bool>; 373 /// 374 /// WrapperFunctionResult 375 /// myMethodCallWrapper(const char *ArgData, size_t ArgSize) { 376 /// return WrapperFunction<SPSMyMethodWrapperSignature>::handle( 377 /// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod)); 378 /// } 379 /// @endcode 380 /// 381 template <typename RetT, typename ClassT, typename... ArgTs> 382 class MethodWrapperHandler { 383 public: 384 using MethodT = RetT (ClassT::*)(ArgTs...); 385 MethodWrapperHandler(MethodT M) : M(M) {} 386 RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) { 387 return (ObjAddr.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...); 388 } 389 390 private: 391 MethodT M; 392 }; 393 394 /// Create a MethodWrapperHandler object from the given method pointer. 395 template <typename RetT, typename ClassT, typename... ArgTs> 396 MethodWrapperHandler<RetT, ClassT, ArgTs...> 397 makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) { 398 return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method); 399 } 400 401 /// Represents a call to a wrapper function. 402 class WrapperFunctionCall { 403 public: 404 // FIXME: Switch to a SmallVector<char, 24> once ORC runtime has a 405 // smallvector. 406 using ArgDataBufferType = std::vector<char>; 407 408 /// Create a WrapperFunctionCall using the given SPS serializer to serialize 409 /// the arguments. 410 template <typename SPSSerializer, typename... ArgTs> 411 static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr, 412 const ArgTs &...Args) { 413 ArgDataBufferType ArgData; 414 ArgData.resize(SPSSerializer::size(Args...)); 415 SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(), 416 ArgData.size()); 417 if (SPSSerializer::serialize(OB, Args...)) 418 return WrapperFunctionCall(FnAddr, std::move(ArgData)); 419 return make_error<StringError>("Cannot serialize arguments for " 420 "AllocActionCall"); 421 } 422 423 WrapperFunctionCall() = default; 424 425 /// Create a WrapperFunctionCall from a target function and arg buffer. 426 WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData) 427 : FnAddr(FnAddr), ArgData(std::move(ArgData)) {} 428 429 /// Returns the address to be called. 430 const ExecutorAddr &getCallee() const { return FnAddr; } 431 432 /// Returns the argument data. 433 const ArgDataBufferType &getArgData() const { return ArgData; } 434 435 /// WrapperFunctionCalls convert to true if the callee is non-null. 436 explicit operator bool() const { return !!FnAddr; } 437 438 /// Run call returning raw WrapperFunctionResult. 439 WrapperFunctionResult run() const { 440 using FnTy = 441 orc_rt_CWrapperFunctionResult(const char *ArgData, size_t ArgSize); 442 return WrapperFunctionResult( 443 FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size())); 444 } 445 446 /// Run call and deserialize result using SPS. 447 template <typename SPSRetT, typename RetT> 448 std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error> 449 runWithSPSRet(RetT &RetVal) const { 450 auto WFR = run(); 451 if (const char *ErrMsg = WFR.getOutOfBandError()) 452 return make_error<StringError>(ErrMsg); 453 SPSInputBuffer IB(WFR.data(), WFR.size()); 454 if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal)) 455 return make_error<StringError>("Could not deserialize result from " 456 "serialized wrapper function call"); 457 return Error::success(); 458 } 459 460 /// Overload for SPS functions returning void. 461 template <typename SPSRetT> 462 std::enable_if_t<std::is_same<SPSRetT, void>::value, Error> 463 runWithSPSRet() const { 464 SPSEmpty E; 465 return runWithSPSRet<SPSEmpty>(E); 466 } 467 468 /// Run call and deserialize an SPSError result. SPSError returns and 469 /// deserialization failures are merged into the returned error. 470 Error runWithSPSRetErrorMerged() const { 471 detail::SPSSerializableError RetErr; 472 if (auto Err = runWithSPSRet<SPSError>(RetErr)) 473 return Err; 474 return detail::fromSPSSerializable(std::move(RetErr)); 475 } 476 477 private: 478 ExecutorAddr FnAddr; 479 std::vector<char> ArgData; 480 }; 481 482 using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>; 483 484 template <> 485 class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> { 486 public: 487 static size_t size(const WrapperFunctionCall &WFC) { 488 return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::size( 489 WFC.getCallee(), WFC.getArgData()); 490 } 491 492 static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) { 493 return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::serialize( 494 OB, WFC.getCallee(), WFC.getArgData()); 495 } 496 497 static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) { 498 ExecutorAddr FnAddr; 499 WrapperFunctionCall::ArgDataBufferType ArgData; 500 if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData)) 501 return false; 502 WFC = WrapperFunctionCall(FnAddr, std::move(ArgData)); 503 return true; 504 } 505 }; 506 507 } // end namespace __orc_rt 508 509 #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H 510