1 //===-- wrapper_function_utils_test.cpp -----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file is a part of the ORC runtime.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "common.h"
14 #include "jit_dispatch.h"
15 #include "wrapper_function_utils.h"
16 #include "gtest/gtest.h"
17
18 using namespace orc_rt;
19
20 namespace {
21 constexpr const char *TestString = "test string";
22 } // end anonymous namespace
23
TEST(WrapperFunctionUtilsTest,DefaultWrapperFunctionResult)24 TEST(WrapperFunctionUtilsTest, DefaultWrapperFunctionResult) {
25 WrapperFunctionResult R;
26 EXPECT_TRUE(R.empty());
27 EXPECT_EQ(R.size(), 0U);
28 EXPECT_EQ(R.getOutOfBandError(), nullptr);
29 }
30
TEST(WrapperFunctionUtilsTest,WrapperFunctionResultFromCStruct)31 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromCStruct) {
32 orc_rt_WrapperFunctionResult CR =
33 orc_rt_CreateWrapperFunctionResultFromString(TestString);
34 WrapperFunctionResult R(CR);
35 EXPECT_EQ(R.size(), strlen(TestString) + 1);
36 EXPECT_TRUE(strcmp(R.data(), TestString) == 0);
37 EXPECT_FALSE(R.empty());
38 EXPECT_EQ(R.getOutOfBandError(), nullptr);
39 }
40
TEST(WrapperFunctionUtilsTest,WrapperFunctionResultFromRange)41 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromRange) {
42 auto R = WrapperFunctionResult::copyFrom(TestString, strlen(TestString) + 1);
43 EXPECT_EQ(R.size(), strlen(TestString) + 1);
44 EXPECT_TRUE(strcmp(R.data(), TestString) == 0);
45 EXPECT_FALSE(R.empty());
46 EXPECT_EQ(R.getOutOfBandError(), nullptr);
47 }
48
TEST(WrapperFunctionUtilsTest,WrapperFunctionResultFromCString)49 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromCString) {
50 auto R = WrapperFunctionResult::copyFrom(TestString);
51 EXPECT_EQ(R.size(), strlen(TestString) + 1);
52 EXPECT_TRUE(strcmp(R.data(), TestString) == 0);
53 EXPECT_FALSE(R.empty());
54 EXPECT_EQ(R.getOutOfBandError(), nullptr);
55 }
56
TEST(WrapperFunctionUtilsTest,WrapperFunctionResultFromStdString)57 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromStdString) {
58 auto R = WrapperFunctionResult::copyFrom(std::string(TestString));
59 EXPECT_EQ(R.size(), strlen(TestString) + 1);
60 EXPECT_TRUE(strcmp(R.data(), TestString) == 0);
61 EXPECT_FALSE(R.empty());
62 EXPECT_EQ(R.getOutOfBandError(), nullptr);
63 }
64
TEST(WrapperFunctionUtilsTest,WrapperFunctionResultFromOutOfBandError)65 TEST(WrapperFunctionUtilsTest, WrapperFunctionResultFromOutOfBandError) {
66 auto R = WrapperFunctionResult::createOutOfBandError(TestString);
67 EXPECT_FALSE(R.empty());
68 EXPECT_TRUE(strcmp(R.getOutOfBandError(), TestString) == 0);
69 }
70
TEST(WrapperFunctionUtilsTest,WrapperFunctionCCallCreateEmpty)71 TEST(WrapperFunctionUtilsTest, WrapperFunctionCCallCreateEmpty) {
72 EXPECT_TRUE(!!WrapperFunctionCall::Create<SPSArgList<>>(ExecutorAddr()));
73 }
74
voidNoop()75 static void voidNoop() {}
76
voidNoopWrapper(const char * ArgData,size_t ArgSize)77 static orc_rt_WrapperFunctionResult voidNoopWrapper(const char *ArgData,
78 size_t ArgSize) {
79 return WrapperFunction<void()>::handle(ArgData, ArgSize, voidNoop).release();
80 }
81
addWrapper(const char * ArgData,size_t ArgSize)82 static orc_rt_WrapperFunctionResult addWrapper(const char *ArgData,
83 size_t ArgSize) {
84 return WrapperFunction<int32_t(int32_t, int32_t)>::handle(
85 ArgData, ArgSize,
86 [](int32_t X, int32_t Y) -> int32_t { return X + Y; })
87 .release();
88 }
89
90 extern "C" __orc_rt_Opaque __orc_rt_jit_dispatch_ctx{};
91
92 extern "C" orc_rt_WrapperFunctionResult
__orc_rt_jit_dispatch(__orc_rt_Opaque * Ctx,const void * FnTag,const char * ArgData,size_t ArgSize)93 __orc_rt_jit_dispatch(__orc_rt_Opaque *Ctx, const void *FnTag,
94 const char *ArgData, size_t ArgSize) {
95 using WrapperFunctionType =
96 orc_rt_WrapperFunctionResult (*)(const char *, size_t);
97
98 return reinterpret_cast<WrapperFunctionType>(const_cast<void *>(FnTag))(
99 ArgData, ArgSize);
100 }
101
TEST(WrapperFunctionUtilsTest,WrapperFunctionCallVoidNoopAndHandle)102 TEST(WrapperFunctionUtilsTest, WrapperFunctionCallVoidNoopAndHandle) {
103 EXPECT_FALSE(
104 !!WrapperFunction<void()>::call(JITDispatch((void *)&voidNoopWrapper)));
105 }
106
TEST(WrapperFunctionUtilsTest,WrapperFunctionCallAddWrapperAndHandle)107 TEST(WrapperFunctionUtilsTest, WrapperFunctionCallAddWrapperAndHandle) {
108 int32_t Result;
109 EXPECT_FALSE(!!WrapperFunction<int32_t(int32_t, int32_t)>::call(
110 JITDispatch((void *)&addWrapper), Result, 1, 2));
111 EXPECT_EQ(Result, (int32_t)3);
112 }
113
114 class AddClass {
115 public:
AddClass(int32_t X)116 AddClass(int32_t X) : X(X) {}
addMethod(int32_t Y)117 int32_t addMethod(int32_t Y) { return X + Y; }
118
119 private:
120 int32_t X;
121 };
122
addMethodWrapper(const char * ArgData,size_t ArgSize)123 static orc_rt_WrapperFunctionResult addMethodWrapper(const char *ArgData,
124 size_t ArgSize) {
125 return WrapperFunction<int32_t(SPSExecutorAddr, int32_t)>::handle(
126 ArgData, ArgSize, makeMethodWrapperHandler(&AddClass::addMethod))
127 .release();
128 }
129
TEST(WrapperFunctionUtilsTest,WrapperFunctionMethodCallAndHandleRet)130 TEST(WrapperFunctionUtilsTest, WrapperFunctionMethodCallAndHandleRet) {
131 int32_t Result;
132 AddClass AddObj(1);
133 EXPECT_FALSE(!!WrapperFunction<int32_t(SPSExecutorAddr, int32_t)>::call(
134 JITDispatch((void *)&addMethodWrapper), Result,
135 ExecutorAddr::fromPtr(&AddObj), 2));
136 EXPECT_EQ(Result, (int32_t)3);
137 }
138
sumArrayWrapper(const char * ArgData,size_t ArgSize)139 static orc_rt_WrapperFunctionResult sumArrayWrapper(const char *ArgData,
140 size_t ArgSize) {
141 return WrapperFunction<int8_t(SPSExecutorAddrRange)>::handle(
142 ArgData, ArgSize,
143 [](ExecutorAddrRange R) {
144 int8_t Sum = 0;
145 for (char C : R.toSpan<char>())
146 Sum += C;
147 return Sum;
148 })
149 .release();
150 }
151
TEST(WrapperFunctionUtilsTest,SerializedWrapperFunctionCallTest)152 TEST(WrapperFunctionUtilsTest, SerializedWrapperFunctionCallTest) {
153 {
154 // Check wrapper function calls.
155 char A[] = {1, 2, 3, 4};
156
157 auto WFC =
158 cantFail(WrapperFunctionCall::Create<SPSArgList<SPSExecutorAddrRange>>(
159 ExecutorAddr::fromPtr(sumArrayWrapper),
160 ExecutorAddrRange(ExecutorAddr::fromPtr(A),
161 ExecutorAddrDiff(sizeof(A)))));
162
163 WrapperFunctionResult WFR(WFC.run());
164 EXPECT_EQ(WFR.size(), 1U);
165 EXPECT_EQ(WFR.data()[0], 10);
166 }
167
168 {
169 // Check calls to void functions.
170 auto WFC =
171 cantFail(WrapperFunctionCall::Create<SPSArgList<SPSExecutorAddrRange>>(
172 ExecutorAddr::fromPtr(voidNoopWrapper), ExecutorAddrRange()));
173 auto Err = WFC.runWithSPSRet<void>();
174 EXPECT_FALSE(!!Err);
175 }
176
177 {
178 // Check calls with arguments and return values.
179 auto WFC =
180 cantFail(WrapperFunctionCall::Create<SPSArgList<int32_t, int32_t>>(
181 ExecutorAddr::fromPtr(addWrapper), 2, 4));
182
183 int32_t Result = 0;
184 auto Err = WFC.runWithSPSRet<int32_t>(Result);
185 EXPECT_FALSE(!!Err);
186 EXPECT_EQ(Result, 6);
187 }
188 }
189