//===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file is a part of the ORC runtime support library. // //===----------------------------------------------------------------------===// #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H #define ORC_RT_WRAPPER_FUNCTION_UTILS_H #include "c_api.h" #include "common.h" #include "error.h" #include "simple_packed_serialization.h" #include namespace __orc_rt { /// C++ wrapper function result: Same as CWrapperFunctionResult but /// auto-releases memory. class WrapperFunctionResult { public: /// Create a default WrapperFunctionResult. WrapperFunctionResult() { __orc_rt_CWrapperFunctionResultInit(&R); } /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This /// instance takes ownership of the result object and will automatically /// call dispose on the result upon destruction. WrapperFunctionResult(__orc_rt_CWrapperFunctionResult R) : R(R) {} WrapperFunctionResult(const WrapperFunctionResult &) = delete; WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; WrapperFunctionResult(WrapperFunctionResult &&Other) { __orc_rt_CWrapperFunctionResultInit(&R); std::swap(R, Other.R); } WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { __orc_rt_CWrapperFunctionResult Tmp; __orc_rt_CWrapperFunctionResultInit(&Tmp); std::swap(Tmp, Other.R); std::swap(R, Tmp); return *this; } ~WrapperFunctionResult() { __orc_rt_DisposeCWrapperFunctionResult(&R); } /// Relinquish ownership of and return the /// __orc_rt_CWrapperFunctionResult. __orc_rt_CWrapperFunctionResult release() { __orc_rt_CWrapperFunctionResult Tmp; __orc_rt_CWrapperFunctionResultInit(&Tmp); std::swap(R, Tmp); return Tmp; } /// Get a pointer to the data contained in this instance. const char *data() const { return __orc_rt_CWrapperFunctionResultData(&R); } /// Returns the size of the data contained in this instance. size_t size() const { return __orc_rt_CWrapperFunctionResultSize(&R); } /// Returns true if this value is equivalent to a default-constructed /// WrapperFunctionResult. bool empty() const { return __orc_rt_CWrapperFunctionResultEmpty(&R); } /// Create a WrapperFunctionResult with the given size and return a pointer /// to the underlying memory. static char *allocate(WrapperFunctionResult &R, size_t Size) { __orc_rt_DisposeCWrapperFunctionResult(&R.R); __orc_rt_CWrapperFunctionResultInit(&R.R); return __orc_rt_CWrapperFunctionResultAllocate(&R.R, Size); } /// Copy from the given char range. static WrapperFunctionResult copyFrom(const char *Source, size_t Size) { return __orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size); } /// Copy from the given null-terminated string (includes the null-terminator). static WrapperFunctionResult copyFrom(const char *Source) { return __orc_rt_CreateCWrapperFunctionResultFromString(Source); } /// Copy from the given std::string (includes the null terminator). static WrapperFunctionResult copyFrom(const std::string &Source) { return copyFrom(Source.c_str()); } /// Create an out-of-band error by copying the given string. static WrapperFunctionResult createOutOfBandError(const char *Msg) { return __orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg); } /// Create an out-of-band error by copying the given string. static WrapperFunctionResult createOutOfBandError(const std::string &Msg) { return createOutOfBandError(Msg.c_str()); } /// If this value is an out-of-band error then this returns the error message, /// otherwise returns nullptr. const char *getOutOfBandError() const { return __orc_rt_CWrapperFunctionResultGetOutOfBandError(&R); } private: __orc_rt_CWrapperFunctionResult R; }; namespace detail { template Expected serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) { WrapperFunctionResult Result; char *DataPtr = WrapperFunctionResult::allocate(Result, SPSArgListT::size(Args...)); SPSOutputBuffer OB(DataPtr, Result.size()); if (!SPSArgListT::serialize(OB, Args...)) return make_error( "Error serializing arguments to blob in call"); return std::move(Result); } template class WrapperFunctionHandlerCaller { public: template static decltype(auto) call(HandlerT &&H, ArgTupleT &Args, std::index_sequence) { return std::forward(H)(std::get(Args)...); } }; template <> class WrapperFunctionHandlerCaller { public: template static SPSEmpty call(HandlerT &&H, ArgTupleT &Args, std::index_sequence) { std::forward(H)(std::get(Args)...); return SPSEmpty(); } }; template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionHandlerHelper : public WrapperFunctionHandlerHelper< decltype(&std::remove_reference_t::operator()), ResultSerializer, SPSTagTs...> {}; template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionHandlerHelper { public: using ArgTuple = std::tuple...>; using ArgIndices = std::make_index_sequence::value>; template static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData, size_t ArgSize) { ArgTuple Args; if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) return WrapperFunctionResult::createOutOfBandError( "Could not deserialize arguments for wrapper function call"); auto HandlerResult = WrapperFunctionHandlerCaller::call( std::forward(H), Args, ArgIndices{}); if (auto Result = ResultSerializer::serialize( std::move(HandlerResult))) return std::move(*Result); else return WrapperFunctionResult::createOutOfBandError( toString(Result.takeError())); } private: template static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, std::index_sequence) { SPSInputBuffer IB(ArgData, ArgSize); return SPSArgList::deserialize(IB, std::get(Args)...); } }; // Map function references to function types. template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionHandlerHelper : public WrapperFunctionHandlerHelper {}; // Map non-const member function types to function types. template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionHandlerHelper : public WrapperFunctionHandlerHelper {}; // Map const member function types to function types. template class ResultSerializer, typename... SPSTagTs> class WrapperFunctionHandlerHelper : public WrapperFunctionHandlerHelper {}; template class ResultSerializer { public: static Expected serialize(RetT Result) { return serializeViaSPSToWrapperFunctionResult>( Result); } }; template class ResultSerializer { public: static Expected serialize(Error Err) { return serializeViaSPSToWrapperFunctionResult>( toSPSSerializable(std::move(Err))); } }; template class ResultSerializer> { public: static Expected serialize(Expected E) { return serializeViaSPSToWrapperFunctionResult>( toSPSSerializable(std::move(E))); } }; template class ResultDeserializer { public: static void makeSafe(RetT &Result) {} static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) { SPSInputBuffer IB(ArgData, ArgSize); if (!SPSArgList::deserialize(IB, Result)) return make_error( "Error deserializing return value from blob in call"); return Error::success(); } }; template <> class ResultDeserializer { public: static void makeSafe(Error &Err) { cantFail(std::move(Err)); } static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) { SPSInputBuffer IB(ArgData, ArgSize); SPSSerializableError BSE; if (!SPSArgList::deserialize(IB, BSE)) return make_error( "Error deserializing return value from blob in call"); Err = fromSPSSerializable(std::move(BSE)); return Error::success(); } }; template class ResultDeserializer, Expected> { public: static void makeSafe(Expected &E) { cantFail(E.takeError()); } static Error deserialize(Expected &E, const char *ArgData, size_t ArgSize) { SPSInputBuffer IB(ArgData, ArgSize); SPSSerializableExpected BSE; if (!SPSArgList>::deserialize(IB, BSE)) return make_error( "Error deserializing return value from blob in call"); E = fromSPSSerializable(std::move(BSE)); return Error::success(); } }; } // end namespace detail template class WrapperFunction; template class WrapperFunction { private: template using ResultSerializer = detail::ResultSerializer; public: template static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) { // RetT might be an Error or Expected value. Set the checked flag now: // we don't want the user to have to check the unused result if this // operation fails. detail::ResultDeserializer::makeSafe(Result); if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx)) return make_error("__orc_rt_jit_dispatch_ctx not set"); if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch)) return make_error("__orc_rt_jit_dispatch not set"); auto ArgBuffer = detail::serializeViaSPSToWrapperFunctionResult>( Args...); if (!ArgBuffer) return ArgBuffer.takeError(); WrapperFunctionResult ResultBuffer = __orc_rt_jit_dispatch(&__orc_rt_jit_dispatch_ctx, FnTag, ArgBuffer->data(), ArgBuffer->size()); if (auto ErrMsg = ResultBuffer.getOutOfBandError()) return make_error(ErrMsg); return detail::ResultDeserializer::deserialize( Result, ResultBuffer.data(), ResultBuffer.size()); } template static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize, HandlerT &&Handler) { using WFHH = detail::WrapperFunctionHandlerHelper; return WFHH::apply(std::forward(Handler), ArgData, ArgSize); } private: template static const T &makeSerializable(const T &Value) { return Value; } static detail::SPSSerializableError makeSerializable(Error Err) { return detail::toSPSSerializable(std::move(Err)); } template static detail::SPSSerializableExpected makeSerializable(Expected E) { return detail::toSPSSerializable(std::move(E)); } }; template class WrapperFunction : private WrapperFunction { public: template static Error call(const void *FnTag, const ArgTs &...Args) { SPSEmpty BE; return WrapperFunction::call(FnTag, BE, Args...); } using WrapperFunction::handle; }; } // end namespace __orc_rt #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H