xref: /linux/tools/testing/selftests/vDSO/vdso_test_getrandom.c (revision 170aafe35cb98e0f3fbacb446ea86389fbce22ea)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2022-2024 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
4  */
5 
6 #include <assert.h>
7 #include <pthread.h>
8 #include <stdint.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <time.h>
13 #include <unistd.h>
14 #include <signal.h>
15 #include <sys/auxv.h>
16 #include <sys/mman.h>
17 #include <sys/random.h>
18 #include <sys/syscall.h>
19 #include <sys/types.h>
20 #include <linux/random.h>
21 
22 #include "../kselftest.h"
23 #include "parse_vdso.h"
24 
25 #ifndef timespecsub
26 #define	timespecsub(tsp, usp, vsp)					\
27 	do {								\
28 		(vsp)->tv_sec = (tsp)->tv_sec - (usp)->tv_sec;		\
29 		(vsp)->tv_nsec = (tsp)->tv_nsec - (usp)->tv_nsec;	\
30 		if ((vsp)->tv_nsec < 0) {				\
31 			(vsp)->tv_sec--;				\
32 			(vsp)->tv_nsec += 1000000000L;			\
33 		}							\
34 	} while (0)
35 #endif
36 
37 static struct {
38 	pthread_mutex_t lock;
39 	void **states;
40 	size_t len, cap;
41 } grnd_allocator = {
42 	.lock = PTHREAD_MUTEX_INITIALIZER
43 };
44 
45 static struct {
46 	ssize_t(*fn)(void *, size_t, unsigned long, void *, size_t);
47 	pthread_key_t key;
48 	pthread_once_t initialized;
49 	struct vgetrandom_opaque_params params;
50 } grnd_ctx = {
51 	.initialized = PTHREAD_ONCE_INIT
52 };
53 
54 static void *vgetrandom_get_state(void)
55 {
56 	void *state = NULL;
57 
58 	pthread_mutex_lock(&grnd_allocator.lock);
59 	if (!grnd_allocator.len) {
60 		size_t page_size = getpagesize();
61 		size_t new_cap;
62 		size_t alloc_size, num = sysconf(_SC_NPROCESSORS_ONLN); /* Just a decent heuristic. */
63 		void *new_block, *new_states;
64 
65 		alloc_size = (num * grnd_ctx.params.size_of_opaque_state + page_size - 1) & (~(page_size - 1));
66 		num = (page_size / grnd_ctx.params.size_of_opaque_state) * (alloc_size / page_size);
67 		new_block = mmap(0, alloc_size, grnd_ctx.params.mmap_prot, grnd_ctx.params.mmap_flags, -1, 0);
68 		if (new_block == MAP_FAILED)
69 			goto out;
70 
71 		new_cap = grnd_allocator.cap + num;
72 		new_states = reallocarray(grnd_allocator.states, new_cap, sizeof(*grnd_allocator.states));
73 		if (!new_states)
74 			goto unmap;
75 		grnd_allocator.cap = new_cap;
76 		grnd_allocator.states = new_states;
77 
78 		for (size_t i = 0; i < num; ++i) {
79 			if (((uintptr_t)new_block & (page_size - 1)) + grnd_ctx.params.size_of_opaque_state > page_size)
80 				new_block = (void *)(((uintptr_t)new_block + page_size - 1) & (~(page_size - 1)));
81 			grnd_allocator.states[i] = new_block;
82 			new_block += grnd_ctx.params.size_of_opaque_state;
83 		}
84 		grnd_allocator.len = num;
85 		goto success;
86 
87 	unmap:
88 		munmap(new_block, alloc_size);
89 		goto out;
90 	}
91 success:
92 	state = grnd_allocator.states[--grnd_allocator.len];
93 
94 out:
95 	pthread_mutex_unlock(&grnd_allocator.lock);
96 	return state;
97 }
98 
99 static void vgetrandom_put_state(void *state)
100 {
101 	if (!state)
102 		return;
103 	pthread_mutex_lock(&grnd_allocator.lock);
104 	grnd_allocator.states[grnd_allocator.len++] = state;
105 	pthread_mutex_unlock(&grnd_allocator.lock);
106 }
107 
108 static void vgetrandom_init(void)
109 {
110 	if (pthread_key_create(&grnd_ctx.key, vgetrandom_put_state) != 0)
111 		return;
112 	unsigned long sysinfo_ehdr = getauxval(AT_SYSINFO_EHDR);
113 	if (!sysinfo_ehdr) {
114 		printf("AT_SYSINFO_EHDR is not present!\n");
115 		exit(KSFT_SKIP);
116 	}
117 	vdso_init_from_sysinfo_ehdr(sysinfo_ehdr);
118 	grnd_ctx.fn = (__typeof__(grnd_ctx.fn))vdso_sym("LINUX_2.6", "__vdso_getrandom");
119 	if (!grnd_ctx.fn) {
120 		printf("__vdso_getrandom is missing!\n");
121 		exit(KSFT_FAIL);
122 	}
123 	if (grnd_ctx.fn(NULL, 0, 0, &grnd_ctx.params, ~0UL) != 0) {
124 		printf("failed to fetch vgetrandom params!\n");
125 		exit(KSFT_FAIL);
126 	}
127 }
128 
129 static ssize_t vgetrandom(void *buf, size_t len, unsigned long flags)
130 {
131 	void *state;
132 
133 	pthread_once(&grnd_ctx.initialized, vgetrandom_init);
134 	state = pthread_getspecific(grnd_ctx.key);
135 	if (!state) {
136 		state = vgetrandom_get_state();
137 		if (pthread_setspecific(grnd_ctx.key, state) != 0) {
138 			vgetrandom_put_state(state);
139 			state = NULL;
140 		}
141 		if (!state) {
142 			printf("vgetrandom_get_state failed!\n");
143 			exit(KSFT_FAIL);
144 		}
145 	}
146 	return grnd_ctx.fn(buf, len, flags, state, grnd_ctx.params.size_of_opaque_state);
147 }
148 
149 enum { TRIALS = 25000000, THREADS = 256 };
150 
151 static void *test_vdso_getrandom(void *)
152 {
153 	for (size_t i = 0; i < TRIALS; ++i) {
154 		unsigned int val;
155 		ssize_t ret = vgetrandom(&val, sizeof(val), 0);
156 		assert(ret == sizeof(val));
157 	}
158 	return NULL;
159 }
160 
161 static void *test_libc_getrandom(void *)
162 {
163 	for (size_t i = 0; i < TRIALS; ++i) {
164 		unsigned int val;
165 		ssize_t ret = getrandom(&val, sizeof(val), 0);
166 		assert(ret == sizeof(val));
167 	}
168 	return NULL;
169 }
170 
171 static void *test_syscall_getrandom(void *)
172 {
173 	for (size_t i = 0; i < TRIALS; ++i) {
174 		unsigned int val;
175 		ssize_t ret = syscall(__NR_getrandom, &val, sizeof(val), 0);
176 		assert(ret == sizeof(val));
177 	}
178 	return NULL;
179 }
180 
181 static void bench_single(void)
182 {
183 	struct timespec start, end, diff;
184 
185 	clock_gettime(CLOCK_MONOTONIC, &start);
186 	test_vdso_getrandom(NULL);
187 	clock_gettime(CLOCK_MONOTONIC, &end);
188 	timespecsub(&end, &start, &diff);
189 	printf("   vdso: %u times in %lu.%09lu seconds\n", TRIALS, diff.tv_sec, diff.tv_nsec);
190 
191 	clock_gettime(CLOCK_MONOTONIC, &start);
192 	test_libc_getrandom(NULL);
193 	clock_gettime(CLOCK_MONOTONIC, &end);
194 	timespecsub(&end, &start, &diff);
195 	printf("   libc: %u times in %lu.%09lu seconds\n", TRIALS, diff.tv_sec, diff.tv_nsec);
196 
197 	clock_gettime(CLOCK_MONOTONIC, &start);
198 	test_syscall_getrandom(NULL);
199 	clock_gettime(CLOCK_MONOTONIC, &end);
200 	timespecsub(&end, &start, &diff);
201 	printf("syscall: %u times in %lu.%09lu seconds\n", TRIALS, diff.tv_sec, diff.tv_nsec);
202 }
203 
204 static void bench_multi(void)
205 {
206 	struct timespec start, end, diff;
207 	pthread_t threads[THREADS];
208 
209 	clock_gettime(CLOCK_MONOTONIC, &start);
210 	for (size_t i = 0; i < THREADS; ++i)
211 		assert(pthread_create(&threads[i], NULL, test_vdso_getrandom, NULL) == 0);
212 	for (size_t i = 0; i < THREADS; ++i)
213 		pthread_join(threads[i], NULL);
214 	clock_gettime(CLOCK_MONOTONIC, &end);
215 	timespecsub(&end, &start, &diff);
216 	printf("   vdso: %u x %u times in %lu.%09lu seconds\n", TRIALS, THREADS, diff.tv_sec, diff.tv_nsec);
217 
218 	clock_gettime(CLOCK_MONOTONIC, &start);
219 	for (size_t i = 0; i < THREADS; ++i)
220 		assert(pthread_create(&threads[i], NULL, test_libc_getrandom, NULL) == 0);
221 	for (size_t i = 0; i < THREADS; ++i)
222 		pthread_join(threads[i], NULL);
223 	clock_gettime(CLOCK_MONOTONIC, &end);
224 	timespecsub(&end, &start, &diff);
225 	printf("   libc: %u x %u times in %lu.%09lu seconds\n", TRIALS, THREADS, diff.tv_sec, diff.tv_nsec);
226 
227 	clock_gettime(CLOCK_MONOTONIC, &start);
228 	for (size_t i = 0; i < THREADS; ++i)
229 		assert(pthread_create(&threads[i], NULL, test_syscall_getrandom, NULL) == 0);
230 	for (size_t i = 0; i < THREADS; ++i)
231 		pthread_join(threads[i], NULL);
232 	clock_gettime(CLOCK_MONOTONIC, &end);
233 	timespecsub(&end, &start, &diff);
234 	printf("   syscall: %u x %u times in %lu.%09lu seconds\n", TRIALS, THREADS, diff.tv_sec, diff.tv_nsec);
235 }
236 
237 static void fill(void)
238 {
239 	uint8_t weird_size[323929];
240 	for (;;)
241 		vgetrandom(weird_size, sizeof(weird_size), 0);
242 }
243 
244 static void kselftest(void)
245 {
246 	uint8_t weird_size[1263];
247 
248 	ksft_print_header();
249 	ksft_set_plan(1);
250 
251 	for (size_t i = 0; i < 1000; ++i) {
252 		ssize_t ret = vgetrandom(weird_size, sizeof(weird_size), 0);
253 		if (ret != sizeof(weird_size))
254 			exit(KSFT_FAIL);
255 	}
256 
257 	ksft_test_result_pass("getrandom: PASS\n");
258 	exit(KSFT_PASS);
259 }
260 
261 static void usage(const char *argv0)
262 {
263 	fprintf(stderr, "Usage: %s [bench-single|bench-multi|fill]\n", argv0);
264 }
265 
266 int main(int argc, char *argv[])
267 {
268 	if (argc == 1) {
269 		kselftest();
270 		return 0;
271 	}
272 
273 	if (argc != 2) {
274 		usage(argv[0]);
275 		return 1;
276 	}
277 	if (!strcmp(argv[1], "bench-single"))
278 		bench_single();
279 	else if (!strcmp(argv[1], "bench-multi"))
280 		bench_multi();
281 	else if (!strcmp(argv[1], "fill"))
282 		fill();
283 	else {
284 		usage(argv[0]);
285 		return 1;
286 	}
287 	return 0;
288 }
289