xref: /linux/tools/testing/selftests/vDSO/vdso_test_getrandom.c (revision a0482e3446cea426bf16571e0000423ed5b25af0)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2022-2024 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  */
5 
6 #include <assert.h>
7 #include <pthread.h>
8 #include <stdint.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <time.h>
13 #include <unistd.h>
14 #include <sched.h>
15 #include <signal.h>
16 #include <sys/auxv.h>
17 #include <sys/mman.h>
18 #include <sys/random.h>
19 #include <sys/syscall.h>
20 #include <sys/ptrace.h>
21 #include <sys/wait.h>
22 #include <sys/types.h>
23 #include <linux/random.h>
24 #include <linux/ptrace.h>
25 
26 #include "../kselftest.h"
27 #include "parse_vdso.h"
28 #include "vdso_config.h"
29 #include "vdso_call.h"
30 
31 #ifndef timespecsub
32 #define	timespecsub(tsp, usp, vsp)					\
33 	do {								\
34 		(vsp)->tv_sec = (tsp)->tv_sec - (usp)->tv_sec;		\
35 		(vsp)->tv_nsec = (tsp)->tv_nsec - (usp)->tv_nsec;	\
36 		if ((vsp)->tv_nsec < 0) {				\
37 			(vsp)->tv_sec--;				\
38 			(vsp)->tv_nsec += 1000000000L;			\
39 		}							\
40 	} while (0)
41 #endif
42 
43 #define ksft_assert(condition) \
44 	do { if (!(condition)) ksft_exit_fail_msg("Assertion failed: %s\n", #condition); } while (0)
45 
46 static struct {
47 	pthread_mutex_t lock;
48 	void **states;
49 	size_t len, cap;
50 	ssize_t(*fn)(void *, size_t, unsigned long, void *, size_t);
51 	struct vgetrandom_opaque_params params;
52 } vgrnd = {
53 	.lock = PTHREAD_MUTEX_INITIALIZER
54 };
55 
vgetrandom_get_state(void)56 static void *vgetrandom_get_state(void)
57 {
58 	void *state = NULL;
59 
60 	pthread_mutex_lock(&vgrnd.lock);
61 	if (!vgrnd.len) {
62 		size_t page_size = getpagesize();
63 		size_t new_cap;
64 		size_t alloc_size, num = sysconf(_SC_NPROCESSORS_ONLN); /* Just a decent heuristic. */
65 		size_t state_size_aligned, cache_line_size = sysconf(_SC_LEVEL1_DCACHE_LINESIZE) ?: 1;
66 		void *new_block, *new_states;
67 
68 		state_size_aligned = (vgrnd.params.size_of_opaque_state + cache_line_size - 1) & (~(cache_line_size - 1));
69 		alloc_size = (num * state_size_aligned + page_size - 1) & (~(page_size - 1));
70 		num = (page_size / state_size_aligned) * (alloc_size / page_size);
71 		new_block = mmap(0, alloc_size, vgrnd.params.mmap_prot, vgrnd.params.mmap_flags, -1, 0);
72 		if (new_block == MAP_FAILED)
73 			goto out;
74 
75 		new_cap = vgrnd.cap + num;
76 		new_states = reallocarray(vgrnd.states, new_cap, sizeof(*vgrnd.states));
77 		if (!new_states)
78 			goto unmap;
79 		vgrnd.cap = new_cap;
80 		vgrnd.states = new_states;
81 
82 		for (size_t i = 0; i < num; ++i) {
83 			if (((uintptr_t)new_block & (page_size - 1)) + vgrnd.params.size_of_opaque_state > page_size)
84 				new_block = (void *)(((uintptr_t)new_block + page_size - 1) & (~(page_size - 1)));
85 			vgrnd.states[i] = new_block;
86 			new_block += state_size_aligned;
87 		}
88 		vgrnd.len = num;
89 		goto success;
90 
91 	unmap:
92 		munmap(new_block, alloc_size);
93 		goto out;
94 	}
95 success:
96 	state = vgrnd.states[--vgrnd.len];
97 
98 out:
99 	pthread_mutex_unlock(&vgrnd.lock);
100 	return state;
101 }
102 
103 __attribute__((unused)) /* Example for libc implementors */
vgetrandom_put_state(void * state)104 static void vgetrandom_put_state(void *state)
105 {
106 	if (!state)
107 		return;
108 	pthread_mutex_lock(&vgrnd.lock);
109 	vgrnd.states[vgrnd.len++] = state;
110 	pthread_mutex_unlock(&vgrnd.lock);
111 }
112 
vgetrandom_init(void)113 static void vgetrandom_init(void)
114 {
115 	const char *version = versions[VDSO_VERSION];
116 	const char *name = names[VDSO_NAMES][6];
117 	unsigned long sysinfo_ehdr = getauxval(AT_SYSINFO_EHDR);
118 	ssize_t ret;
119 
120 	if (!sysinfo_ehdr)
121 		ksft_exit_skip("AT_SYSINFO_EHDR is not present\n");
122 	vdso_init_from_sysinfo_ehdr(sysinfo_ehdr);
123 	vgrnd.fn = (__typeof__(vgrnd.fn))vdso_sym(version, name);
124 	if (!vgrnd.fn)
125 		ksft_exit_skip("%s@%s symbol is missing from vDSO\n", name, version);
126 	ret = VDSO_CALL(vgrnd.fn, 5, NULL, 0, 0, &vgrnd.params, ~0UL);
127 	if (ret == -ENOSYS)
128 		ksft_exit_skip("CPU does not have runtime support\n");
129 	else if (ret)
130 		ksft_exit_fail_msg("Failed to fetch vgetrandom params: %zd\n", ret);
131 }
132 
vgetrandom(void * buf,size_t len,unsigned long flags)133 static ssize_t vgetrandom(void *buf, size_t len, unsigned long flags)
134 {
135 	static __thread void *state;
136 
137 	if (!state) {
138 		state = vgetrandom_get_state();
139 		ksft_assert(state);
140 	}
141 	return VDSO_CALL(vgrnd.fn, 5, buf, len, flags, state, vgrnd.params.size_of_opaque_state);
142 }
143 
144 enum { TRIALS = 25000000, THREADS = 256 };
145 
test_vdso_getrandom(void * ctx)146 static void *test_vdso_getrandom(void *ctx)
147 {
148 	for (size_t i = 0; i < TRIALS; ++i) {
149 		unsigned int val;
150 		ssize_t ret = vgetrandom(&val, sizeof(val), 0);
151 		ksft_assert(ret == sizeof(val));
152 	}
153 	return NULL;
154 }
155 
test_libc_getrandom(void * ctx)156 static void *test_libc_getrandom(void *ctx)
157 {
158 	for (size_t i = 0; i < TRIALS; ++i) {
159 		unsigned int val;
160 		ssize_t ret = getrandom(&val, sizeof(val), 0);
161 		ksft_assert(ret == sizeof(val));
162 	}
163 	return NULL;
164 }
165 
test_syscall_getrandom(void * ctx)166 static void *test_syscall_getrandom(void *ctx)
167 {
168 	for (size_t i = 0; i < TRIALS; ++i) {
169 		unsigned int val;
170 		ssize_t ret = syscall(__NR_getrandom, &val, sizeof(val), 0);
171 		ksft_assert(ret == sizeof(val));
172 	}
173 	return NULL;
174 }
175 
bench_single(void)176 static void bench_single(void)
177 {
178 	struct timespec start, end, diff;
179 
180 	clock_gettime(CLOCK_MONOTONIC, &start);
181 	test_vdso_getrandom(NULL);
182 	clock_gettime(CLOCK_MONOTONIC, &end);
183 	timespecsub(&end, &start, &diff);
184 	printf("   vdso: %u times in %lu.%09lu seconds\n", TRIALS, diff.tv_sec, diff.tv_nsec);
185 
186 	clock_gettime(CLOCK_MONOTONIC, &start);
187 	test_libc_getrandom(NULL);
188 	clock_gettime(CLOCK_MONOTONIC, &end);
189 	timespecsub(&end, &start, &diff);
190 	printf("   libc: %u times in %lu.%09lu seconds\n", TRIALS, diff.tv_sec, diff.tv_nsec);
191 
192 	clock_gettime(CLOCK_MONOTONIC, &start);
193 	test_syscall_getrandom(NULL);
194 	clock_gettime(CLOCK_MONOTONIC, &end);
195 	timespecsub(&end, &start, &diff);
196 	printf("syscall: %u times in %lu.%09lu seconds\n", TRIALS, diff.tv_sec, diff.tv_nsec);
197 }
198 
bench_multi(void)199 static void bench_multi(void)
200 {
201 	struct timespec start, end, diff;
202 	pthread_t threads[THREADS];
203 
204 	clock_gettime(CLOCK_MONOTONIC, &start);
205 	for (size_t i = 0; i < THREADS; ++i)
206 		ksft_assert(pthread_create(&threads[i], NULL, test_vdso_getrandom, NULL) == 0);
207 	for (size_t i = 0; i < THREADS; ++i)
208 		pthread_join(threads[i], NULL);
209 	clock_gettime(CLOCK_MONOTONIC, &end);
210 	timespecsub(&end, &start, &diff);
211 	printf("   vdso: %u x %u times in %lu.%09lu seconds\n", TRIALS, THREADS, diff.tv_sec, diff.tv_nsec);
212 
213 	clock_gettime(CLOCK_MONOTONIC, &start);
214 	for (size_t i = 0; i < THREADS; ++i)
215 		ksft_assert(pthread_create(&threads[i], NULL, test_libc_getrandom, NULL) == 0);
216 	for (size_t i = 0; i < THREADS; ++i)
217 		pthread_join(threads[i], NULL);
218 	clock_gettime(CLOCK_MONOTONIC, &end);
219 	timespecsub(&end, &start, &diff);
220 	printf("   libc: %u x %u times in %lu.%09lu seconds\n", TRIALS, THREADS, diff.tv_sec, diff.tv_nsec);
221 
222 	clock_gettime(CLOCK_MONOTONIC, &start);
223 	for (size_t i = 0; i < THREADS; ++i)
224 		ksft_assert(pthread_create(&threads[i], NULL, test_syscall_getrandom, NULL) == 0);
225 	for (size_t i = 0; i < THREADS; ++i)
226 		pthread_join(threads[i], NULL);
227 	clock_gettime(CLOCK_MONOTONIC, &end);
228 	timespecsub(&end, &start, &diff);
229 	printf("   syscall: %u x %u times in %lu.%09lu seconds\n", TRIALS, THREADS, diff.tv_sec, diff.tv_nsec);
230 }
231 
fill(void)232 static void fill(void)
233 {
234 	uint8_t weird_size[323929];
235 	for (;;)
236 		vgetrandom(weird_size, sizeof(weird_size), 0);
237 }
238 
kselftest(void)239 static void kselftest(void)
240 {
241 	uint8_t weird_size[1263];
242 	pid_t child;
243 
244 	ksft_print_header();
245 	vgetrandom_init();
246 	ksft_set_plan(2);
247 
248 	for (size_t i = 0; i < 1000; ++i) {
249 		ssize_t ret = vgetrandom(weird_size, sizeof(weird_size), 0);
250 		ksft_assert(ret == sizeof(weird_size));
251 	}
252 
253 	ksft_test_result_pass("getrandom: PASS\n");
254 
255 	unshare(CLONE_NEWUSER);
256 	ksft_assert(unshare(CLONE_NEWTIME) == 0);
257 	child = fork();
258 	ksft_assert(child >= 0);
259 	if (!child) {
260 		vgetrandom_init();
261 		child = getpid();
262 		ksft_assert(ptrace(PTRACE_TRACEME, 0, NULL, NULL) == 0);
263 		ksft_assert(kill(child, SIGSTOP) == 0);
264 		ksft_assert(vgetrandom(weird_size, sizeof(weird_size), 0) == sizeof(weird_size));
265 		_exit(0);
266 	}
267 	for (;;) {
268 		struct ptrace_syscall_info info = { 0 };
269 		int status;
270 		ksft_assert(waitpid(child, &status, 0) >= 0);
271 		if (WIFEXITED(status)) {
272 			ksft_assert(WEXITSTATUS(status) == 0);
273 			break;
274 		}
275 		ksft_assert(WIFSTOPPED(status));
276 		if (WSTOPSIG(status) == SIGSTOP)
277 			ksft_assert(ptrace(PTRACE_SETOPTIONS, child, 0, PTRACE_O_TRACESYSGOOD) == 0);
278 		else if (WSTOPSIG(status) == (SIGTRAP | 0x80)) {
279 			ksft_assert(ptrace(PTRACE_GET_SYSCALL_INFO, child, sizeof(info), &info) > 0);
280 			if (info.op == PTRACE_SYSCALL_INFO_ENTRY && info.entry.nr == __NR_getrandom &&
281 			    info.entry.args[0] == (uintptr_t)weird_size && info.entry.args[1] == sizeof(weird_size))
282 				ksft_exit_fail_msg("vgetrandom passed buffer to syscall getrandom unexpectedly\n");
283 		}
284 		ksft_assert(ptrace(PTRACE_SYSCALL, child, 0, 0) == 0);
285 	}
286 
287 	ksft_test_result_pass("getrandom timens: PASS\n");
288 
289 	ksft_exit_pass();
290 }
291 
usage(const char * argv0)292 static void usage(const char *argv0)
293 {
294 	fprintf(stderr, "Usage: %s [bench-single|bench-multi|fill]\n", argv0);
295 }
296 
main(int argc,char * argv[])297 int main(int argc, char *argv[])
298 {
299 	if (argc == 1) {
300 		kselftest();
301 		return 0;
302 	}
303 
304 	if (argc != 2) {
305 		usage(argv[0]);
306 		return 1;
307 	}
308 
309 	vgetrandom_init();
310 
311 	if (!strcmp(argv[1], "bench-single"))
312 		bench_single();
313 	else if (!strcmp(argv[1], "bench-multi"))
314 		bench_multi();
315 	else if (!strcmp(argv[1], "fill"))
316 		fill();
317 	else {
318 		usage(argv[0]);
319 		return 1;
320 	}
321 	return 0;
322 }
323